tenseleyflow/sway / 73d923c

Browse files

tests/unit: 30 tests for training_drift probe + helpers + real-fixture parse

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
73d923c37d2568a18eda6eda1170ba7a9512f9cd
Parents
0586d99
Tree
7b76342

1 changed file

StatusFile+-
A tests/unit/test_probe_training_drift.py 508 0
tests/unit/test_probe_training_drift.pyadded
@@ -0,0 +1,508 @@
1
+"""Tests for :mod:`dlm_sway.probes.training_drift`."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+import math
7
+from pathlib import Path
8
+
9
+import numpy as np
10
+import pytest
11
+
12
+from dlm_sway.core.result import Verdict
13
+from dlm_sway.probes.base import RunContext, build_probe
14
+from dlm_sway.probes.training_drift import (
15
+    TrainingDriftError,
16
+    _collect_steps,
17
+    _compute_metrics,
18
+    _count_spikes,
19
+    _downsampled_curve,
20
+    _verdict_from_metrics,
21
+)
22
+
23
+# ---------------------------------------------------------------------------
24
+# Fixture helpers
25
+# ---------------------------------------------------------------------------
26
+
27
+
28
+def _write_jsonl(
29
+    path: Path, *, banner: bool = True, steps: list[tuple[int, float]] | None = None
30
+) -> Path:
31
+    """Build a dlm-shaped train-*.jsonl fixture file.
32
+
33
+    Mirrors the real format:
34
+    - Optional banner line (type=banner).
35
+    - Per-step lines (type=step) with step + loss + lr + grad_norm.
36
+    - No closing line — real runs may crash mid-step; the probe
37
+      should tolerate truncated files.
38
+    """
39
+    path.parent.mkdir(parents=True, exist_ok=True)
40
+    lines: list[str] = []
41
+    if banner:
42
+        lines.append(json.dumps({"type": "banner", "run_id": 1, "seed": 42}))
43
+    for step, loss in steps or []:
44
+        lines.append(
45
+            json.dumps({"type": "step", "step": step, "loss": loss, "lr": 1e-4, "grad_norm": 0.5})
46
+        )
47
+    path.write_text("\n".join(lines) + "\n", encoding="utf-8")
48
+    return path
49
+
50
+
51
+def _smooth_decay(num_steps: int = 50) -> list[tuple[int, float]]:
52
+    """Clean exponential decay — should pass every threshold."""
53
+    return [(i, 5.0 * math.exp(-i / 20.0) + 0.1) for i in range(num_steps)]
54
+
55
+
56
+def _spiky_curve(num_steps: int = 50, *, spike_at: int = 25) -> list[tuple[int, float]]:
57
+    """Clean decay with one obvious loss-increase spike injected.
58
+
59
+    The spike is a loss *increase* relative to the prior step (the
60
+    semantically meaningful "training instability" signal — fast
61
+    convergence steps are not instabilities).
62
+    """
63
+    base = _smooth_decay(num_steps)
64
+    out = list(base)
65
+    # Inject a sharp upward spike: loss jumps from ~base value to 5x
66
+    # of itself. The next step's delta back down doesn't count as a
67
+    # spike (loss decreases aren't instabilities by design).
68
+    out[spike_at] = (spike_at, base[spike_at][1] * 5.0)
69
+    return out
70
+
71
+
72
+# ---------------------------------------------------------------------------
73
+# End-to-end probe behavior
74
+# ---------------------------------------------------------------------------
75
+
76
+
77
+class TestProbeBehavior:
78
+    def test_pass_on_smooth_curve(self, tmp_path: Path) -> None:
79
+        store = tmp_path / "store"
80
+        _write_jsonl(
81
+            store / "logs" / "train-000001-20260101T000000.jsonl",
82
+            steps=_smooth_decay(60),
83
+        )
84
+        probe, spec = build_probe(
85
+            {
86
+                "name": "td",
87
+                "kind": "training_drift",
88
+                "store_path": str(store),
89
+            }
90
+        )
91
+        result = probe.run(spec, RunContext())
92
+        assert result.verdict == Verdict.PASS, result.message
93
+        assert result.evidence["instability_events"] == 0
94
+        assert result.evidence["smoothness"] >= 0.7
95
+        # Initial loss is the curve's first sample, final is the last.
96
+        assert result.evidence["initial_loss"] == pytest.approx(5.1, abs=0.1)
97
+        assert result.evidence["final_loss"] < 1.0
98
+
99
+    def test_warn_on_spiky_curve(self, tmp_path: Path) -> None:
100
+        store = tmp_path / "store"
101
+        _write_jsonl(store / "logs" / "train-000001-20260101T000000.jsonl", steps=_spiky_curve(60))
102
+        probe, spec = build_probe(
103
+            {
104
+                "name": "td",
105
+                "kind": "training_drift",
106
+                "store_path": str(store),
107
+            }
108
+        )
109
+        result = probe.run(spec, RunContext())
110
+        assert result.verdict == Verdict.WARN
111
+        assert result.evidence["instability_events"] >= 1
112
+        assert "instability_events" in result.message
113
+
114
+    def test_skip_when_no_store_path(self) -> None:
115
+        probe, spec = build_probe({"name": "td", "kind": "training_drift"})
116
+        result = probe.run(spec, RunContext())
117
+        assert result.verdict == Verdict.SKIP
118
+        assert "no store_path" in result.message
119
+
120
+    def test_skip_when_logs_dir_missing(self, tmp_path: Path) -> None:
121
+        store = tmp_path / "store"
122
+        store.mkdir()
123
+        probe, spec = build_probe(
124
+            {
125
+                "name": "td",
126
+                "kind": "training_drift",
127
+                "store_path": str(store),
128
+            }
129
+        )
130
+        result = probe.run(spec, RunContext())
131
+        assert result.verdict == Verdict.SKIP
132
+        assert "no logs/" in result.message
133
+
134
+    def test_skip_when_no_jsonl(self, tmp_path: Path) -> None:
135
+        store = tmp_path / "store"
136
+        (store / "logs").mkdir(parents=True)
137
+        probe, spec = build_probe(
138
+            {
139
+                "name": "td",
140
+                "kind": "training_drift",
141
+                "store_path": str(store),
142
+            }
143
+        )
144
+        result = probe.run(spec, RunContext())
145
+        assert result.verdict == Verdict.SKIP
146
+        assert "no train-*.jsonl" in result.message
147
+
148
+    def test_skip_when_too_few_steps(self, tmp_path: Path) -> None:
149
+        """Default min_steps=10 — a 3-step curve must SKIP, not produce
150
+        a misleading verdict."""
151
+        store = tmp_path / "store"
152
+        _write_jsonl(
153
+            store / "logs" / "train-000001-20260101T000000.jsonl",
154
+            steps=[(0, 5.0), (1, 4.5), (2, 4.0)],
155
+        )
156
+        probe, spec = build_probe(
157
+            {
158
+                "name": "td",
159
+                "kind": "training_drift",
160
+                "store_path": str(store),
161
+            }
162
+        )
163
+        result = probe.run(spec, RunContext())
164
+        assert result.verdict == Verdict.SKIP
165
+        assert "too short" in result.message
166
+
167
+    def test_resumed_runs_dedupe_keep_latest(self, tmp_path: Path) -> None:
168
+        """Two log files with overlapping step numbers — second one wins
169
+        (mirrors dlm metrics' resume semantics)."""
170
+        store = tmp_path / "store"
171
+        # First run: steps 0..9 with high losses
172
+        _write_jsonl(
173
+            store / "logs" / "train-000001-20260101T000000.jsonl",
174
+            steps=[(i, 10.0 + i) for i in range(10)],
175
+        )
176
+        # Resumed run: steps 5..14 with low losses (resume picked up,
177
+        # values for 5..9 should overwrite the originals)
178
+        _write_jsonl(
179
+            store / "logs" / "train-000001-20260101T010000.jsonl",
180
+            steps=[(i, 1.0) for i in range(5, 15)],
181
+        )
182
+        probe, spec = build_probe(
183
+            {
184
+                "name": "td",
185
+                "kind": "training_drift",
186
+                "store_path": str(store),
187
+            }
188
+        )
189
+        result = probe.run(spec, RunContext())
190
+        # Final loss should be from the resumed run (1.0), not the
191
+        # first run's step 14 (which doesn't exist).
192
+        assert result.evidence["final_loss"] == 1.0
193
+        # Step 5 should carry the resumed value, not the original.
194
+        curve = result.evidence["curve_sampled"]
195
+        step_5 = next((loss for s, loss in curve if s == 5), None)
196
+        assert step_5 == 1.0, f"step 5 should be from resumed run; got {step_5}"
197
+
198
+    def test_curve_downsampled_when_long(self, tmp_path: Path) -> None:
199
+        """A 1500-step run should land in evidence with curve_sampled <= 512."""
200
+        store = tmp_path / "store"
201
+        _write_jsonl(
202
+            store / "logs" / "train-000001-20260101T000000.jsonl",
203
+            steps=[(i, 5.0 * math.exp(-i / 500.0)) for i in range(1500)],
204
+        )
205
+        probe, spec = build_probe(
206
+            {
207
+                "name": "td",
208
+                "kind": "training_drift",
209
+                "store_path": str(store),
210
+            }
211
+        )
212
+        result = probe.run(spec, RunContext())
213
+        assert result.evidence["num_steps"] == 1500
214
+        curve = result.evidence["curve_sampled"]
215
+        assert len(curve) <= 512
216
+        # Endpoints preserved.
217
+        assert curve[0][0] == 0
218
+        assert curve[-1][0] == 1499
219
+
220
+    def test_corrupt_first_line_errors(self, tmp_path: Path) -> None:
221
+        store = tmp_path / "store"
222
+        log_dir = store / "logs"
223
+        log_dir.mkdir(parents=True)
224
+        (log_dir / "train-000001-20260101T000000.jsonl").write_text(
225
+            "not even json\n", encoding="utf-8"
226
+        )
227
+        probe, spec = build_probe(
228
+            {
229
+                "name": "td",
230
+                "kind": "training_drift",
231
+                "store_path": str(store),
232
+            }
233
+        )
234
+        result = probe.run(spec, RunContext())
235
+        assert result.verdict == Verdict.ERROR
236
+        assert "not valid JSON" in result.message
237
+
238
+    def test_truncated_trailing_line_tolerated(self, tmp_path: Path) -> None:
239
+        """A crashed-mid-line trainer leaves a partial JSON tail. The
240
+        probe should consume the good lines and skip the bad one."""
241
+        store = tmp_path / "store"
242
+        log = store / "logs" / "train-000001-20260101T000000.jsonl"
243
+        log.parent.mkdir(parents=True)
244
+        good_lines = [
245
+            json.dumps({"type": "banner", "run_id": 1}),
246
+        ] + [
247
+            json.dumps({"type": "step", "step": i, "loss": 5.0 - i * 0.05, "lr": 1e-4})
248
+            for i in range(60)
249
+        ]
250
+        # Trailing partial line a crashed trainer might emit.
251
+        log.write_text(
252
+            "\n".join(good_lines) + '\n{"type": "step", "step": 60, "lo', encoding="utf-8"
253
+        )
254
+        probe, spec = build_probe(
255
+            {
256
+                "name": "td",
257
+                "kind": "training_drift",
258
+                "store_path": str(store),
259
+            }
260
+        )
261
+        result = probe.run(spec, RunContext())
262
+        # Verdict can be PASS or WARN depending on the curve, but it
263
+        # must not be ERROR — the partial line shouldn't break the run.
264
+        assert result.verdict in {Verdict.PASS, Verdict.WARN}
265
+        assert result.evidence["num_steps"] == 60
266
+
267
+
268
+# ---------------------------------------------------------------------------
269
+# Pure-math metric helpers
270
+# ---------------------------------------------------------------------------
271
+
272
+
273
+class TestComputeMetrics:
274
+    def test_smooth_decay_metrics(self) -> None:
275
+        losses = np.array([5.0 * math.exp(-i / 20.0) + 0.1 for i in range(50)])
276
+        m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
277
+        assert m["instability_events"] == 0
278
+        assert m["smoothness"] > 0.95
279
+        assert 0.0 < m["convergence_ratio"] < 0.2
280
+
281
+    def test_constant_loss_marked_unsmooth(self) -> None:
282
+        """A perfectly flat curve is NOT 'smooth' — it's a stuck run."""
283
+        losses = np.full(50, 5.0)
284
+        m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
285
+        assert m["smoothness"] == 0.0
286
+        assert m["convergence_ratio"] == 1.0
287
+
288
+    def test_nan_loss_counts_as_instability(self) -> None:
289
+        """A NaN in the curve should count as an instability event but
290
+        not crash the metric computation (NaN propagation breaks
291
+        everything otherwise)."""
292
+        losses = np.array([5.0 - i * 0.1 for i in range(50)])
293
+        losses[20] = float("nan")
294
+        m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
295
+        assert m["instability_events"] >= 1
296
+        # The forward-fill kept downstream stats finite.
297
+        assert math.isfinite(m["smoothness"])
298
+        assert math.isfinite(m["final_loss"])
299
+
300
+    def test_all_nan_returns_zero_smoothness(self) -> None:
301
+        losses = np.array([float("nan")] * 10)
302
+        m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
303
+        assert m["smoothness"] == 0.0
304
+        assert m["instability_events"] == 10
305
+
306
+    def test_zero_initial_loss_returns_inf_ratio(self) -> None:
307
+        """Initial loss of 0 (degenerate) → convergence ratio is inf
308
+        rather than ZeroDivisionError."""
309
+        losses = np.array([0.0, 0.5, 1.0, 1.5])
310
+        m = _compute_metrics(losses, rolling_window=2, spike_sigma=3.0)
311
+        assert m["convergence_ratio"] == float("inf")
312
+
313
+
314
+class TestCountSpikes:
315
+    def test_no_spikes_in_smoothly_decaying_curve(self) -> None:
316
+        """Loss going down — not an instability, regardless of |Δ|."""
317
+        deltas = np.array([-0.05] * 50)
318
+        assert _count_spikes(deltas, window=10, sigma=3.0) == 0
319
+
320
+    def test_no_spikes_in_constant_curve(self) -> None:
321
+        deltas = np.zeros(50)
322
+        assert _count_spikes(deltas, window=10, sigma=3.0) == 0
323
+
324
+    def test_loss_increase_outlier_caught(self) -> None:
325
+        """A genuine training spike: loss-up event much larger than typical."""
326
+        deltas = np.array([-0.05] * 50)
327
+        deltas[25] = 1.5  # loss jumped UP
328
+        assert _count_spikes(deltas, window=10, sigma=3.0) == 1
329
+
330
+    def test_loss_decrease_outlier_ignored(self) -> None:
331
+        """A 'fast convergence' step (loss going down hard) is NOT an
332
+        instability — only loss-up events count."""
333
+        deltas = np.array([-0.05] * 50)
334
+        deltas[25] = -2.0  # huge negative delta — fast convergence
335
+        assert _count_spikes(deltas, window=10, sigma=3.0) == 0
336
+
337
+    def test_short_curve_uses_global_baseline(self) -> None:
338
+        # 5 deltas, window=10 → falls back to global MAD.
339
+        deltas = np.array([-0.01, -0.01, 1.0, -0.01, -0.01])
340
+        spikes = _count_spikes(deltas, window=10, sigma=2.0)
341
+        assert spikes == 1
342
+
343
+    def test_empty_deltas_returns_zero(self) -> None:
344
+        assert _count_spikes(np.array([]), window=10, sigma=3.0) == 0
345
+
346
+
347
+class TestDownsampledCurve:
348
+    def test_short_curve_unchanged(self) -> None:
349
+        steps = np.array([0, 1, 2, 3])
350
+        losses = np.array([5.0, 4.0, 3.0, 2.0])
351
+        out = _downsampled_curve(steps, losses, cap=10)
352
+        assert len(out) == 4
353
+        assert out == [(0, 5.0), (1, 4.0), (2, 3.0), (3, 2.0)]
354
+
355
+    def test_long_curve_capped_with_endpoints_preserved(self) -> None:
356
+        steps = np.arange(2000)
357
+        losses = np.linspace(5.0, 0.5, 2000)
358
+        out = _downsampled_curve(steps, losses, cap=100)
359
+        assert len(out) <= 110  # cap with the +1 endpoint allowance
360
+        assert out[0][0] == 0
361
+        assert out[-1][0] == 1999
362
+
363
+
364
+class TestVerdictFromMetrics:
365
+    def _spec(self, **kwargs: object) -> object:
366
+        from dlm_sway.probes.training_drift import TrainingDriftSpec
367
+
368
+        return TrainingDriftSpec(name="td", kind="training_drift", **kwargs)  # type: ignore[arg-type]
369
+
370
+    def test_pass_when_all_thresholds_clear(self) -> None:
371
+        spec = self._spec()
372
+        v, _, msg = _verdict_from_metrics(
373
+            {
374
+                "smoothness": 0.9,
375
+                "convergence_ratio": 0.4,
376
+                "instability_events": 0,
377
+                "final_loss": 0.5,
378
+            },
379
+            spec,  # type: ignore[arg-type]
380
+        )
381
+        assert v == Verdict.PASS
382
+        assert "smoothness=0.90" in msg
383
+        assert "warnings:" not in msg
384
+
385
+    def test_warn_lists_each_failed_threshold(self) -> None:
386
+        spec = self._spec()
387
+        v, _, msg = _verdict_from_metrics(
388
+            {
389
+                "smoothness": 0.5,
390
+                "convergence_ratio": 0.9,
391
+                "instability_events": 3,
392
+                "final_loss": 4.5,
393
+            },
394
+            spec,  # type: ignore[arg-type]
395
+        )
396
+        assert v == Verdict.WARN
397
+        assert "smoothness=0.50" in msg
398
+        assert "convergence_ratio=0.90" in msg
399
+        assert "instability_events=3" in msg
400
+
401
+
402
+# ---------------------------------------------------------------------------
403
+# Collect steps (JSONL parsing edge cases)
404
+# ---------------------------------------------------------------------------
405
+
406
+
407
+class TestRealDlmFixture:
408
+    """Validate the probe against a JSONL captured from a real dlm run.
409
+
410
+    The fixture under ``tests/fixtures/dlm_train_log_fixture.jsonl`` is
411
+    a captured-from-disk shape: leading banner, an interleaved
412
+    ``type=delta`` (doc-change record), 30 ``type=step`` records, and
413
+    a closing ``type=run_complete``. If this test breaks, dlm's log
414
+    format has shifted and the probe needs an update — that's
415
+    exactly the regression signal we want.
416
+    """
417
+
418
+    def test_parses_real_fixture_to_pass_verdict(self, tmp_path: Path) -> None:
419
+        fixture = (
420
+            Path(__file__).resolve().parent.parent / "fixtures" / "dlm_train_log_fixture.jsonl"
421
+        )
422
+        store = tmp_path / "store"
423
+        store.mkdir()
424
+        (store / "logs").mkdir()
425
+        (store / "logs" / "train-000001-20260426T062514.jsonl").write_bytes(fixture.read_bytes())
426
+
427
+        probe, spec = build_probe(
428
+            {
429
+                "name": "td",
430
+                "kind": "training_drift",
431
+                "store_path": str(store),
432
+                # The fixture's tail flattens out (loss converges) so
433
+                # the curve has a stable plateau. Permissive convergence
434
+                # threshold to focus the assertion on format compat.
435
+                "assert_convergence_ratio_lte": 0.5,
436
+            }
437
+        )
438
+        result = probe.run(spec, RunContext())
439
+        assert result.verdict == Verdict.PASS, result.message
440
+        assert result.evidence["num_steps"] == 30
441
+        assert result.evidence["instability_events"] == 0
442
+        # Final loss was 1.911 in the fixture; just check the right
443
+        # ballpark so future fixture tweaks don't spuriously fail.
444
+        assert 1.8 < result.evidence["final_loss"] < 2.0
445
+        assert result.evidence["initial_loss"] > 5.0
446
+
447
+
448
+class TestCollectSteps:
449
+    def test_filters_non_step_records(self, tmp_path: Path) -> None:
450
+        log = tmp_path / "train-000001.jsonl"
451
+        log.write_text(
452
+            "\n".join(
453
+                [
454
+                    json.dumps({"type": "banner", "run_id": 1}),
455
+                    json.dumps({"type": "step", "step": 0, "loss": 5.0}),
456
+                    json.dumps({"type": "delta", "new": [], "removed": []}),
457
+                    json.dumps({"type": "step", "step": 1, "loss": 4.0}),
458
+                    json.dumps({"type": "run_complete", "elapsed_seconds": 10.0}),
459
+                ]
460
+            )
461
+            + "\n",
462
+            encoding="utf-8",
463
+        )
464
+        out = _collect_steps([log])
465
+        assert out == {0: 5.0, 1: 4.0}
466
+
467
+    def test_missing_step_key_skipped(self, tmp_path: Path) -> None:
468
+        """A 'step' record missing required fields is dropped — the
469
+        parser doesn't crash the run on a single bad record."""
470
+        log = tmp_path / "train.jsonl"
471
+        log.write_text(
472
+            "\n".join(
473
+                [
474
+                    json.dumps({"type": "step", "loss": 5.0}),  # no `step`
475
+                    json.dumps({"type": "step", "step": 1, "loss": 4.0}),
476
+                    json.dumps({"type": "step", "step": 2}),  # no `loss`
477
+                ]
478
+            )
479
+            + "\n",
480
+            encoding="utf-8",
481
+        )
482
+        out = _collect_steps([log])
483
+        assert out == {1: 4.0}
484
+
485
+    def test_nan_loss_recorded_as_inf(self, tmp_path: Path) -> None:
486
+        """NaN loss should land as +inf in the curve so the spike
487
+        detector flags it as instability without numpy NaN poisoning."""
488
+        log = tmp_path / "train.jsonl"
489
+        log.write_text(
490
+            "\n".join(
491
+                [
492
+                    json.dumps({"type": "step", "step": 0, "loss": 5.0}),
493
+                    # Real dlm logs encode NaN as the literal NaN; json
494
+                    # itself doesn't permit it, so simulate via Infinity
495
+                    # which json.loads accepts in non-strict mode.
496
+                    '{"type": "step", "step": 1, "loss": NaN}',
497
+                ]
498
+            )
499
+            + "\n",
500
+            encoding="utf-8",
501
+        )
502
+        out = _collect_steps([log])
503
+        assert out[0] == 5.0
504
+        assert math.isinf(out[1])
505
+
506
+    def test_missing_file_raises(self, tmp_path: Path) -> None:
507
+        with pytest.raises(TrainingDriftError, match="failed to read"):
508
+            _collect_steps([tmp_path / "nonexistent.jsonl"])