| 1 |
"""`ReplayStore` — high-level facade over `corpus.zst` + `index.json`. |
| 2 |
|
| 3 |
Binds the low-level primitives (`corpus.append_snapshot`, |
| 4 |
`index.load_index`, `sampler.sample`, `eviction.evict_until`) to a |
| 5 |
concrete store path so callers don't juggle file paths themselves. The |
| 6 |
store-level exclusive lock must be held for mutating operations — |
| 7 |
this module doesn't acquire it, to avoid fighting the outer |
| 8 |
training-run lifecycle. |
| 9 |
|
| 10 |
Also provides `sample_rows()` — the glue that feeds |
| 11 |
`build_dataset(..., replay_rows=...)` without the caller having to |
| 12 |
understand snapshot → row shape herself. |
| 13 |
""" |
| 14 |
|
| 15 |
from __future__ import annotations |
| 16 |
|
| 17 |
from dataclasses import dataclass |
| 18 |
from pathlib import Path |
| 19 |
from typing import TYPE_CHECKING, Any |
| 20 |
|
| 21 |
from dlm.replay.corpus import append_snapshot, iter_snapshots |
| 22 |
from dlm.replay.index import load_index, save_index |
| 23 |
from dlm.replay.models import IndexEntry, SectionSnapshot |
| 24 |
|
| 25 |
if TYPE_CHECKING: |
| 26 |
import random |
| 27 |
from datetime import datetime |
| 28 |
|
| 29 |
from dlm.replay.sampler import Scheme |
| 30 |
|
| 31 |
Row = dict[str, Any] |
| 32 |
|
| 33 |
|
| 34 |
@dataclass(frozen=True) |
| 35 |
class ReplayStore: |
| 36 |
"""Facade bound to one store's `replay/` subdir. |
| 37 |
|
| 38 |
Construct via `ReplayStore.at(store_path.replay_corpus, |
| 39 |
store_path.replay_index)` — the path pair is kept explicit so the |
| 40 |
`StorePath` accessor remains the single source of truth for |
| 41 |
filesystem layout. |
| 42 |
""" |
| 43 |
|
| 44 |
corpus_path: Path |
| 45 |
index_path: Path |
| 46 |
|
| 47 |
@classmethod |
| 48 |
def at(cls, corpus_path: Path, index_path: Path) -> ReplayStore: |
| 49 |
corpus_path.parent.mkdir(parents=True, exist_ok=True) |
| 50 |
return cls(corpus_path=corpus_path, index_path=index_path) |
| 51 |
|
| 52 |
# --- index --------------------------------------------------------------- |
| 53 |
|
| 54 |
def load(self) -> list[IndexEntry]: |
| 55 |
return load_index(self.index_path) |
| 56 |
|
| 57 |
def save(self, entries: list[IndexEntry]) -> None: |
| 58 |
save_index(self.index_path, entries) |
| 59 |
|
| 60 |
# --- corpus -------------------------------------------------------------- |
| 61 |
|
| 62 |
def append(self, snapshot: SectionSnapshot) -> IndexEntry: |
| 63 |
"""Append one snapshot, persist an updated index, return its entry. |
| 64 |
|
| 65 |
Index save happens on every append so a crash mid-training |
| 66 |
leaves the corpus + index consistent. |
| 67 |
|
| 68 |
**Performance:** each call does a full |
| 69 |
`load_index → append → save_index` cycle, which is O(n) in the |
| 70 |
existing index size. Fine for the one-shot append the trainer |
| 71 |
makes after each training cycle; **not** fine for loops like |
| 72 |
corpus imports or recovery flows. Use `append_many` whenever |
| 73 |
you have more than a handful of snapshots to add — the batch |
| 74 |
variant saves the index exactly once. |
| 75 |
""" |
| 76 |
entry = append_snapshot(self.corpus_path, snapshot) |
| 77 |
self.save([*self.load(), entry]) |
| 78 |
return entry |
| 79 |
|
| 80 |
def append_many(self, snapshots: list[SectionSnapshot]) -> list[IndexEntry]: |
| 81 |
"""Batch variant: one index save at the end.""" |
| 82 |
existing = self.load() |
| 83 |
new_entries = [append_snapshot(self.corpus_path, s) for s in snapshots] |
| 84 |
self.save([*existing, *new_entries]) |
| 85 |
return new_entries |
| 86 |
|
| 87 |
# --- sampling → rows ----------------------------------------------------- |
| 88 |
|
| 89 |
def sample_rows( |
| 90 |
self, |
| 91 |
*, |
| 92 |
k: int, |
| 93 |
now: datetime, |
| 94 |
rng: random.Random, |
| 95 |
scheme: Scheme = "recency", |
| 96 |
) -> list[Row]: |
| 97 |
"""Sample `k` snapshots and expand each to `sections_to_rows` dicts. |
| 98 |
|
| 99 |
A single INSTRUCTION snapshot can fan out to multiple rows (one |
| 100 |
per Q/A pair); same for PREFERENCE. The returned list is |
| 101 |
already flat — plug directly into |
| 102 |
`dlm.data.build_dataset(..., replay_rows=...)`. |
| 103 |
|
| 104 |
Each row's `_dlm_section_id` is prefixed with `replay:` and |
| 105 |
suffixed with the snapshot's `last_seen_at` timestamp. This |
| 106 |
prevents a rehydrated replay section from colliding with the |
| 107 |
same content in the current document under the splitter's |
| 108 |
`(seed, id, sub_index)` hash. |
| 109 |
""" |
| 110 |
from dlm.replay.sampler import sample |
| 111 |
|
| 112 |
entries = self.load() |
| 113 |
picked = sample(entries, k=k, now=now, rng=rng, scheme=scheme) |
| 114 |
snapshots = list(iter_snapshots(self.corpus_path, picked)) |
| 115 |
rows: list[Row] = [] |
| 116 |
for snap in snapshots: |
| 117 |
rows.extend(_snapshot_to_rows(snap)) |
| 118 |
return rows |
| 119 |
|
| 120 |
def sample_preference_rows( |
| 121 |
self, |
| 122 |
*, |
| 123 |
k: int, |
| 124 |
now: datetime, |
| 125 |
rng: random.Random, |
| 126 |
include_auto_mined: bool = True, |
| 127 |
scheme: Scheme = "recency", |
| 128 |
) -> list[Row]: |
| 129 |
"""Sample `k` *preference* snapshots; emit DPO-shaped rows. |
| 130 |
|
| 131 |
Mirrors `sample_rows` but pre-filters the index to |
| 132 |
preference-only entries before the weighted-reservoir draw. |
| 133 |
Falls back to an empty list when the corpus has no preference |
| 134 |
snapshots — callers at DPO-time decide whether zero replay is |
| 135 |
acceptable or not. |
| 136 |
|
| 137 |
`IndexEntry` doesn't carry `section_type` today, so we decode |
| 138 |
snapshots to partition. For the corpus sizes DLM realistically |
| 139 |
stores (<1k sections after eviction) the full decode is |
| 140 |
negligible compared to the training step itself. |
| 141 |
""" |
| 142 |
from dlm.replay.sampler import sample |
| 143 |
|
| 144 |
entries = self.load() |
| 145 |
if not entries: |
| 146 |
return [] |
| 147 |
|
| 148 |
snapshots = list(iter_snapshots(self.corpus_path, entries)) |
| 149 |
preference_entries: list[IndexEntry] = [] |
| 150 |
by_section_id: dict[str, SectionSnapshot] = {} |
| 151 |
for entry, snap in zip(entries, snapshots, strict=True): |
| 152 |
if snap.section_type != "preference": |
| 153 |
continue |
| 154 |
if not include_auto_mined and snap.auto_mined: |
| 155 |
continue |
| 156 |
preference_entries.append(entry) |
| 157 |
by_section_id[entry.section_id] = snap |
| 158 |
if not preference_entries: |
| 159 |
return [] |
| 160 |
|
| 161 |
picked = sample(preference_entries, k=k, now=now, rng=rng, scheme=scheme) |
| 162 |
rows: list[Row] = [] |
| 163 |
for entry in picked: |
| 164 |
snap = by_section_id[entry.section_id] |
| 165 |
rows.extend(_snapshot_to_rows(snap)) |
| 166 |
return rows |
| 167 |
|
| 168 |
|
| 169 |
def _snapshot_to_rows(snap: SectionSnapshot) -> list[Row]: |
| 170 |
"""Expand one snapshot to its row-shape list. |
| 171 |
|
| 172 |
Mirrors `dlm.data.sections_to_rows._section_to_rows` but emits a |
| 173 |
replay-namespaced `_dlm_section_id` so replay rows don't collide |
| 174 |
with current-document rows of identical content. |
| 175 |
""" |
| 176 |
replay_sid = f"replay:{snap.section_id}:{snap.last_seen_at.isoformat()}" |
| 177 |
|
| 178 |
if snap.section_type == "prose": |
| 179 |
text = snap.content.strip() |
| 180 |
if not text: |
| 181 |
return [] |
| 182 |
return [{"text": text, "_dlm_section_id": replay_sid}] |
| 183 |
|
| 184 |
if snap.section_type == "instruction": |
| 185 |
from dlm.data.instruction_parser import parse_instruction_body |
| 186 |
|
| 187 |
pairs = parse_instruction_body(snap.content, section_id=snap.section_id) |
| 188 |
return [ |
| 189 |
{ |
| 190 |
"messages": [ |
| 191 |
{"role": "user", "content": pair.question}, |
| 192 |
{"role": "assistant", "content": pair.answer}, |
| 193 |
], |
| 194 |
"_dlm_section_id": replay_sid, |
| 195 |
} |
| 196 |
for pair in pairs |
| 197 |
] |
| 198 |
|
| 199 |
# preference |
| 200 |
from dlm.data.preference_parser import parse_preference_body |
| 201 |
|
| 202 |
triples = parse_preference_body(snap.content, section_id=snap.section_id) |
| 203 |
return [ |
| 204 |
{ |
| 205 |
"prompt": t.prompt, |
| 206 |
"chosen": t.chosen, |
| 207 |
"rejected": t.rejected, |
| 208 |
"_dlm_section_id": replay_sid, |
| 209 |
} |
| 210 |
for t in triples |
| 211 |
] |