| 1 | """Training-drift — pre-run probe that reads dlm's per-step loss curve. |
| 2 | |
| 3 | Sister probe to :mod:`gradient_ghost`. Where ``gradient_ghost`` reads |
| 4 | the optimizer state at the end of training, ``training_drift`` reads |
| 5 | the loss curve *during* training: the rich signal about *how* the |
| 6 | adapter trained, not just what it produced. |
| 7 | |
| 8 | Four metrics extracted from the curve: |
| 9 | |
| 10 | - ``final_loss`` — last recorded step's loss. |
| 11 | - ``convergence_ratio`` — ``final_loss / initial_loss``. Lower is |
| 12 | better; a healthy adapter cuts loss by half or more. |
| 13 | - ``smoothness`` — ``1 - (var(Δloss) / var(loss))``. Values near 1 |
| 14 | mean the curve descends smoothly; near 0 means each step's |
| 15 | change-in-loss is comparable to the loss range itself (spiky). |
| 16 | - ``instability_events`` — count of steps where |
| 17 | ``|Δloss| > 3 · rolling_std``. Spikes that survive the rolling |
| 18 | window are real — they correlate with silent adapter degradation. |
| 19 | |
| 20 | Verdict: PASS when all three of (smoothness ≥ 0.7, instability_events |
| 21 | == 0, convergence_ratio ≤ 0.7); else WARN. The thresholds are |
| 22 | hand-tuned defaults; spec fields override them. |
| 23 | |
| 24 | ## Why no null calibration |
| 25 | |
| 26 | Mirrors ``prompt_collapse`` and ``multi_turn_coherence_decay``: a null |
| 27 | adapter doesn't *train* — it has no loss curve. The null distribution |
| 28 | of "smoothness on a noise adapter" is undefined. Fixed-threshold |
| 29 | verdicts are the published path; users override per-spec. |
| 30 | |
| 31 | ## Log-format note |
| 32 | |
| 33 | dlm writes one JSONL per run at |
| 34 | ``<store_path>/logs/train-NNNNNN-YYYYMMDDTHHMMSS.jsonl``. Each line is |
| 35 | ``{"type": "<banner|step|run_complete|...>", ...}``. This probe |
| 36 | filters for ``type == "step"`` records and reads ``step`` + ``loss``. |
| 37 | The sibling ``*.summary.json`` has run aggregates; we don't consume |
| 38 | it here — the curve is richer. |
| 39 | """ |
| 40 | |
| 41 | from __future__ import annotations |
| 42 | |
| 43 | import json |
| 44 | import math |
| 45 | from pathlib import Path |
| 46 | from typing import ClassVar, Literal |
| 47 | |
| 48 | import numpy as np |
| 49 | from pydantic import Field |
| 50 | |
| 51 | from dlm_sway.core.errors import SwayError |
| 52 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 53 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 54 | |
| 55 | |
| 56 | class TrainingDriftError(SwayError): |
| 57 | """Raised when a training log file is structurally unparseable. |
| 58 | |
| 59 | Distinct from :class:`MissingTrainingStateError`-style absences: |
| 60 | a missing logs directory is a SKIP (no .dlm context, or a |
| 61 | pre-training adapter), but a JSONL with corrupted lines or the |
| 62 | wrong shape is an ERROR — the user has training data but we |
| 63 | can't read it, which deserves a noisy failure. |
| 64 | """ |
| 65 | |
| 66 | |
| 67 | class TrainingDriftSpec(ProbeSpec): |
| 68 | """Spec for ``kind: training_drift``.""" |
| 69 | |
| 70 | kind: Literal["training_drift"] = "training_drift" |
| 71 | store_path: str | None = None |
| 72 | """Path to a dlm store root (the directory containing |
| 73 | ``logs/`` and ``adapter/``). When ``None`` the probe SKIPs — |
| 74 | typically populated by the .dlm autogen path which knows the |
| 75 | store from the dlm_id resolution.""" |
| 76 | min_steps: int = Field(default=10, ge=2) |
| 77 | """Skip when the curve has fewer steps than this. Short runs |
| 78 | (3-step early-stops, smoke tests) produce undefined fits; |
| 79 | surfacing those as PASS/FAIL would be misleading.""" |
| 80 | rolling_window: int = Field(default=10, ge=2) |
| 81 | """Window for the rolling std used in spike detection. Larger |
| 82 | windows mean tighter spike detection at the cost of missing |
| 83 | rapid oscillations. 10 is a balance for typical 100–1000 step |
| 84 | runs.""" |
| 85 | spike_sigma: float = Field(default=3.0, gt=0.0) |
| 86 | """A step is an instability event when ``|Δloss|`` exceeds this |
| 87 | many rolling standard deviations. 3σ is the conventional |
| 88 | "outlier" boundary; lower for more sensitivity, higher for less |
| 89 | noise tolerance.""" |
| 90 | assert_smoothness_gte: float = Field(default=0.7, ge=0.0, le=1.0) |
| 91 | """Minimum smoothness for PASS.""" |
| 92 | assert_convergence_ratio_lte: float = Field(default=0.7, gt=0.0) |
| 93 | """Maximum ``final_loss / initial_loss`` for PASS. A "well- |
| 94 | trained" run typically halves loss; permissive default tolerates |
| 95 | noisier data.""" |
| 96 | assert_instability_events_lte: int = Field(default=0, ge=0) |
| 97 | """Maximum allowed spike count. Default 0 — any spike → WARN.""" |
| 98 | |
| 99 | |
| 100 | class TrainingDriftProbe(Probe): |
| 101 | """The "did this adapter train smoothly?" pre-run probe.""" |
| 102 | |
| 103 | kind = "training_drift" |
| 104 | spec_cls = TrainingDriftSpec |
| 105 | category = "calibration" |
| 106 | needs_backend: ClassVar[bool] = False |
| 107 | # Pre-run probe: no model load, runs in <100ms, ideal for |
| 108 | # ``sway check`` first-pass before any heavy probe fires. |
| 109 | # Mirrors gradient_ghost's posture. |
| 110 | |
| 111 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 112 | del ctx # No backend / sections / null_stats consumed. |
| 113 | assert isinstance(spec, TrainingDriftSpec) |
| 114 | |
| 115 | if spec.store_path is None: |
| 116 | return ProbeResult( |
| 117 | name=spec.name, |
| 118 | kind=spec.kind, |
| 119 | verdict=Verdict.SKIP, |
| 120 | score=None, |
| 121 | message="no store_path provided (no .dlm context)", |
| 122 | ) |
| 123 | |
| 124 | store_path = Path(spec.store_path).expanduser().resolve() |
| 125 | logs_dir = store_path / "logs" |
| 126 | if not logs_dir.is_dir(): |
| 127 | return ProbeResult( |
| 128 | name=spec.name, |
| 129 | kind=spec.kind, |
| 130 | verdict=Verdict.SKIP, |
| 131 | score=None, |
| 132 | message=f"no logs/ directory under {store_path}", |
| 133 | ) |
| 134 | |
| 135 | log_paths = sorted(logs_dir.glob("train-*.jsonl")) |
| 136 | if not log_paths: |
| 137 | return ProbeResult( |
| 138 | name=spec.name, |
| 139 | kind=spec.kind, |
| 140 | verdict=Verdict.SKIP, |
| 141 | score=None, |
| 142 | message=f"no train-*.jsonl files under {logs_dir}", |
| 143 | ) |
| 144 | |
| 145 | # Concatenate steps across all runs in chronological order. |
| 146 | # Resumed runs may produce duplicate step numbers — dedupe by |
| 147 | # keeping the latest occurrence (the most recent run's value |
| 148 | # for that step). Sorted glob already gives chronological |
| 149 | # order via the timestamp suffix. |
| 150 | try: |
| 151 | steps_by_idx = _collect_steps(log_paths) |
| 152 | except TrainingDriftError as exc: |
| 153 | return ProbeResult( |
| 154 | name=spec.name, |
| 155 | kind=spec.kind, |
| 156 | verdict=Verdict.ERROR, |
| 157 | score=None, |
| 158 | message=str(exc), |
| 159 | evidence={"log_paths": [str(p) for p in log_paths]}, |
| 160 | ) |
| 161 | |
| 162 | if len(steps_by_idx) < spec.min_steps: |
| 163 | return ProbeResult( |
| 164 | name=spec.name, |
| 165 | kind=spec.kind, |
| 166 | verdict=Verdict.SKIP, |
| 167 | score=None, |
| 168 | message=( |
| 169 | f"only {len(steps_by_idx)} step records " |
| 170 | f"(< min_steps={spec.min_steps}); " |
| 171 | f"curve too short to fit reliably" |
| 172 | ), |
| 173 | evidence={"num_steps": len(steps_by_idx)}, |
| 174 | ) |
| 175 | |
| 176 | # Sorted (step, loss) pairs. |
| 177 | ordered = sorted(steps_by_idx.items()) |
| 178 | steps = np.asarray([s for s, _ in ordered], dtype=np.int64) |
| 179 | losses = np.asarray([loss for _, loss in ordered], dtype=np.float64) |
| 180 | |
| 181 | # All four metrics. Each can fail gracefully (NaN-only loss |
| 182 | # column, all-equal losses, etc.) and the verdict respects the |
| 183 | # safe_finalize critical-field guard for ``raw``. |
| 184 | metrics = _compute_metrics( |
| 185 | losses, rolling_window=spec.rolling_window, spike_sigma=spec.spike_sigma |
| 186 | ) |
| 187 | |
| 188 | verdict, score, message = _verdict_from_metrics(metrics, spec) |
| 189 | |
| 190 | # Bound the curve we ship in evidence: a 10k-step run shouldn't |
| 191 | # explode the JSON report. Downsample uniformly to the cap. |
| 192 | curve = _downsampled_curve(steps, losses, cap=512) |
| 193 | |
| 194 | return safe_finalize( |
| 195 | name=spec.name, |
| 196 | kind=spec.kind, |
| 197 | verdict=verdict, |
| 198 | score=score, |
| 199 | raw=metrics["smoothness"], |
| 200 | evidence={ |
| 201 | "final_loss": metrics["final_loss"], |
| 202 | "initial_loss": metrics["initial_loss"], |
| 203 | "convergence_ratio": metrics["convergence_ratio"], |
| 204 | "smoothness": metrics["smoothness"], |
| 205 | "instability_events": metrics["instability_events"], |
| 206 | "num_steps": len(losses), |
| 207 | "num_log_files": len(log_paths), |
| 208 | "curve_sampled": curve, |
| 209 | "thresholds": { |
| 210 | "smoothness_gte": spec.assert_smoothness_gte, |
| 211 | "convergence_ratio_lte": spec.assert_convergence_ratio_lte, |
| 212 | "instability_events_lte": spec.assert_instability_events_lte, |
| 213 | }, |
| 214 | "weight": spec.weight, |
| 215 | }, |
| 216 | message=message, |
| 217 | critical_fields=(), |
| 218 | # ``raw`` (smoothness) is bounded [0, 1] in normal cases; |
| 219 | # the metric helper returns NaN only on degenerate inputs |
| 220 | # which we already surfaced via the SKIP path. Critical- |
| 221 | # field guard would null-out the score on a NaN smoothness |
| 222 | # which is more confusing than helpful here. |
| 223 | ) |
| 224 | |
| 225 | |
| 226 | # --------------------------------------------------------------------------- |
| 227 | # JSONL parsing |
| 228 | # --------------------------------------------------------------------------- |
| 229 | |
| 230 | |
| 231 | def _collect_steps(log_paths: list[Path]) -> dict[int, float]: |
| 232 | """Parse every JSONL, dedupe-by-step, return ``{step: loss}``. |
| 233 | |
| 234 | Resumed runs append to a fresh JSONL but share step numbers with |
| 235 | the prior run. Dedupe-by-keep-latest matches dlm's own ``dlm |
| 236 | metrics`` semantics: the most recent run for a given step wins. |
| 237 | """ |
| 238 | out: dict[int, float] = {} |
| 239 | for path in log_paths: |
| 240 | try: |
| 241 | with path.open("r", encoding="utf-8") as f: |
| 242 | for line_no, raw in enumerate(f, start=1): |
| 243 | raw = raw.strip() |
| 244 | if not raw: |
| 245 | continue |
| 246 | try: |
| 247 | rec = json.loads(raw) |
| 248 | except json.JSONDecodeError as exc: |
| 249 | # A trailing partial line from a crashed |
| 250 | # trainer is the typical cause. Skip it but |
| 251 | # surface as ERROR if EVERY line is broken |
| 252 | # (caller checks ``out`` emptiness). |
| 253 | if line_no == 1: |
| 254 | raise TrainingDriftError( |
| 255 | f"first line of {path.name} is not valid JSON: {exc}" |
| 256 | ) from exc |
| 257 | continue |
| 258 | if not isinstance(rec, dict): |
| 259 | continue |
| 260 | if rec.get("type") != "step": |
| 261 | continue |
| 262 | try: |
| 263 | step = int(rec["step"]) |
| 264 | loss = float(rec["loss"]) |
| 265 | except (KeyError, TypeError, ValueError): |
| 266 | continue |
| 267 | if not math.isfinite(loss): |
| 268 | # NaN loss is a real signal — record it as |
| 269 | # +inf so the spike detector flags the |
| 270 | # instability without crashing on np.log. |
| 271 | loss = math.inf |
| 272 | out[step] = loss |
| 273 | except OSError as exc: |
| 274 | raise TrainingDriftError(f"failed to read {path}: {exc}") from exc |
| 275 | return out |
| 276 | |
| 277 | |
| 278 | # --------------------------------------------------------------------------- |
| 279 | # Metric computation |
| 280 | # --------------------------------------------------------------------------- |
| 281 | |
| 282 | |
| 283 | def _compute_metrics( |
| 284 | losses: np.ndarray, |
| 285 | *, |
| 286 | rolling_window: int, |
| 287 | spike_sigma: float, |
| 288 | ) -> dict[str, float]: |
| 289 | """Compute the four headline metrics from the loss array. |
| 290 | |
| 291 | All metrics are robust to NaN/inf: non-finite step losses are |
| 292 | replaced with the most-recent finite value before fitting (so a |
| 293 | single bad batch doesn't poison the curve), but their *positions* |
| 294 | are still counted as instability events. |
| 295 | """ |
| 296 | losses = losses.astype(np.float64, copy=True) |
| 297 | instability_from_nan = int(np.sum(~np.isfinite(losses))) |
| 298 | |
| 299 | if instability_from_nan > 0: |
| 300 | # Forward-fill non-finite values from the previous finite |
| 301 | # entry so downstream stats don't blow up. The first entry |
| 302 | # MUST be finite (training records the first batch's loss |
| 303 | # before anything could go wrong); guard anyway. |
| 304 | finite_mask = np.isfinite(losses) |
| 305 | if not finite_mask.any(): |
| 306 | # Pathological: every step recorded NaN. Return all-NaN |
| 307 | # metrics so the verdict path can surface the failure. |
| 308 | return { |
| 309 | "initial_loss": float("nan"), |
| 310 | "final_loss": float("nan"), |
| 311 | "convergence_ratio": float("nan"), |
| 312 | "smoothness": 0.0, |
| 313 | "instability_events": float(len(losses)), |
| 314 | } |
| 315 | last_good = float(losses[finite_mask][0]) |
| 316 | for i, v in enumerate(losses): |
| 317 | if not math.isfinite(v): |
| 318 | losses[i] = last_good |
| 319 | else: |
| 320 | last_good = float(v) |
| 321 | |
| 322 | initial_loss = float(losses[0]) |
| 323 | final_loss = float(losses[-1]) |
| 324 | convergence_ratio = float(final_loss / initial_loss) if initial_loss != 0.0 else float("inf") |
| 325 | |
| 326 | deltas = np.diff(losses) |
| 327 | var_loss = float(losses.var()) |
| 328 | var_delta = float(deltas.var()) if deltas.size > 0 else 0.0 |
| 329 | if var_loss > 0.0: # noqa: SIM108 — branch comments are load-bearing |
| 330 | # Clip into [0, 1]: a curve where var(Δloss) > var(loss) |
| 331 | # implies the per-step jitter dominates the overall sweep — |
| 332 | # treat as fully-spiky (smoothness=0) rather than emit a |
| 333 | # negative number. |
| 334 | smoothness = max(0.0, 1.0 - var_delta / var_loss) |
| 335 | else: |
| 336 | # Identical losses across every step: the run never |
| 337 | # progressed. Smoothness is formally 1.0 (perfectly flat), |
| 338 | # but that's misleading — surface as 0.0 so the verdict path |
| 339 | # can flag it. |
| 340 | smoothness = 0.0 |
| 341 | |
| 342 | instability_events = _count_spikes(deltas, window=rolling_window, sigma=spike_sigma) |
| 343 | instability_events += instability_from_nan |
| 344 | |
| 345 | return { |
| 346 | "initial_loss": initial_loss, |
| 347 | "final_loss": final_loss, |
| 348 | "convergence_ratio": convergence_ratio, |
| 349 | "smoothness": smoothness, |
| 350 | "instability_events": float(instability_events), |
| 351 | } |
| 352 | |
| 353 | |
| 354 | def _count_spikes(deltas: np.ndarray, *, window: int, sigma: float) -> int: |
| 355 | """Count loss-increase events that exceed a robust noise threshold. |
| 356 | |
| 357 | The naive ``|Δloss| > sigma · rolling_std`` heuristic is broken for |
| 358 | exponential decay: deltas span orders of magnitude across training, |
| 359 | so within-window std stays tiny while absolute deltas are large — |
| 360 | every step trips the threshold. The semantically correct |
| 361 | "instability event" is a loss *increase*, not a faster-than-typical |
| 362 | decrease. |
| 363 | |
| 364 | Heuristic: |
| 365 | |
| 366 | 1. Restrict to positive deltas (``delta > 0`` ⇒ loss went up). Loss |
| 367 | going down faster than usual isn't an instability — it's the |
| 368 | happy path. |
| 369 | 2. For each positive delta, compare to the median absolute delta |
| 370 | in a centered window. Flag when ``delta > sigma · MAD``, where |
| 371 | MAD is the median absolute deviation (robust to outliers). |
| 372 | 3. Short curves fall back to global MAD; constant-loss curves |
| 373 | (every delta zero) report 0 spikes. |
| 374 | |
| 375 | The ``sigma`` parameter retains its semantic meaning ("how many |
| 376 | typical deltas does this exceed") but operates against the |
| 377 | median-absolute-delta scale rather than std-of-delta. |
| 378 | """ |
| 379 | if deltas.size == 0: |
| 380 | return 0 |
| 381 | |
| 382 | # Median absolute delta — the "typical movement scale" we judge |
| 383 | # spikes against. Median is robust to a few outliers; if the |
| 384 | # whole curve is flat, MAD is 0 and we report no spikes. |
| 385 | abs_deltas = np.abs(deltas) |
| 386 | |
| 387 | if deltas.size < window: |
| 388 | baseline = float(np.median(abs_deltas)) |
| 389 | if baseline == 0.0: |
| 390 | return 0 |
| 391 | return int(np.sum((deltas > 0.0) & (deltas > sigma * baseline))) |
| 392 | |
| 393 | spikes = 0 |
| 394 | half = window // 2 |
| 395 | for i in range(deltas.size): |
| 396 | if deltas[i] <= 0.0: |
| 397 | continue # Loss went down — not an instability event. |
| 398 | lo = max(0, i - half) |
| 399 | hi = min(deltas.size, i + half + 1) |
| 400 | window_slice = abs_deltas[lo:hi] |
| 401 | # Exclude the point itself so a real outlier doesn't inflate |
| 402 | # the baseline it's measured against. |
| 403 | if window_slice.size > 1: |
| 404 | mask = np.ones(window_slice.size, dtype=bool) |
| 405 | mask[i - lo] = False |
| 406 | window_slice = window_slice[mask] |
| 407 | baseline = float(np.median(window_slice)) |
| 408 | if baseline == 0.0: |
| 409 | # Surrounding deltas are all zero — any nonzero positive |
| 410 | # delta is an instability event by construction. |
| 411 | spikes += 1 |
| 412 | continue |
| 413 | if float(deltas[i]) > sigma * baseline: |
| 414 | spikes += 1 |
| 415 | return spikes |
| 416 | |
| 417 | |
| 418 | # --------------------------------------------------------------------------- |
| 419 | # Verdict mapping |
| 420 | # --------------------------------------------------------------------------- |
| 421 | |
| 422 | |
| 423 | def _verdict_from_metrics( |
| 424 | metrics: dict[str, float], spec: TrainingDriftSpec |
| 425 | ) -> tuple[Verdict, float, str]: |
| 426 | """Map the four metrics to a (verdict, score, message).""" |
| 427 | smooth = metrics["smoothness"] |
| 428 | conv = metrics["convergence_ratio"] |
| 429 | instability = int(metrics["instability_events"]) |
| 430 | |
| 431 | smoothness_pass = smooth >= spec.assert_smoothness_gte |
| 432 | convergence_pass = conv <= spec.assert_convergence_ratio_lte |
| 433 | instability_pass = instability <= spec.assert_instability_events_lte |
| 434 | |
| 435 | failures: list[str] = [] |
| 436 | if not smoothness_pass: |
| 437 | failures.append(f"smoothness={smooth:.2f} < {spec.assert_smoothness_gte}") |
| 438 | if not convergence_pass: |
| 439 | failures.append(f"convergence_ratio={conv:.2f} > {spec.assert_convergence_ratio_lte}") |
| 440 | if not instability_pass: |
| 441 | failures.append(f"instability_events={instability} > {spec.assert_instability_events_lte}") |
| 442 | |
| 443 | headline = ( |
| 444 | f"smoothness={smooth:.2f}, convergence_ratio={conv:.2f}, " |
| 445 | f"instability_events={instability}, final_loss={metrics['final_loss']:.3f}" |
| 446 | ) |
| 447 | |
| 448 | if not failures: |
| 449 | # All three thresholds clear → PASS with a normalized score |
| 450 | # blending the three signals (smoothness contributes most; |
| 451 | # convergence and instability are guard rails). |
| 452 | score = float(min(1.0, max(0.0, smooth))) |
| 453 | return Verdict.PASS, score, headline |
| 454 | |
| 455 | # Score: a continuous blend of how far we are from each threshold, |
| 456 | # so the report can rank "borderline warn" against "actively bad." |
| 457 | # Doesn't influence the verdict — that's already FAIL/WARN. |
| 458 | score = float(min(1.0, max(0.0, smooth * 0.5))) |
| 459 | return Verdict.WARN, score, f"{headline}; warnings: {'; '.join(failures)}" |
| 460 | |
| 461 | |
| 462 | # --------------------------------------------------------------------------- |
| 463 | # Curve downsampling for evidence |
| 464 | # --------------------------------------------------------------------------- |
| 465 | |
| 466 | |
| 467 | def _downsampled_curve( |
| 468 | steps: np.ndarray, losses: np.ndarray, *, cap: int |
| 469 | ) -> list[tuple[int, float]]: |
| 470 | """Uniform-stride downsample so a 10k-step run still fits in the JSON report. |
| 471 | |
| 472 | Always preserves the first and last point so initial/final loss |
| 473 | are visible regardless of cap. For curves shorter than the cap, |
| 474 | returns the full series unchanged. |
| 475 | """ |
| 476 | n = int(len(losses)) |
| 477 | if n <= cap: |
| 478 | return [(int(s), float(loss)) for s, loss in zip(steps, losses, strict=True)] |
| 479 | # Stride to keep at most ``cap`` points: stride = ceil(n / cap). |
| 480 | # The +1 accounts for always-appending the final point even when |
| 481 | # ``stride * (cap - 1) < n - 1``. |
| 482 | stride = max(1, (n + cap - 1) // cap) |
| 483 | idx = list(range(0, n, stride)) |
| 484 | if idx[-1] != n - 1: |
| 485 | idx.append(n - 1) |
| 486 | return [(int(steps[i]), float(losses[i])) for i in idx] |