Python · 20010 bytes Raw Blame History
1 """Training-drift — pre-run probe that reads dlm's per-step loss curve.
2
3 Sister probe to :mod:`gradient_ghost`. Where ``gradient_ghost`` reads
4 the optimizer state at the end of training, ``training_drift`` reads
5 the loss curve *during* training: the rich signal about *how* the
6 adapter trained, not just what it produced.
7
8 Four metrics extracted from the curve:
9
10 - ``final_loss`` — last recorded step's loss.
11 - ``convergence_ratio`` — ``final_loss / initial_loss``. Lower is
12 better; a healthy adapter cuts loss by half or more.
13 - ``smoothness`` — ``1 - (var(Δloss) / var(loss))``. Values near 1
14 mean the curve descends smoothly; near 0 means each step's
15 change-in-loss is comparable to the loss range itself (spiky).
16 - ``instability_events`` — count of steps where
17 ``|Δloss| > 3 · rolling_std``. Spikes that survive the rolling
18 window are real — they correlate with silent adapter degradation.
19
20 Verdict: PASS when all three of (smoothness ≥ 0.7, instability_events
21 == 0, convergence_ratio ≤ 0.7); else WARN. The thresholds are
22 hand-tuned defaults; spec fields override them.
23
24 ## Why no null calibration
25
26 Mirrors ``prompt_collapse`` and ``multi_turn_coherence_decay``: a null
27 adapter doesn't *train* — it has no loss curve. The null distribution
28 of "smoothness on a noise adapter" is undefined. Fixed-threshold
29 verdicts are the published path; users override per-spec.
30
31 ## Log-format note
32
33 dlm writes one JSONL per run at
34 ``<store_path>/logs/train-NNNNNN-YYYYMMDDTHHMMSS.jsonl``. Each line is
35 ``{"type": "<banner|step|run_complete|...>", ...}``. This probe
36 filters for ``type == "step"`` records and reads ``step`` + ``loss``.
37 The sibling ``*.summary.json`` has run aggregates; we don't consume
38 it here — the curve is richer.
39 """
40
41 from __future__ import annotations
42
43 import json
44 import math
45 from pathlib import Path
46 from typing import ClassVar, Literal
47
48 import numpy as np
49 from pydantic import Field
50
51 from dlm_sway.core.errors import SwayError
52 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
53 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
54
55
56 class TrainingDriftError(SwayError):
57 """Raised when a training log file is structurally unparseable.
58
59 Distinct from :class:`MissingTrainingStateError`-style absences:
60 a missing logs directory is a SKIP (no .dlm context, or a
61 pre-training adapter), but a JSONL with corrupted lines or the
62 wrong shape is an ERROR — the user has training data but we
63 can't read it, which deserves a noisy failure.
64 """
65
66
67 class TrainingDriftSpec(ProbeSpec):
68 """Spec for ``kind: training_drift``."""
69
70 kind: Literal["training_drift"] = "training_drift"
71 store_path: str | None = None
72 """Path to a dlm store root (the directory containing
73 ``logs/`` and ``adapter/``). When ``None`` the probe SKIPs —
74 typically populated by the .dlm autogen path which knows the
75 store from the dlm_id resolution."""
76 min_steps: int = Field(default=10, ge=2)
77 """Skip when the curve has fewer steps than this. Short runs
78 (3-step early-stops, smoke tests) produce undefined fits;
79 surfacing those as PASS/FAIL would be misleading."""
80 rolling_window: int = Field(default=10, ge=2)
81 """Window for the rolling std used in spike detection. Larger
82 windows mean tighter spike detection at the cost of missing
83 rapid oscillations. 10 is a balance for typical 100–1000 step
84 runs."""
85 spike_sigma: float = Field(default=3.0, gt=0.0)
86 """A step is an instability event when ``|Δloss|`` exceeds this
87 many rolling standard deviations. 3σ is the conventional
88 "outlier" boundary; lower for more sensitivity, higher for less
89 noise tolerance."""
90 assert_smoothness_gte: float = Field(default=0.7, ge=0.0, le=1.0)
91 """Minimum smoothness for PASS."""
92 assert_convergence_ratio_lte: float = Field(default=0.7, gt=0.0)
93 """Maximum ``final_loss / initial_loss`` for PASS. A "well-
94 trained" run typically halves loss; permissive default tolerates
95 noisier data."""
96 assert_instability_events_lte: int = Field(default=0, ge=0)
97 """Maximum allowed spike count. Default 0 — any spike → WARN."""
98
99
100 class TrainingDriftProbe(Probe):
101 """The "did this adapter train smoothly?" pre-run probe."""
102
103 kind = "training_drift"
104 spec_cls = TrainingDriftSpec
105 category = "calibration"
106 needs_backend: ClassVar[bool] = False
107 # Pre-run probe: no model load, runs in <100ms, ideal for
108 # ``sway check`` first-pass before any heavy probe fires.
109 # Mirrors gradient_ghost's posture.
110
111 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
112 del ctx # No backend / sections / null_stats consumed.
113 assert isinstance(spec, TrainingDriftSpec)
114
115 if spec.store_path is None:
116 return ProbeResult(
117 name=spec.name,
118 kind=spec.kind,
119 verdict=Verdict.SKIP,
120 score=None,
121 message="no store_path provided (no .dlm context)",
122 )
123
124 store_path = Path(spec.store_path).expanduser().resolve()
125 logs_dir = store_path / "logs"
126 if not logs_dir.is_dir():
127 return ProbeResult(
128 name=spec.name,
129 kind=spec.kind,
130 verdict=Verdict.SKIP,
131 score=None,
132 message=f"no logs/ directory under {store_path}",
133 )
134
135 log_paths = sorted(logs_dir.glob("train-*.jsonl"))
136 if not log_paths:
137 return ProbeResult(
138 name=spec.name,
139 kind=spec.kind,
140 verdict=Verdict.SKIP,
141 score=None,
142 message=f"no train-*.jsonl files under {logs_dir}",
143 )
144
145 # Concatenate steps across all runs in chronological order.
146 # Resumed runs may produce duplicate step numbers — dedupe by
147 # keeping the latest occurrence (the most recent run's value
148 # for that step). Sorted glob already gives chronological
149 # order via the timestamp suffix.
150 try:
151 steps_by_idx = _collect_steps(log_paths)
152 except TrainingDriftError as exc:
153 return ProbeResult(
154 name=spec.name,
155 kind=spec.kind,
156 verdict=Verdict.ERROR,
157 score=None,
158 message=str(exc),
159 evidence={"log_paths": [str(p) for p in log_paths]},
160 )
161
162 if len(steps_by_idx) < spec.min_steps:
163 return ProbeResult(
164 name=spec.name,
165 kind=spec.kind,
166 verdict=Verdict.SKIP,
167 score=None,
168 message=(
169 f"only {len(steps_by_idx)} step records "
170 f"(< min_steps={spec.min_steps}); "
171 f"curve too short to fit reliably"
172 ),
173 evidence={"num_steps": len(steps_by_idx)},
174 )
175
176 # Sorted (step, loss) pairs.
177 ordered = sorted(steps_by_idx.items())
178 steps = np.asarray([s for s, _ in ordered], dtype=np.int64)
179 losses = np.asarray([loss for _, loss in ordered], dtype=np.float64)
180
181 # All four metrics. Each can fail gracefully (NaN-only loss
182 # column, all-equal losses, etc.) and the verdict respects the
183 # safe_finalize critical-field guard for ``raw``.
184 metrics = _compute_metrics(
185 losses, rolling_window=spec.rolling_window, spike_sigma=spec.spike_sigma
186 )
187
188 verdict, score, message = _verdict_from_metrics(metrics, spec)
189
190 # Bound the curve we ship in evidence: a 10k-step run shouldn't
191 # explode the JSON report. Downsample uniformly to the cap.
192 curve = _downsampled_curve(steps, losses, cap=512)
193
194 return safe_finalize(
195 name=spec.name,
196 kind=spec.kind,
197 verdict=verdict,
198 score=score,
199 raw=metrics["smoothness"],
200 evidence={
201 "final_loss": metrics["final_loss"],
202 "initial_loss": metrics["initial_loss"],
203 "convergence_ratio": metrics["convergence_ratio"],
204 "smoothness": metrics["smoothness"],
205 "instability_events": metrics["instability_events"],
206 "num_steps": len(losses),
207 "num_log_files": len(log_paths),
208 "curve_sampled": curve,
209 "thresholds": {
210 "smoothness_gte": spec.assert_smoothness_gte,
211 "convergence_ratio_lte": spec.assert_convergence_ratio_lte,
212 "instability_events_lte": spec.assert_instability_events_lte,
213 },
214 "weight": spec.weight,
215 },
216 message=message,
217 critical_fields=(),
218 # ``raw`` (smoothness) is bounded [0, 1] in normal cases;
219 # the metric helper returns NaN only on degenerate inputs
220 # which we already surfaced via the SKIP path. Critical-
221 # field guard would null-out the score on a NaN smoothness
222 # which is more confusing than helpful here.
223 )
224
225
226 # ---------------------------------------------------------------------------
227 # JSONL parsing
228 # ---------------------------------------------------------------------------
229
230
231 def _collect_steps(log_paths: list[Path]) -> dict[int, float]:
232 """Parse every JSONL, dedupe-by-step, return ``{step: loss}``.
233
234 Resumed runs append to a fresh JSONL but share step numbers with
235 the prior run. Dedupe-by-keep-latest matches dlm's own ``dlm
236 metrics`` semantics: the most recent run for a given step wins.
237 """
238 out: dict[int, float] = {}
239 for path in log_paths:
240 try:
241 with path.open("r", encoding="utf-8") as f:
242 for line_no, raw in enumerate(f, start=1):
243 raw = raw.strip()
244 if not raw:
245 continue
246 try:
247 rec = json.loads(raw)
248 except json.JSONDecodeError as exc:
249 # A trailing partial line from a crashed
250 # trainer is the typical cause. Skip it but
251 # surface as ERROR if EVERY line is broken
252 # (caller checks ``out`` emptiness).
253 if line_no == 1:
254 raise TrainingDriftError(
255 f"first line of {path.name} is not valid JSON: {exc}"
256 ) from exc
257 continue
258 if not isinstance(rec, dict):
259 continue
260 if rec.get("type") != "step":
261 continue
262 try:
263 step = int(rec["step"])
264 loss = float(rec["loss"])
265 except (KeyError, TypeError, ValueError):
266 continue
267 if not math.isfinite(loss):
268 # NaN loss is a real signal — record it as
269 # +inf so the spike detector flags the
270 # instability without crashing on np.log.
271 loss = math.inf
272 out[step] = loss
273 except OSError as exc:
274 raise TrainingDriftError(f"failed to read {path}: {exc}") from exc
275 return out
276
277
278 # ---------------------------------------------------------------------------
279 # Metric computation
280 # ---------------------------------------------------------------------------
281
282
283 def _compute_metrics(
284 losses: np.ndarray,
285 *,
286 rolling_window: int,
287 spike_sigma: float,
288 ) -> dict[str, float]:
289 """Compute the four headline metrics from the loss array.
290
291 All metrics are robust to NaN/inf: non-finite step losses are
292 replaced with the most-recent finite value before fitting (so a
293 single bad batch doesn't poison the curve), but their *positions*
294 are still counted as instability events.
295 """
296 losses = losses.astype(np.float64, copy=True)
297 instability_from_nan = int(np.sum(~np.isfinite(losses)))
298
299 if instability_from_nan > 0:
300 # Forward-fill non-finite values from the previous finite
301 # entry so downstream stats don't blow up. The first entry
302 # MUST be finite (training records the first batch's loss
303 # before anything could go wrong); guard anyway.
304 finite_mask = np.isfinite(losses)
305 if not finite_mask.any():
306 # Pathological: every step recorded NaN. Return all-NaN
307 # metrics so the verdict path can surface the failure.
308 return {
309 "initial_loss": float("nan"),
310 "final_loss": float("nan"),
311 "convergence_ratio": float("nan"),
312 "smoothness": 0.0,
313 "instability_events": float(len(losses)),
314 }
315 last_good = float(losses[finite_mask][0])
316 for i, v in enumerate(losses):
317 if not math.isfinite(v):
318 losses[i] = last_good
319 else:
320 last_good = float(v)
321
322 initial_loss = float(losses[0])
323 final_loss = float(losses[-1])
324 convergence_ratio = float(final_loss / initial_loss) if initial_loss != 0.0 else float("inf")
325
326 deltas = np.diff(losses)
327 var_loss = float(losses.var())
328 var_delta = float(deltas.var()) if deltas.size > 0 else 0.0
329 if var_loss > 0.0: # noqa: SIM108 — branch comments are load-bearing
330 # Clip into [0, 1]: a curve where var(Δloss) > var(loss)
331 # implies the per-step jitter dominates the overall sweep —
332 # treat as fully-spiky (smoothness=0) rather than emit a
333 # negative number.
334 smoothness = max(0.0, 1.0 - var_delta / var_loss)
335 else:
336 # Identical losses across every step: the run never
337 # progressed. Smoothness is formally 1.0 (perfectly flat),
338 # but that's misleading — surface as 0.0 so the verdict path
339 # can flag it.
340 smoothness = 0.0
341
342 instability_events = _count_spikes(deltas, window=rolling_window, sigma=spike_sigma)
343 instability_events += instability_from_nan
344
345 return {
346 "initial_loss": initial_loss,
347 "final_loss": final_loss,
348 "convergence_ratio": convergence_ratio,
349 "smoothness": smoothness,
350 "instability_events": float(instability_events),
351 }
352
353
354 def _count_spikes(deltas: np.ndarray, *, window: int, sigma: float) -> int:
355 """Count loss-increase events that exceed a robust noise threshold.
356
357 The naive ``|Δloss| > sigma · rolling_std`` heuristic is broken for
358 exponential decay: deltas span orders of magnitude across training,
359 so within-window std stays tiny while absolute deltas are large —
360 every step trips the threshold. The semantically correct
361 "instability event" is a loss *increase*, not a faster-than-typical
362 decrease.
363
364 Heuristic:
365
366 1. Restrict to positive deltas (``delta > 0`` ⇒ loss went up). Loss
367 going down faster than usual isn't an instability — it's the
368 happy path.
369 2. For each positive delta, compare to the median absolute delta
370 in a centered window. Flag when ``delta > sigma · MAD``, where
371 MAD is the median absolute deviation (robust to outliers).
372 3. Short curves fall back to global MAD; constant-loss curves
373 (every delta zero) report 0 spikes.
374
375 The ``sigma`` parameter retains its semantic meaning ("how many
376 typical deltas does this exceed") but operates against the
377 median-absolute-delta scale rather than std-of-delta.
378 """
379 if deltas.size == 0:
380 return 0
381
382 # Median absolute delta — the "typical movement scale" we judge
383 # spikes against. Median is robust to a few outliers; if the
384 # whole curve is flat, MAD is 0 and we report no spikes.
385 abs_deltas = np.abs(deltas)
386
387 if deltas.size < window:
388 baseline = float(np.median(abs_deltas))
389 if baseline == 0.0:
390 return 0
391 return int(np.sum((deltas > 0.0) & (deltas > sigma * baseline)))
392
393 spikes = 0
394 half = window // 2
395 for i in range(deltas.size):
396 if deltas[i] <= 0.0:
397 continue # Loss went down — not an instability event.
398 lo = max(0, i - half)
399 hi = min(deltas.size, i + half + 1)
400 window_slice = abs_deltas[lo:hi]
401 # Exclude the point itself so a real outlier doesn't inflate
402 # the baseline it's measured against.
403 if window_slice.size > 1:
404 mask = np.ones(window_slice.size, dtype=bool)
405 mask[i - lo] = False
406 window_slice = window_slice[mask]
407 baseline = float(np.median(window_slice))
408 if baseline == 0.0:
409 # Surrounding deltas are all zero — any nonzero positive
410 # delta is an instability event by construction.
411 spikes += 1
412 continue
413 if float(deltas[i]) > sigma * baseline:
414 spikes += 1
415 return spikes
416
417
418 # ---------------------------------------------------------------------------
419 # Verdict mapping
420 # ---------------------------------------------------------------------------
421
422
423 def _verdict_from_metrics(
424 metrics: dict[str, float], spec: TrainingDriftSpec
425 ) -> tuple[Verdict, float, str]:
426 """Map the four metrics to a (verdict, score, message)."""
427 smooth = metrics["smoothness"]
428 conv = metrics["convergence_ratio"]
429 instability = int(metrics["instability_events"])
430
431 smoothness_pass = smooth >= spec.assert_smoothness_gte
432 convergence_pass = conv <= spec.assert_convergence_ratio_lte
433 instability_pass = instability <= spec.assert_instability_events_lte
434
435 failures: list[str] = []
436 if not smoothness_pass:
437 failures.append(f"smoothness={smooth:.2f} < {spec.assert_smoothness_gte}")
438 if not convergence_pass:
439 failures.append(f"convergence_ratio={conv:.2f} > {spec.assert_convergence_ratio_lte}")
440 if not instability_pass:
441 failures.append(f"instability_events={instability} > {spec.assert_instability_events_lte}")
442
443 headline = (
444 f"smoothness={smooth:.2f}, convergence_ratio={conv:.2f}, "
445 f"instability_events={instability}, final_loss={metrics['final_loss']:.3f}"
446 )
447
448 if not failures:
449 # All three thresholds clear → PASS with a normalized score
450 # blending the three signals (smoothness contributes most;
451 # convergence and instability are guard rails).
452 score = float(min(1.0, max(0.0, smooth)))
453 return Verdict.PASS, score, headline
454
455 # Score: a continuous blend of how far we are from each threshold,
456 # so the report can rank "borderline warn" against "actively bad."
457 # Doesn't influence the verdict — that's already FAIL/WARN.
458 score = float(min(1.0, max(0.0, smooth * 0.5)))
459 return Verdict.WARN, score, f"{headline}; warnings: {'; '.join(failures)}"
460
461
462 # ---------------------------------------------------------------------------
463 # Curve downsampling for evidence
464 # ---------------------------------------------------------------------------
465
466
467 def _downsampled_curve(
468 steps: np.ndarray, losses: np.ndarray, *, cap: int
469 ) -> list[tuple[int, float]]:
470 """Uniform-stride downsample so a 10k-step run still fits in the JSON report.
471
472 Always preserves the first and last point so initial/final loss
473 are visible regardless of cap. For curves shorter than the cap,
474 returns the full series unchanged.
475 """
476 n = int(len(losses))
477 if n <= cap:
478 return [(int(s), float(loss)) for s, loss in zip(steps, losses, strict=True)]
479 # Stride to keep at most ``cap`` points: stride = ceil(n / cap).
480 # The +1 accounts for always-appending the final point even when
481 # ``stride * (cap - 1) < n - 1``.
482 stride = max(1, (n + cap - 1) // cap)
483 idx = list(range(0, n, stride))
484 if idx[-1] != n - 1:
485 idx.append(n - 1)
486 return [(int(steps[i]), float(losses[i])) for i in idx]