| 1 |
"""Retention metric — eval on a fixed slice of the replay corpus. |
| 2 |
|
| 3 |
At every eval step the trainer computes val loss on the current |
| 4 |
document's held-out split. That tells us how well the model fits |
| 5 |
recent content. It does NOT tell us whether the model still |
| 6 |
remembers what it learned two retrains ago — the "catastrophic |
| 7 |
forgetting" failure mode the replay corpus was designed to prevent. |
| 8 |
|
| 9 |
This module picks a **stable 5% slice** of the replay corpus at run |
| 10 |
start, reserves it as eval-only (the trainer never sees it), and |
| 11 |
reports loss on that slice alongside val loss. A retention_delta >> |
| 12 |
val_delta between runs is the forgetting signal the UI surfaces. |
| 13 |
|
| 14 |
Design: |
| 15 |
- `build_retention_slice(replay_store, *, frac, seed)` returns a list |
| 16 |
of `IndexEntry` the caller can rehydrate into rows. Seed-stable: |
| 17 |
same corpus + seed → same slice → same "held-out" across runs. |
| 18 |
- The slice is disjoint from the training sample (the sampler draws |
| 19 |
from the remainder of the corpus; trainer threads the retention |
| 20 |
entries' ids to `exclude`). |
| 21 |
""" |
| 22 |
|
| 23 |
from __future__ import annotations |
| 24 |
|
| 25 |
import hashlib |
| 26 |
from dataclasses import dataclass |
| 27 |
|
| 28 |
from dlm.eval.errors import RetentionSliceError |
| 29 |
from dlm.replay.models import IndexEntry |
| 30 |
|
| 31 |
_DEFAULT_FRAC = 0.05 |
| 32 |
|
| 33 |
|
| 34 |
@dataclass(frozen=True) |
| 35 |
class RetentionSlice: |
| 36 |
"""Seed-stable eval-only slice of the replay corpus.""" |
| 37 |
|
| 38 |
entries: list[IndexEntry] |
| 39 |
seed: int |
| 40 |
frac: float |
| 41 |
|
| 42 |
@property |
| 43 |
def section_ids(self) -> set[str]: |
| 44 |
return {e.section_id for e in self.entries} |
| 45 |
|
| 46 |
|
| 47 |
def build_retention_slice( |
| 48 |
entries: list[IndexEntry], |
| 49 |
*, |
| 50 |
frac: float = _DEFAULT_FRAC, |
| 51 |
seed: int = 0, |
| 52 |
) -> RetentionSlice: |
| 53 |
"""Pick a `frac` fraction of `entries` to reserve as retention-only. |
| 54 |
|
| 55 |
Entries are hashed against `(seed, section_id)` and the top-k by |
| 56 |
hash are selected. This is stable — adding a section to the corpus |
| 57 |
doesn't reshuffle what's already been designated retention — and |
| 58 |
deterministic under the same seed. |
| 59 |
|
| 60 |
Raises `RetentionSliceError` on empty input or a frac outside (0, 1). |
| 61 |
""" |
| 62 |
if not 0.0 < frac < 1.0: |
| 63 |
raise RetentionSliceError(f"frac must be in (0, 1), got {frac!r}") |
| 64 |
if not entries: |
| 65 |
raise RetentionSliceError("cannot build retention slice from empty corpus") |
| 66 |
|
| 67 |
# Always pick at least one; round up so small corpora still have |
| 68 |
# a retention signal. |
| 69 |
k = max(1, int(len(entries) * frac + 0.999)) |
| 70 |
keyed = sorted(entries, key=lambda e: _retention_key(e.section_id, seed)) |
| 71 |
picked = keyed[:k] |
| 72 |
return RetentionSlice(entries=picked, seed=seed, frac=frac) |
| 73 |
|
| 74 |
|
| 75 |
def _retention_key(section_id: str, seed: int) -> str: |
| 76 |
h = hashlib.sha256(f"{seed}\x00{section_id}".encode()) |
| 77 |
return h.hexdigest() |
| 78 |
|
| 79 |
|
| 80 |
def retention_delta( |
| 81 |
*, |
| 82 |
current_retention_loss: float | None, |
| 83 |
previous_retention_loss: float | None, |
| 84 |
) -> float | None: |
| 85 |
"""`current - previous`; None if either side is missing. |
| 86 |
|
| 87 |
Reported in the `TrainingSummary.retention_loss_delta` field. A |
| 88 |
positive delta means retention loss went UP — the model is |
| 89 |
forgetting. The magnitude relative to `final_val_loss` is the |
| 90 |
honest forgetting signal; a rising retention loss alongside falling |
| 91 |
val loss is the canonical catastrophic-forgetting fingerprint. |
| 92 |
""" |
| 93 |
if current_retention_loss is None or previous_retention_loss is None: |
| 94 |
return None |
| 95 |
return current_retention_loss - previous_retention_loss |