tenseleyflow/sway / 3c0ffd9

Browse files

probes/training_drift: parse dlm per-step JSONLs, score smoothness + spikes

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
3c0ffd9732247a18b90d96b0cef99a8ceef6c60d
Parents
777d726
Tree
4366a76

1 changed file

StatusFile+-
A src/dlm_sway/probes/training_drift.py 486 0
src/dlm_sway/probes/training_drift.pyadded
@@ -0,0 +1,486 @@
1
+"""Training-drift — pre-run probe that reads dlm's per-step loss curve.
2
+
3
+Sister probe to :mod:`gradient_ghost`. Where ``gradient_ghost`` reads
4
+the optimizer state at the end of training, ``training_drift`` reads
5
+the loss curve *during* training: the rich signal about *how* the
6
+adapter trained, not just what it produced.
7
+
8
+Four metrics extracted from the curve:
9
+
10
+- ``final_loss`` — last recorded step's loss.
11
+- ``convergence_ratio`` — ``final_loss / initial_loss``. Lower is
12
+  better; a healthy adapter cuts loss by half or more.
13
+- ``smoothness`` — ``1 - (var(Δloss) / var(loss))``. Values near 1
14
+  mean the curve descends smoothly; near 0 means each step's
15
+  change-in-loss is comparable to the loss range itself (spiky).
16
+- ``instability_events`` — count of steps where
17
+  ``|Δloss| > 3 · rolling_std``. Spikes that survive the rolling
18
+  window are real — they correlate with silent adapter degradation.
19
+
20
+Verdict: PASS when all three of (smoothness ≥ 0.7, instability_events
21
+== 0, convergence_ratio ≤ 0.7); else WARN. The thresholds are
22
+hand-tuned defaults; spec fields override them.
23
+
24
+## Why no null calibration
25
+
26
+Mirrors ``prompt_collapse`` and ``multi_turn_coherence_decay``: a null
27
+adapter doesn't *train* — it has no loss curve. The null distribution
28
+of "smoothness on a noise adapter" is undefined. Fixed-threshold
29
+verdicts are the published path; users override per-spec.
30
+
31
+## Log-format note
32
+
33
+dlm writes one JSONL per run at
34
+``<store_path>/logs/train-NNNNNN-YYYYMMDDTHHMMSS.jsonl``. Each line is
35
+``{"type": "<banner|step|run_complete|...>", ...}``. This probe
36
+filters for ``type == "step"`` records and reads ``step`` + ``loss``.
37
+The sibling ``*.summary.json`` has run aggregates; we don't consume
38
+it here — the curve is richer.
39
+"""
40
+
41
+from __future__ import annotations
42
+
43
+import json
44
+import math
45
+from pathlib import Path
46
+from typing import ClassVar, Literal
47
+
48
+import numpy as np
49
+from pydantic import Field
50
+
51
+from dlm_sway.core.errors import SwayError
52
+from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
53
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
54
+
55
+
56
+class TrainingDriftError(SwayError):
57
+    """Raised when a training log file is structurally unparseable.
58
+
59
+    Distinct from :class:`MissingTrainingStateError`-style absences:
60
+    a missing logs directory is a SKIP (no .dlm context, or a
61
+    pre-training adapter), but a JSONL with corrupted lines or the
62
+    wrong shape is an ERROR — the user has training data but we
63
+    can't read it, which deserves a noisy failure.
64
+    """
65
+
66
+
67
+class TrainingDriftSpec(ProbeSpec):
68
+    """Spec for ``kind: training_drift``."""
69
+
70
+    kind: Literal["training_drift"] = "training_drift"
71
+    store_path: str | None = None
72
+    """Path to a dlm store root (the directory containing
73
+    ``logs/`` and ``adapter/``). When ``None`` the probe SKIPs —
74
+    typically populated by the .dlm autogen path which knows the
75
+    store from the dlm_id resolution."""
76
+    min_steps: int = Field(default=10, ge=2)
77
+    """Skip when the curve has fewer steps than this. Short runs
78
+    (3-step early-stops, smoke tests) produce undefined fits;
79
+    surfacing those as PASS/FAIL would be misleading."""
80
+    rolling_window: int = Field(default=10, ge=2)
81
+    """Window for the rolling std used in spike detection. Larger
82
+    windows mean tighter spike detection at the cost of missing
83
+    rapid oscillations. 10 is a balance for typical 100–1000 step
84
+    runs."""
85
+    spike_sigma: float = Field(default=3.0, gt=0.0)
86
+    """A step is an instability event when ``|Δloss|`` exceeds this
87
+    many rolling standard deviations. 3σ is the conventional
88
+    "outlier" boundary; lower for more sensitivity, higher for less
89
+    noise tolerance."""
90
+    assert_smoothness_gte: float = Field(default=0.7, ge=0.0, le=1.0)
91
+    """Minimum smoothness for PASS."""
92
+    assert_convergence_ratio_lte: float = Field(default=0.7, gt=0.0)
93
+    """Maximum ``final_loss / initial_loss`` for PASS. A "well-
94
+    trained" run typically halves loss; permissive default tolerates
95
+    noisier data."""
96
+    assert_instability_events_lte: int = Field(default=0, ge=0)
97
+    """Maximum allowed spike count. Default 0 — any spike → WARN."""
98
+
99
+
100
+class TrainingDriftProbe(Probe):
101
+    """The "did this adapter train smoothly?" pre-run probe."""
102
+
103
+    kind = "training_drift"
104
+    spec_cls = TrainingDriftSpec
105
+    category = "calibration"
106
+    needs_backend: ClassVar[bool] = False
107
+    # Pre-run probe: no model load, runs in <100ms, ideal for
108
+    # ``sway check`` first-pass before any heavy probe fires.
109
+    # Mirrors gradient_ghost's posture.
110
+
111
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
112
+        del ctx  # No backend / sections / null_stats consumed.
113
+        assert isinstance(spec, TrainingDriftSpec)
114
+
115
+        if spec.store_path is None:
116
+            return ProbeResult(
117
+                name=spec.name,
118
+                kind=spec.kind,
119
+                verdict=Verdict.SKIP,
120
+                score=None,
121
+                message="no store_path provided (no .dlm context)",
122
+            )
123
+
124
+        store_path = Path(spec.store_path).expanduser().resolve()
125
+        logs_dir = store_path / "logs"
126
+        if not logs_dir.is_dir():
127
+            return ProbeResult(
128
+                name=spec.name,
129
+                kind=spec.kind,
130
+                verdict=Verdict.SKIP,
131
+                score=None,
132
+                message=f"no logs/ directory under {store_path}",
133
+            )
134
+
135
+        log_paths = sorted(logs_dir.glob("train-*.jsonl"))
136
+        if not log_paths:
137
+            return ProbeResult(
138
+                name=spec.name,
139
+                kind=spec.kind,
140
+                verdict=Verdict.SKIP,
141
+                score=None,
142
+                message=f"no train-*.jsonl files under {logs_dir}",
143
+            )
144
+
145
+        # Concatenate steps across all runs in chronological order.
146
+        # Resumed runs may produce duplicate step numbers — dedupe by
147
+        # keeping the latest occurrence (the most recent run's value
148
+        # for that step). Sorted glob already gives chronological
149
+        # order via the timestamp suffix.
150
+        try:
151
+            steps_by_idx = _collect_steps(log_paths)
152
+        except TrainingDriftError as exc:
153
+            return ProbeResult(
154
+                name=spec.name,
155
+                kind=spec.kind,
156
+                verdict=Verdict.ERROR,
157
+                score=None,
158
+                message=str(exc),
159
+                evidence={"log_paths": [str(p) for p in log_paths]},
160
+            )
161
+
162
+        if len(steps_by_idx) < spec.min_steps:
163
+            return ProbeResult(
164
+                name=spec.name,
165
+                kind=spec.kind,
166
+                verdict=Verdict.SKIP,
167
+                score=None,
168
+                message=(
169
+                    f"only {len(steps_by_idx)} step records "
170
+                    f"(< min_steps={spec.min_steps}); "
171
+                    f"curve too short to fit reliably"
172
+                ),
173
+                evidence={"num_steps": len(steps_by_idx)},
174
+            )
175
+
176
+        # Sorted (step, loss) pairs.
177
+        ordered = sorted(steps_by_idx.items())
178
+        steps = np.asarray([s for s, _ in ordered], dtype=np.int64)
179
+        losses = np.asarray([loss for _, loss in ordered], dtype=np.float64)
180
+
181
+        # All four metrics. Each can fail gracefully (NaN-only loss
182
+        # column, all-equal losses, etc.) and the verdict respects the
183
+        # safe_finalize critical-field guard for ``raw``.
184
+        metrics = _compute_metrics(
185
+            losses, rolling_window=spec.rolling_window, spike_sigma=spec.spike_sigma
186
+        )
187
+
188
+        verdict, score, message = _verdict_from_metrics(metrics, spec)
189
+
190
+        # Bound the curve we ship in evidence: a 10k-step run shouldn't
191
+        # explode the JSON report. Downsample uniformly to the cap.
192
+        curve = _downsampled_curve(steps, losses, cap=512)
193
+
194
+        return safe_finalize(
195
+            name=spec.name,
196
+            kind=spec.kind,
197
+            verdict=verdict,
198
+            score=score,
199
+            raw=metrics["smoothness"],
200
+            evidence={
201
+                "final_loss": metrics["final_loss"],
202
+                "initial_loss": metrics["initial_loss"],
203
+                "convergence_ratio": metrics["convergence_ratio"],
204
+                "smoothness": metrics["smoothness"],
205
+                "instability_events": metrics["instability_events"],
206
+                "num_steps": len(losses),
207
+                "num_log_files": len(log_paths),
208
+                "curve_sampled": curve,
209
+                "thresholds": {
210
+                    "smoothness_gte": spec.assert_smoothness_gte,
211
+                    "convergence_ratio_lte": spec.assert_convergence_ratio_lte,
212
+                    "instability_events_lte": spec.assert_instability_events_lte,
213
+                },
214
+                "weight": spec.weight,
215
+            },
216
+            message=message,
217
+            critical_fields=(),
218
+            # ``raw`` (smoothness) is bounded [0, 1] in normal cases;
219
+            # the metric helper returns NaN only on degenerate inputs
220
+            # which we already surfaced via the SKIP path. Critical-
221
+            # field guard would null-out the score on a NaN smoothness
222
+            # which is more confusing than helpful here.
223
+        )
224
+
225
+
226
+# ---------------------------------------------------------------------------
227
+# JSONL parsing
228
+# ---------------------------------------------------------------------------
229
+
230
+
231
+def _collect_steps(log_paths: list[Path]) -> dict[int, float]:
232
+    """Parse every JSONL, dedupe-by-step, return ``{step: loss}``.
233
+
234
+    Resumed runs append to a fresh JSONL but share step numbers with
235
+    the prior run. Dedupe-by-keep-latest matches dlm's own ``dlm
236
+    metrics`` semantics: the most recent run for a given step wins.
237
+    """
238
+    out: dict[int, float] = {}
239
+    for path in log_paths:
240
+        try:
241
+            with path.open("r", encoding="utf-8") as f:
242
+                for line_no, raw in enumerate(f, start=1):
243
+                    raw = raw.strip()
244
+                    if not raw:
245
+                        continue
246
+                    try:
247
+                        rec = json.loads(raw)
248
+                    except json.JSONDecodeError as exc:
249
+                        # A trailing partial line from a crashed
250
+                        # trainer is the typical cause. Skip it but
251
+                        # surface as ERROR if EVERY line is broken
252
+                        # (caller checks ``out`` emptiness).
253
+                        if line_no == 1:
254
+                            raise TrainingDriftError(
255
+                                f"first line of {path.name} is not valid JSON: {exc}"
256
+                            ) from exc
257
+                        continue
258
+                    if not isinstance(rec, dict):
259
+                        continue
260
+                    if rec.get("type") != "step":
261
+                        continue
262
+                    try:
263
+                        step = int(rec["step"])
264
+                        loss = float(rec["loss"])
265
+                    except (KeyError, TypeError, ValueError):
266
+                        continue
267
+                    if not math.isfinite(loss):
268
+                        # NaN loss is a real signal — record it as
269
+                        # +inf so the spike detector flags the
270
+                        # instability without crashing on np.log.
271
+                        loss = math.inf
272
+                    out[step] = loss
273
+        except OSError as exc:
274
+            raise TrainingDriftError(f"failed to read {path}: {exc}") from exc
275
+    return out
276
+
277
+
278
+# ---------------------------------------------------------------------------
279
+# Metric computation
280
+# ---------------------------------------------------------------------------
281
+
282
+
283
+def _compute_metrics(
284
+    losses: np.ndarray,
285
+    *,
286
+    rolling_window: int,
287
+    spike_sigma: float,
288
+) -> dict[str, float]:
289
+    """Compute the four headline metrics from the loss array.
290
+
291
+    All metrics are robust to NaN/inf: non-finite step losses are
292
+    replaced with the most-recent finite value before fitting (so a
293
+    single bad batch doesn't poison the curve), but their *positions*
294
+    are still counted as instability events.
295
+    """
296
+    losses = losses.astype(np.float64, copy=True)
297
+    instability_from_nan = int(np.sum(~np.isfinite(losses)))
298
+
299
+    if instability_from_nan > 0:
300
+        # Forward-fill non-finite values from the previous finite
301
+        # entry so downstream stats don't blow up. The first entry
302
+        # MUST be finite (training records the first batch's loss
303
+        # before anything could go wrong); guard anyway.
304
+        finite_mask = np.isfinite(losses)
305
+        if not finite_mask.any():
306
+            # Pathological: every step recorded NaN. Return all-NaN
307
+            # metrics so the verdict path can surface the failure.
308
+            return {
309
+                "initial_loss": float("nan"),
310
+                "final_loss": float("nan"),
311
+                "convergence_ratio": float("nan"),
312
+                "smoothness": 0.0,
313
+                "instability_events": float(len(losses)),
314
+            }
315
+        last_good = float(losses[finite_mask][0])
316
+        for i, v in enumerate(losses):
317
+            if not math.isfinite(v):
318
+                losses[i] = last_good
319
+            else:
320
+                last_good = float(v)
321
+
322
+    initial_loss = float(losses[0])
323
+    final_loss = float(losses[-1])
324
+    convergence_ratio = float(final_loss / initial_loss) if initial_loss != 0.0 else float("inf")
325
+
326
+    deltas = np.diff(losses)
327
+    var_loss = float(losses.var())
328
+    var_delta = float(deltas.var()) if deltas.size > 0 else 0.0
329
+    if var_loss > 0.0:  # noqa: SIM108 — branch comments are load-bearing
330
+        # Clip into [0, 1]: a curve where var(Δloss) > var(loss)
331
+        # implies the per-step jitter dominates the overall sweep —
332
+        # treat as fully-spiky (smoothness=0) rather than emit a
333
+        # negative number.
334
+        smoothness = max(0.0, 1.0 - var_delta / var_loss)
335
+    else:
336
+        # Identical losses across every step: the run never
337
+        # progressed. Smoothness is formally 1.0 (perfectly flat),
338
+        # but that's misleading — surface as 0.0 so the verdict path
339
+        # can flag it.
340
+        smoothness = 0.0
341
+
342
+    instability_events = _count_spikes(deltas, window=rolling_window, sigma=spike_sigma)
343
+    instability_events += instability_from_nan
344
+
345
+    return {
346
+        "initial_loss": initial_loss,
347
+        "final_loss": final_loss,
348
+        "convergence_ratio": convergence_ratio,
349
+        "smoothness": smoothness,
350
+        "instability_events": float(instability_events),
351
+    }
352
+
353
+
354
+def _count_spikes(deltas: np.ndarray, *, window: int, sigma: float) -> int:
355
+    """Count loss-increase events that exceed a robust noise threshold.
356
+
357
+    The naive ``|Δloss| > sigma · rolling_std`` heuristic is broken for
358
+    exponential decay: deltas span orders of magnitude across training,
359
+    so within-window std stays tiny while absolute deltas are large —
360
+    every step trips the threshold. The semantically correct
361
+    "instability event" is a loss *increase*, not a faster-than-typical
362
+    decrease.
363
+
364
+    Heuristic:
365
+
366
+    1. Restrict to positive deltas (``delta > 0`` ⇒ loss went up). Loss
367
+       going down faster than usual isn't an instability — it's the
368
+       happy path.
369
+    2. For each positive delta, compare to the median absolute delta
370
+       in a centered window. Flag when ``delta > sigma · MAD``, where
371
+       MAD is the median absolute deviation (robust to outliers).
372
+    3. Short curves fall back to global MAD; constant-loss curves
373
+       (every delta zero) report 0 spikes.
374
+
375
+    The ``sigma`` parameter retains its semantic meaning ("how many
376
+    typical deltas does this exceed") but operates against the
377
+    median-absolute-delta scale rather than std-of-delta.
378
+    """
379
+    if deltas.size == 0:
380
+        return 0
381
+
382
+    # Median absolute delta — the "typical movement scale" we judge
383
+    # spikes against. Median is robust to a few outliers; if the
384
+    # whole curve is flat, MAD is 0 and we report no spikes.
385
+    abs_deltas = np.abs(deltas)
386
+
387
+    if deltas.size < window:
388
+        baseline = float(np.median(abs_deltas))
389
+        if baseline == 0.0:
390
+            return 0
391
+        return int(np.sum((deltas > 0.0) & (deltas > sigma * baseline)))
392
+
393
+    spikes = 0
394
+    half = window // 2
395
+    for i in range(deltas.size):
396
+        if deltas[i] <= 0.0:
397
+            continue  # Loss went down — not an instability event.
398
+        lo = max(0, i - half)
399
+        hi = min(deltas.size, i + half + 1)
400
+        window_slice = abs_deltas[lo:hi]
401
+        # Exclude the point itself so a real outlier doesn't inflate
402
+        # the baseline it's measured against.
403
+        if window_slice.size > 1:
404
+            mask = np.ones(window_slice.size, dtype=bool)
405
+            mask[i - lo] = False
406
+            window_slice = window_slice[mask]
407
+        baseline = float(np.median(window_slice))
408
+        if baseline == 0.0:
409
+            # Surrounding deltas are all zero — any nonzero positive
410
+            # delta is an instability event by construction.
411
+            spikes += 1
412
+            continue
413
+        if float(deltas[i]) > sigma * baseline:
414
+            spikes += 1
415
+    return spikes
416
+
417
+
418
+# ---------------------------------------------------------------------------
419
+# Verdict mapping
420
+# ---------------------------------------------------------------------------
421
+
422
+
423
+def _verdict_from_metrics(
424
+    metrics: dict[str, float], spec: TrainingDriftSpec
425
+) -> tuple[Verdict, float, str]:
426
+    """Map the four metrics to a (verdict, score, message)."""
427
+    smooth = metrics["smoothness"]
428
+    conv = metrics["convergence_ratio"]
429
+    instability = int(metrics["instability_events"])
430
+
431
+    smoothness_pass = smooth >= spec.assert_smoothness_gte
432
+    convergence_pass = conv <= spec.assert_convergence_ratio_lte
433
+    instability_pass = instability <= spec.assert_instability_events_lte
434
+
435
+    failures: list[str] = []
436
+    if not smoothness_pass:
437
+        failures.append(f"smoothness={smooth:.2f} < {spec.assert_smoothness_gte}")
438
+    if not convergence_pass:
439
+        failures.append(f"convergence_ratio={conv:.2f} > {spec.assert_convergence_ratio_lte}")
440
+    if not instability_pass:
441
+        failures.append(f"instability_events={instability} > {spec.assert_instability_events_lte}")
442
+
443
+    headline = (
444
+        f"smoothness={smooth:.2f}, convergence_ratio={conv:.2f}, "
445
+        f"instability_events={instability}, final_loss={metrics['final_loss']:.3f}"
446
+    )
447
+
448
+    if not failures:
449
+        # All three thresholds clear → PASS with a normalized score
450
+        # blending the three signals (smoothness contributes most;
451
+        # convergence and instability are guard rails).
452
+        score = float(min(1.0, max(0.0, smooth)))
453
+        return Verdict.PASS, score, headline
454
+
455
+    # Score: a continuous blend of how far we are from each threshold,
456
+    # so the report can rank "borderline warn" against "actively bad."
457
+    # Doesn't influence the verdict — that's already FAIL/WARN.
458
+    score = float(min(1.0, max(0.0, smooth * 0.5)))
459
+    return Verdict.WARN, score, f"{headline}; warnings: {'; '.join(failures)}"
460
+
461
+
462
+# ---------------------------------------------------------------------------
463
+# Curve downsampling for evidence
464
+# ---------------------------------------------------------------------------
465
+
466
+
467
+def _downsampled_curve(
468
+    steps: np.ndarray, losses: np.ndarray, *, cap: int
469
+) -> list[tuple[int, float]]:
470
+    """Uniform-stride downsample so a 10k-step run still fits in the JSON report.
471
+
472
+    Always preserves the first and last point so initial/final loss
473
+    are visible regardless of cap. For curves shorter than the cap,
474
+    returns the full series unchanged.
475
+    """
476
+    n = int(len(losses))
477
+    if n <= cap:
478
+        return [(int(s), float(loss)) for s, loss in zip(steps, losses, strict=True)]
479
+    # Stride to keep at most ``cap`` points: stride = ceil(n / cap).
480
+    # The +1 accounts for always-appending the final point even when
481
+    # ``stride * (cap - 1) < n - 1``.
482
+    stride = max(1, (n + cap - 1) // cap)
483
+    idx = list(range(0, n, stride))
484
+    if idx[-1] != n - 1:
485
+        idx.append(n - 1)
486
+    return [(int(steps[i]), float(losses[i])) for i in idx]