Python · 3418 bytes Raw Blame History
1 """Retention metric — eval on a fixed slice of the replay corpus.
2
3 At every eval step the trainer computes val loss on the current
4 document's held-out split. That tells us how well the model fits
5 recent content. It does NOT tell us whether the model still
6 remembers what it learned two retrains ago — the "catastrophic
7 forgetting" failure mode the replay corpus was designed to prevent.
8
9 This module picks a **stable 5% slice** of the replay corpus at run
10 start, reserves it as eval-only (the trainer never sees it), and
11 reports loss on that slice alongside val loss. A retention_delta >>
12 val_delta between runs is the forgetting signal the UI surfaces.
13
14 Design:
15 - `build_retention_slice(replay_store, *, frac, seed)` returns a list
16 of `IndexEntry` the caller can rehydrate into rows. Seed-stable:
17 same corpus + seed → same slice → same "held-out" across runs.
18 - The slice is disjoint from the training sample (the sampler draws
19 from the remainder of the corpus; trainer threads the retention
20 entries' ids to `exclude`).
21 """
22
23 from __future__ import annotations
24
25 import hashlib
26 from dataclasses import dataclass
27
28 from dlm.eval.errors import RetentionSliceError
29 from dlm.replay.models import IndexEntry
30
31 _DEFAULT_FRAC = 0.05
32
33
34 @dataclass(frozen=True)
35 class RetentionSlice:
36 """Seed-stable eval-only slice of the replay corpus."""
37
38 entries: list[IndexEntry]
39 seed: int
40 frac: float
41
42 @property
43 def section_ids(self) -> set[str]:
44 return {e.section_id for e in self.entries}
45
46
47 def build_retention_slice(
48 entries: list[IndexEntry],
49 *,
50 frac: float = _DEFAULT_FRAC,
51 seed: int = 0,
52 ) -> RetentionSlice:
53 """Pick a `frac` fraction of `entries` to reserve as retention-only.
54
55 Entries are hashed against `(seed, section_id)` and the top-k by
56 hash are selected. This is stable — adding a section to the corpus
57 doesn't reshuffle what's already been designated retention — and
58 deterministic under the same seed.
59
60 Raises `RetentionSliceError` on empty input or a frac outside (0, 1).
61 """
62 if not 0.0 < frac < 1.0:
63 raise RetentionSliceError(f"frac must be in (0, 1), got {frac!r}")
64 if not entries:
65 raise RetentionSliceError("cannot build retention slice from empty corpus")
66
67 # Always pick at least one; round up so small corpora still have
68 # a retention signal.
69 k = max(1, int(len(entries) * frac + 0.999))
70 keyed = sorted(entries, key=lambda e: _retention_key(e.section_id, seed))
71 picked = keyed[:k]
72 return RetentionSlice(entries=picked, seed=seed, frac=frac)
73
74
75 def _retention_key(section_id: str, seed: int) -> str:
76 h = hashlib.sha256(f"{seed}\x00{section_id}".encode())
77 return h.hexdigest()
78
79
80 def retention_delta(
81 *,
82 current_retention_loss: float | None,
83 previous_retention_loss: float | None,
84 ) -> float | None:
85 """`current - previous`; None if either side is missing.
86
87 Reported in the `TrainingSummary.retention_loss_delta` field. A
88 positive delta means retention loss went UP — the model is
89 forgetting. The magnitude relative to `final_val_loss` is the
90 honest forgetting signal; a rising retention loss alongside falling
91 val loss is the canonical catastrophic-forgetting fingerprint.
92 """
93 if current_retention_loss is None or previous_retention_loss is None:
94 return None
95 return current_retention_loss - previous_retention_loss