Python · 7601 bytes Raw Blame History
1 """`ReplayStore` — high-level facade over `corpus.zst` + `index.json`.
2
3 Binds the low-level primitives (`corpus.append_snapshot`,
4 `index.load_index`, `sampler.sample`, `eviction.evict_until`) to a
5 concrete store path so callers don't juggle file paths themselves. The
6 store-level exclusive lock must be held for mutating operations —
7 this module doesn't acquire it, to avoid fighting the outer
8 training-run lifecycle.
9
10 Also provides `sample_rows()` — the glue that feeds
11 `build_dataset(..., replay_rows=...)` without the caller having to
12 understand snapshot → row shape herself.
13 """
14
15 from __future__ import annotations
16
17 from dataclasses import dataclass
18 from pathlib import Path
19 from typing import TYPE_CHECKING, Any
20
21 from dlm.replay.corpus import append_snapshot, iter_snapshots
22 from dlm.replay.index import load_index, save_index
23 from dlm.replay.models import IndexEntry, SectionSnapshot
24
25 if TYPE_CHECKING:
26 import random
27 from datetime import datetime
28
29 from dlm.replay.sampler import Scheme
30
31 Row = dict[str, Any]
32
33
34 @dataclass(frozen=True)
35 class ReplayStore:
36 """Facade bound to one store's `replay/` subdir.
37
38 Construct via `ReplayStore.at(store_path.replay_corpus,
39 store_path.replay_index)` — the path pair is kept explicit so the
40 `StorePath` accessor remains the single source of truth for
41 filesystem layout.
42 """
43
44 corpus_path: Path
45 index_path: Path
46
47 @classmethod
48 def at(cls, corpus_path: Path, index_path: Path) -> ReplayStore:
49 corpus_path.parent.mkdir(parents=True, exist_ok=True)
50 return cls(corpus_path=corpus_path, index_path=index_path)
51
52 # --- index ---------------------------------------------------------------
53
54 def load(self) -> list[IndexEntry]:
55 return load_index(self.index_path)
56
57 def save(self, entries: list[IndexEntry]) -> None:
58 save_index(self.index_path, entries)
59
60 # --- corpus --------------------------------------------------------------
61
62 def append(self, snapshot: SectionSnapshot) -> IndexEntry:
63 """Append one snapshot, persist an updated index, return its entry.
64
65 Index save happens on every append so a crash mid-training
66 leaves the corpus + index consistent.
67
68 **Performance:** each call does a full
69 `load_index → append → save_index` cycle, which is O(n) in the
70 existing index size. Fine for the one-shot append the trainer
71 makes after each training cycle; **not** fine for loops like
72 corpus imports or recovery flows. Use `append_many` whenever
73 you have more than a handful of snapshots to add — the batch
74 variant saves the index exactly once.
75 """
76 entry = append_snapshot(self.corpus_path, snapshot)
77 self.save([*self.load(), entry])
78 return entry
79
80 def append_many(self, snapshots: list[SectionSnapshot]) -> list[IndexEntry]:
81 """Batch variant: one index save at the end."""
82 existing = self.load()
83 new_entries = [append_snapshot(self.corpus_path, s) for s in snapshots]
84 self.save([*existing, *new_entries])
85 return new_entries
86
87 # --- sampling → rows -----------------------------------------------------
88
89 def sample_rows(
90 self,
91 *,
92 k: int,
93 now: datetime,
94 rng: random.Random,
95 scheme: Scheme = "recency",
96 ) -> list[Row]:
97 """Sample `k` snapshots and expand each to `sections_to_rows` dicts.
98
99 A single INSTRUCTION snapshot can fan out to multiple rows (one
100 per Q/A pair); same for PREFERENCE. The returned list is
101 already flat — plug directly into
102 `dlm.data.build_dataset(..., replay_rows=...)`.
103
104 Each row's `_dlm_section_id` is prefixed with `replay:` and
105 suffixed with the snapshot's `last_seen_at` timestamp. This
106 prevents a rehydrated replay section from colliding with the
107 same content in the current document under the splitter's
108 `(seed, id, sub_index)` hash.
109 """
110 from dlm.replay.sampler import sample
111
112 entries = self.load()
113 picked = sample(entries, k=k, now=now, rng=rng, scheme=scheme)
114 snapshots = list(iter_snapshots(self.corpus_path, picked))
115 rows: list[Row] = []
116 for snap in snapshots:
117 rows.extend(_snapshot_to_rows(snap))
118 return rows
119
120 def sample_preference_rows(
121 self,
122 *,
123 k: int,
124 now: datetime,
125 rng: random.Random,
126 include_auto_mined: bool = True,
127 scheme: Scheme = "recency",
128 ) -> list[Row]:
129 """Sample `k` *preference* snapshots; emit DPO-shaped rows.
130
131 Mirrors `sample_rows` but pre-filters the index to
132 preference-only entries before the weighted-reservoir draw.
133 Falls back to an empty list when the corpus has no preference
134 snapshots — callers at DPO-time decide whether zero replay is
135 acceptable or not.
136
137 `IndexEntry` doesn't carry `section_type` today, so we decode
138 snapshots to partition. For the corpus sizes DLM realistically
139 stores (<1k sections after eviction) the full decode is
140 negligible compared to the training step itself.
141 """
142 from dlm.replay.sampler import sample
143
144 entries = self.load()
145 if not entries:
146 return []
147
148 snapshots = list(iter_snapshots(self.corpus_path, entries))
149 preference_entries: list[IndexEntry] = []
150 by_section_id: dict[str, SectionSnapshot] = {}
151 for entry, snap in zip(entries, snapshots, strict=True):
152 if snap.section_type != "preference":
153 continue
154 if not include_auto_mined and snap.auto_mined:
155 continue
156 preference_entries.append(entry)
157 by_section_id[entry.section_id] = snap
158 if not preference_entries:
159 return []
160
161 picked = sample(preference_entries, k=k, now=now, rng=rng, scheme=scheme)
162 rows: list[Row] = []
163 for entry in picked:
164 snap = by_section_id[entry.section_id]
165 rows.extend(_snapshot_to_rows(snap))
166 return rows
167
168
169 def _snapshot_to_rows(snap: SectionSnapshot) -> list[Row]:
170 """Expand one snapshot to its row-shape list.
171
172 Mirrors `dlm.data.sections_to_rows._section_to_rows` but emits a
173 replay-namespaced `_dlm_section_id` so replay rows don't collide
174 with current-document rows of identical content.
175 """
176 replay_sid = f"replay:{snap.section_id}:{snap.last_seen_at.isoformat()}"
177
178 if snap.section_type == "prose":
179 text = snap.content.strip()
180 if not text:
181 return []
182 return [{"text": text, "_dlm_section_id": replay_sid}]
183
184 if snap.section_type == "instruction":
185 from dlm.data.instruction_parser import parse_instruction_body
186
187 pairs = parse_instruction_body(snap.content, section_id=snap.section_id)
188 return [
189 {
190 "messages": [
191 {"role": "user", "content": pair.question},
192 {"role": "assistant", "content": pair.answer},
193 ],
194 "_dlm_section_id": replay_sid,
195 }
196 for pair in pairs
197 ]
198
199 # preference
200 from dlm.data.preference_parser import parse_preference_body
201
202 triples = parse_preference_body(snap.content, section_id=snap.section_id)
203 return [
204 {
205 "prompt": t.prompt,
206 "chosen": t.chosen,
207 "rejected": t.rejected,
208 "_dlm_section_id": replay_sid,
209 }
210 for t in triples
211 ]