Python · 5005 bytes Raw Blame History
1 """`TrainingSummary` — post-run report serialized to `logs/train-*.summary.json`.
2
3 One summary per training run. Captures "how did it go?" in a compact
4 form that's cheap to read from the CLI (no torch / HF imports needed)
5 and human-inspectable. The manifest's `training_runs` list links out
6 to the summary file by path so `dlm show` can load details on demand.
7 """
8
9 from __future__ import annotations
10
11 import json
12 from collections.abc import Iterable
13 from dataclasses import dataclass
14 from pathlib import Path
15
16 from pydantic import BaseModel, ConfigDict, Field
17
18 from dlm.io.atomic import write_text
19
20
21 class ProbeOutput(BaseModel):
22 """One probe's prompt + generated response + optional reference."""
23
24 model_config = ConfigDict(extra="forbid", frozen=True)
25
26 prompt: str
27 response: str
28 reference: str | None = None
29 section_id: str = ""
30
31
32 class SourceProvenanceRecord(BaseModel):
33 """Serialized per-directive ingestion bookkeeping.
34
35 Mirrors `dlm.directives.SourceProvenance` as a pydantic model so
36 the training summary JSON is self-describing. One record per
37 `training.sources` entry; `file_count == 0` indicates a directive
38 that matched nothing (worth flagging in the CLI).
39 """
40
41 model_config = ConfigDict(extra="forbid", frozen=True)
42
43 path: str
44 file_count: int = Field(0, ge=0)
45 total_bytes: int = Field(0, ge=0)
46 skipped_binary: int = Field(0, ge=0)
47 skipped_encoding: int = Field(0, ge=0)
48 skipped_over_size: int = Field(0, ge=0)
49
50
51 class TrainingSummary(BaseModel):
52 """Canonical post-run report."""
53
54 model_config = ConfigDict(extra="forbid", frozen=True)
55
56 run_id: int = Field(..., ge=1)
57 adapter_version: int = Field(..., ge=1)
58 seed: int
59 final_train_loss: float | None = None
60 final_val_loss: float | None = None
61 final_val_perplexity: float | None = None
62 retention_loss: float | None = None
63 retention_loss_delta: float | None = None
64 # Mixed-mode breakdown: when the training rows mix CPT prose and
65 # SFT instruction rows, the aggregate loss hides which side is
66 # driving movement. These fields split the final train/val and
67 # retention metrics by row mode. All optional — single-mode runs
68 # leave them None.
69 train_loss_cpt: float | None = None
70 train_loss_sft: float | None = None
71 val_loss_cpt: float | None = None
72 val_loss_sft: float | None = None
73 retention_cpt: float | None = None
74 retention_sft: float | None = None
75 probes: list[ProbeOutput] = Field(default_factory=list)
76 early_stopped: bool = False
77 steps: int = Field(0, ge=0)
78 duration_seconds: float = Field(0.0, ge=0.0)
79 determinism_class: str = "best_effort"
80 # Per-directive ingestion provenance. Empty when no
81 # `training.sources` declared. Order matches the frontmatter so
82 # CLI formatters can line up rows with source entries.
83 source_directives: list[SourceProvenanceRecord] = Field(default_factory=list)
84
85
86 @dataclass(frozen=True)
87 class LossByMode:
88 """Mean loss split by row mode (`cpt` prose vs `sft` instruction).
89
90 `None` where the corresponding row count was zero — the caller
91 stores that verbatim in the summary so ``None`` is honest "we had
92 no rows of that mode" rather than a zeroed-out number.
93 """
94
95 cpt: float | None
96 sft: float | None
97
98
99 def split_loss_by_mode(rows: Iterable[tuple[float, str]]) -> LossByMode:
100 """Average `(loss, mode)` pairs grouped by mode.
101
102 `mode` is expected to be one of `"cpt"` or `"sft"`; other strings
103 are ignored so the caller can pass a single stream containing
104 preference/other rows without pre-filtering.
105 """
106 cpt_losses: list[float] = []
107 sft_losses: list[float] = []
108 for loss, mode in rows:
109 if mode == "cpt":
110 cpt_losses.append(loss)
111 elif mode == "sft":
112 sft_losses.append(loss)
113 return LossByMode(
114 cpt=sum(cpt_losses) / len(cpt_losses) if cpt_losses else None,
115 sft=sum(sft_losses) / len(sft_losses) if sft_losses else None,
116 )
117
118
119 def save_summary(path: Path, summary: TrainingSummary) -> None:
120 """Atomically serialize `summary` as pretty JSON.
121
122 Uses the atomic-write helper so a concurrent CLI reader never sees
123 a torn file.
124 """
125 payload = summary.model_dump(mode="json")
126 blob = json.dumps(payload, sort_keys=True, indent=2) + "\n"
127 write_text(path, blob)
128
129
130 def load_summary(path: Path) -> TrainingSummary:
131 """Inverse of `save_summary`; raises `pydantic.ValidationError` on drift."""
132 raw = path.read_text(encoding="utf-8")
133 data = json.loads(raw)
134 return TrainingSummary.model_validate(data)
135
136
137 def summary_path_for(logs_dir: Path, run_id: int, started_iso: str) -> Path:
138 """Match the JSONL log file naming so pairs are easy to eyeball.
139
140 `train-<run_id>-<ts>.summary.json` sits next to the `.jsonl` log
141 with the same stem.
142 """
143 ts = started_iso.replace(":", "").replace("-", "")
144 return logs_dir / f"train-{run_id:06d}-{ts}.summary.json"