Python · 19942 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.training_drift`."""
2
3 from __future__ import annotations
4
5 import json
6 import math
7 from pathlib import Path
8
9 import numpy as np
10 import pytest
11
12 from dlm_sway.core.result import Verdict
13 from dlm_sway.probes.base import RunContext, build_probe
14 from dlm_sway.probes.training_drift import (
15 TrainingDriftError,
16 _collect_steps,
17 _compute_metrics,
18 _count_spikes,
19 _downsampled_curve,
20 _verdict_from_metrics,
21 )
22
23 # ---------------------------------------------------------------------------
24 # Fixture helpers
25 # ---------------------------------------------------------------------------
26
27
28 def _write_jsonl(
29 path: Path, *, banner: bool = True, steps: list[tuple[int, float]] | None = None
30 ) -> Path:
31 """Build a dlm-shaped train-*.jsonl fixture file.
32
33 Mirrors the real format:
34 - Optional banner line (type=banner).
35 - Per-step lines (type=step) with step + loss + lr + grad_norm.
36 - No closing line — real runs may crash mid-step; the probe
37 should tolerate truncated files.
38 """
39 path.parent.mkdir(parents=True, exist_ok=True)
40 lines: list[str] = []
41 if banner:
42 lines.append(json.dumps({"type": "banner", "run_id": 1, "seed": 42}))
43 for step, loss in steps or []:
44 lines.append(
45 json.dumps({"type": "step", "step": step, "loss": loss, "lr": 1e-4, "grad_norm": 0.5})
46 )
47 path.write_text("\n".join(lines) + "\n", encoding="utf-8")
48 return path
49
50
51 def _smooth_decay(num_steps: int = 50) -> list[tuple[int, float]]:
52 """Clean exponential decay — should pass every threshold."""
53 return [(i, 5.0 * math.exp(-i / 20.0) + 0.1) for i in range(num_steps)]
54
55
56 def _spiky_curve(num_steps: int = 50, *, spike_at: int = 25) -> list[tuple[int, float]]:
57 """Clean decay with one obvious loss-increase spike injected.
58
59 The spike is a loss *increase* relative to the prior step (the
60 semantically meaningful "training instability" signal — fast
61 convergence steps are not instabilities).
62 """
63 base = _smooth_decay(num_steps)
64 out = list(base)
65 # Inject a sharp upward spike: loss jumps from ~base value to 5x
66 # of itself. The next step's delta back down doesn't count as a
67 # spike (loss decreases aren't instabilities by design).
68 out[spike_at] = (spike_at, base[spike_at][1] * 5.0)
69 return out
70
71
72 # ---------------------------------------------------------------------------
73 # End-to-end probe behavior
74 # ---------------------------------------------------------------------------
75
76
77 class TestProbeBehavior:
78 def test_pass_on_smooth_curve(self, tmp_path: Path) -> None:
79 store = tmp_path / "store"
80 _write_jsonl(
81 store / "logs" / "train-000001-20260101T000000.jsonl",
82 steps=_smooth_decay(60),
83 )
84 probe, spec = build_probe(
85 {
86 "name": "td",
87 "kind": "training_drift",
88 "store_path": str(store),
89 }
90 )
91 result = probe.run(spec, RunContext())
92 assert result.verdict == Verdict.PASS, result.message
93 assert result.evidence["instability_events"] == 0
94 assert result.evidence["smoothness"] >= 0.7
95 # Initial loss is the curve's first sample, final is the last.
96 assert result.evidence["initial_loss"] == pytest.approx(5.1, abs=0.1)
97 assert result.evidence["final_loss"] < 1.0
98
99 def test_warn_on_spiky_curve(self, tmp_path: Path) -> None:
100 store = tmp_path / "store"
101 _write_jsonl(store / "logs" / "train-000001-20260101T000000.jsonl", steps=_spiky_curve(60))
102 probe, spec = build_probe(
103 {
104 "name": "td",
105 "kind": "training_drift",
106 "store_path": str(store),
107 }
108 )
109 result = probe.run(spec, RunContext())
110 assert result.verdict == Verdict.WARN
111 assert result.evidence["instability_events"] >= 1
112 assert "instability_events" in result.message
113
114 def test_skip_when_no_store_path(self) -> None:
115 probe, spec = build_probe({"name": "td", "kind": "training_drift"})
116 result = probe.run(spec, RunContext())
117 assert result.verdict == Verdict.SKIP
118 assert "no store_path" in result.message
119
120 def test_skip_when_logs_dir_missing(self, tmp_path: Path) -> None:
121 store = tmp_path / "store"
122 store.mkdir()
123 probe, spec = build_probe(
124 {
125 "name": "td",
126 "kind": "training_drift",
127 "store_path": str(store),
128 }
129 )
130 result = probe.run(spec, RunContext())
131 assert result.verdict == Verdict.SKIP
132 assert "no logs/" in result.message
133
134 def test_skip_when_no_jsonl(self, tmp_path: Path) -> None:
135 store = tmp_path / "store"
136 (store / "logs").mkdir(parents=True)
137 probe, spec = build_probe(
138 {
139 "name": "td",
140 "kind": "training_drift",
141 "store_path": str(store),
142 }
143 )
144 result = probe.run(spec, RunContext())
145 assert result.verdict == Verdict.SKIP
146 assert "no train-*.jsonl" in result.message
147
148 def test_skip_when_too_few_steps(self, tmp_path: Path) -> None:
149 """Default min_steps=10 — a 3-step curve must SKIP, not produce
150 a misleading verdict."""
151 store = tmp_path / "store"
152 _write_jsonl(
153 store / "logs" / "train-000001-20260101T000000.jsonl",
154 steps=[(0, 5.0), (1, 4.5), (2, 4.0)],
155 )
156 probe, spec = build_probe(
157 {
158 "name": "td",
159 "kind": "training_drift",
160 "store_path": str(store),
161 }
162 )
163 result = probe.run(spec, RunContext())
164 assert result.verdict == Verdict.SKIP
165 assert "too short" in result.message
166
167 def test_resumed_runs_dedupe_keep_latest(self, tmp_path: Path) -> None:
168 """Two log files with overlapping step numbers — second one wins
169 (mirrors dlm metrics' resume semantics)."""
170 store = tmp_path / "store"
171 # First run: steps 0..9 with high losses
172 _write_jsonl(
173 store / "logs" / "train-000001-20260101T000000.jsonl",
174 steps=[(i, 10.0 + i) for i in range(10)],
175 )
176 # Resumed run: steps 5..14 with low losses (resume picked up,
177 # values for 5..9 should overwrite the originals)
178 _write_jsonl(
179 store / "logs" / "train-000001-20260101T010000.jsonl",
180 steps=[(i, 1.0) for i in range(5, 15)],
181 )
182 probe, spec = build_probe(
183 {
184 "name": "td",
185 "kind": "training_drift",
186 "store_path": str(store),
187 }
188 )
189 result = probe.run(spec, RunContext())
190 # Final loss should be from the resumed run (1.0), not the
191 # first run's step 14 (which doesn't exist).
192 assert result.evidence["final_loss"] == 1.0
193 # Step 5 should carry the resumed value, not the original.
194 curve = result.evidence["curve_sampled"]
195 step_5 = next((loss for s, loss in curve if s == 5), None)
196 assert step_5 == 1.0, f"step 5 should be from resumed run; got {step_5}"
197
198 def test_curve_downsampled_when_long(self, tmp_path: Path) -> None:
199 """A 1500-step run should land in evidence with curve_sampled <= 512."""
200 store = tmp_path / "store"
201 _write_jsonl(
202 store / "logs" / "train-000001-20260101T000000.jsonl",
203 steps=[(i, 5.0 * math.exp(-i / 500.0)) for i in range(1500)],
204 )
205 probe, spec = build_probe(
206 {
207 "name": "td",
208 "kind": "training_drift",
209 "store_path": str(store),
210 }
211 )
212 result = probe.run(spec, RunContext())
213 assert result.evidence["num_steps"] == 1500
214 curve = result.evidence["curve_sampled"]
215 assert len(curve) <= 512
216 # Endpoints preserved.
217 assert curve[0][0] == 0
218 assert curve[-1][0] == 1499
219
220 def test_corrupt_first_line_errors(self, tmp_path: Path) -> None:
221 store = tmp_path / "store"
222 log_dir = store / "logs"
223 log_dir.mkdir(parents=True)
224 (log_dir / "train-000001-20260101T000000.jsonl").write_text(
225 "not even json\n", encoding="utf-8"
226 )
227 probe, spec = build_probe(
228 {
229 "name": "td",
230 "kind": "training_drift",
231 "store_path": str(store),
232 }
233 )
234 result = probe.run(spec, RunContext())
235 assert result.verdict == Verdict.ERROR
236 assert "not valid JSON" in result.message
237
238 def test_truncated_trailing_line_tolerated(self, tmp_path: Path) -> None:
239 """A crashed-mid-line trainer leaves a partial JSON tail. The
240 probe should consume the good lines and skip the bad one."""
241 store = tmp_path / "store"
242 log = store / "logs" / "train-000001-20260101T000000.jsonl"
243 log.parent.mkdir(parents=True)
244 good_lines = [
245 json.dumps({"type": "banner", "run_id": 1}),
246 ] + [
247 json.dumps({"type": "step", "step": i, "loss": 5.0 - i * 0.05, "lr": 1e-4})
248 for i in range(60)
249 ]
250 # Trailing partial line a crashed trainer might emit.
251 log.write_text(
252 "\n".join(good_lines) + '\n{"type": "step", "step": 60, "lo', encoding="utf-8"
253 )
254 probe, spec = build_probe(
255 {
256 "name": "td",
257 "kind": "training_drift",
258 "store_path": str(store),
259 }
260 )
261 result = probe.run(spec, RunContext())
262 # Verdict can be PASS or WARN depending on the curve, but it
263 # must not be ERROR — the partial line shouldn't break the run.
264 assert result.verdict in {Verdict.PASS, Verdict.WARN}
265 assert result.evidence["num_steps"] == 60
266
267
268 # ---------------------------------------------------------------------------
269 # Pure-math metric helpers
270 # ---------------------------------------------------------------------------
271
272
273 class TestComputeMetrics:
274 def test_smooth_decay_metrics(self) -> None:
275 losses = np.array([5.0 * math.exp(-i / 20.0) + 0.1 for i in range(50)])
276 m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
277 assert m["instability_events"] == 0
278 assert m["smoothness"] > 0.95
279 assert 0.0 < m["convergence_ratio"] < 0.2
280
281 def test_constant_loss_marked_unsmooth(self) -> None:
282 """A perfectly flat curve is NOT 'smooth' — it's a stuck run."""
283 losses = np.full(50, 5.0)
284 m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
285 assert m["smoothness"] == 0.0
286 assert m["convergence_ratio"] == 1.0
287
288 def test_nan_loss_counts_as_instability(self) -> None:
289 """A NaN in the curve should count as an instability event but
290 not crash the metric computation (NaN propagation breaks
291 everything otherwise)."""
292 losses = np.array([5.0 - i * 0.1 for i in range(50)])
293 losses[20] = float("nan")
294 m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
295 assert m["instability_events"] >= 1
296 # The forward-fill kept downstream stats finite.
297 assert math.isfinite(m["smoothness"])
298 assert math.isfinite(m["final_loss"])
299
300 def test_all_nan_returns_zero_smoothness(self) -> None:
301 losses = np.array([float("nan")] * 10)
302 m = _compute_metrics(losses, rolling_window=10, spike_sigma=3.0)
303 assert m["smoothness"] == 0.0
304 assert m["instability_events"] == 10
305
306 def test_zero_initial_loss_returns_inf_ratio(self) -> None:
307 """Initial loss of 0 (degenerate) → convergence ratio is inf
308 rather than ZeroDivisionError."""
309 losses = np.array([0.0, 0.5, 1.0, 1.5])
310 m = _compute_metrics(losses, rolling_window=2, spike_sigma=3.0)
311 assert m["convergence_ratio"] == float("inf")
312
313
314 class TestCountSpikes:
315 def test_no_spikes_in_smoothly_decaying_curve(self) -> None:
316 """Loss going down — not an instability, regardless of |Δ|."""
317 deltas = np.array([-0.05] * 50)
318 assert _count_spikes(deltas, window=10, sigma=3.0) == 0
319
320 def test_no_spikes_in_constant_curve(self) -> None:
321 deltas = np.zeros(50)
322 assert _count_spikes(deltas, window=10, sigma=3.0) == 0
323
324 def test_loss_increase_outlier_caught(self) -> None:
325 """A genuine training spike: loss-up event much larger than typical."""
326 deltas = np.array([-0.05] * 50)
327 deltas[25] = 1.5 # loss jumped UP
328 assert _count_spikes(deltas, window=10, sigma=3.0) == 1
329
330 def test_loss_decrease_outlier_ignored(self) -> None:
331 """A 'fast convergence' step (loss going down hard) is NOT an
332 instability — only loss-up events count."""
333 deltas = np.array([-0.05] * 50)
334 deltas[25] = -2.0 # huge negative delta — fast convergence
335 assert _count_spikes(deltas, window=10, sigma=3.0) == 0
336
337 def test_short_curve_uses_global_baseline(self) -> None:
338 # 5 deltas, window=10 → falls back to global MAD.
339 deltas = np.array([-0.01, -0.01, 1.0, -0.01, -0.01])
340 spikes = _count_spikes(deltas, window=10, sigma=2.0)
341 assert spikes == 1
342
343 def test_empty_deltas_returns_zero(self) -> None:
344 assert _count_spikes(np.array([]), window=10, sigma=3.0) == 0
345
346
347 class TestDownsampledCurve:
348 def test_short_curve_unchanged(self) -> None:
349 steps = np.array([0, 1, 2, 3])
350 losses = np.array([5.0, 4.0, 3.0, 2.0])
351 out = _downsampled_curve(steps, losses, cap=10)
352 assert len(out) == 4
353 assert out == [(0, 5.0), (1, 4.0), (2, 3.0), (3, 2.0)]
354
355 def test_long_curve_capped_with_endpoints_preserved(self) -> None:
356 steps = np.arange(2000)
357 losses = np.linspace(5.0, 0.5, 2000)
358 out = _downsampled_curve(steps, losses, cap=100)
359 assert len(out) <= 110 # cap with the +1 endpoint allowance
360 assert out[0][0] == 0
361 assert out[-1][0] == 1999
362
363
364 class TestVerdictFromMetrics:
365 def _spec(self, **kwargs: object) -> object:
366 from dlm_sway.probes.training_drift import TrainingDriftSpec
367
368 return TrainingDriftSpec(name="td", kind="training_drift", **kwargs) # type: ignore[arg-type]
369
370 def test_pass_when_all_thresholds_clear(self) -> None:
371 spec = self._spec()
372 v, _, msg = _verdict_from_metrics(
373 {
374 "smoothness": 0.9,
375 "convergence_ratio": 0.4,
376 "instability_events": 0,
377 "final_loss": 0.5,
378 },
379 spec, # type: ignore[arg-type]
380 )
381 assert v == Verdict.PASS
382 assert "smoothness=0.90" in msg
383 assert "warnings:" not in msg
384
385 def test_warn_lists_each_failed_threshold(self) -> None:
386 spec = self._spec()
387 v, _, msg = _verdict_from_metrics(
388 {
389 "smoothness": 0.5,
390 "convergence_ratio": 0.9,
391 "instability_events": 3,
392 "final_loss": 4.5,
393 },
394 spec, # type: ignore[arg-type]
395 )
396 assert v == Verdict.WARN
397 assert "smoothness=0.50" in msg
398 assert "convergence_ratio=0.90" in msg
399 assert "instability_events=3" in msg
400
401
402 # ---------------------------------------------------------------------------
403 # Collect steps (JSONL parsing edge cases)
404 # ---------------------------------------------------------------------------
405
406
407 class TestRealDlmFixture:
408 """Validate the probe against a JSONL captured from a real dlm run.
409
410 The fixture under ``tests/fixtures/dlm_train_log_fixture.jsonl`` is
411 a captured-from-disk shape: leading banner, an interleaved
412 ``type=delta`` (doc-change record), 30 ``type=step`` records, and
413 a closing ``type=run_complete``. If this test breaks, dlm's log
414 format has shifted and the probe needs an update — that's
415 exactly the regression signal we want.
416 """
417
418 def test_parses_real_fixture_to_pass_verdict(self, tmp_path: Path) -> None:
419 fixture = (
420 Path(__file__).resolve().parent.parent / "fixtures" / "dlm_train_log_fixture.jsonl"
421 )
422 store = tmp_path / "store"
423 store.mkdir()
424 (store / "logs").mkdir()
425 (store / "logs" / "train-000001-20260426T062514.jsonl").write_bytes(fixture.read_bytes())
426
427 probe, spec = build_probe(
428 {
429 "name": "td",
430 "kind": "training_drift",
431 "store_path": str(store),
432 # The fixture's tail flattens out (loss converges) so
433 # the curve has a stable plateau. Permissive convergence
434 # threshold to focus the assertion on format compat.
435 "assert_convergence_ratio_lte": 0.5,
436 }
437 )
438 result = probe.run(spec, RunContext())
439 assert result.verdict == Verdict.PASS, result.message
440 assert result.evidence["num_steps"] == 30
441 assert result.evidence["instability_events"] == 0
442 # Final loss was 1.911 in the fixture; just check the right
443 # ballpark so future fixture tweaks don't spuriously fail.
444 assert 1.8 < result.evidence["final_loss"] < 2.0
445 assert result.evidence["initial_loss"] > 5.0
446
447
448 class TestCollectSteps:
449 def test_filters_non_step_records(self, tmp_path: Path) -> None:
450 log = tmp_path / "train-000001.jsonl"
451 log.write_text(
452 "\n".join(
453 [
454 json.dumps({"type": "banner", "run_id": 1}),
455 json.dumps({"type": "step", "step": 0, "loss": 5.0}),
456 json.dumps({"type": "delta", "new": [], "removed": []}),
457 json.dumps({"type": "step", "step": 1, "loss": 4.0}),
458 json.dumps({"type": "run_complete", "elapsed_seconds": 10.0}),
459 ]
460 )
461 + "\n",
462 encoding="utf-8",
463 )
464 out = _collect_steps([log])
465 assert out == {0: 5.0, 1: 4.0}
466
467 def test_missing_step_key_skipped(self, tmp_path: Path) -> None:
468 """A 'step' record missing required fields is dropped — the
469 parser doesn't crash the run on a single bad record."""
470 log = tmp_path / "train.jsonl"
471 log.write_text(
472 "\n".join(
473 [
474 json.dumps({"type": "step", "loss": 5.0}), # no `step`
475 json.dumps({"type": "step", "step": 1, "loss": 4.0}),
476 json.dumps({"type": "step", "step": 2}), # no `loss`
477 ]
478 )
479 + "\n",
480 encoding="utf-8",
481 )
482 out = _collect_steps([log])
483 assert out == {1: 4.0}
484
485 def test_nan_loss_recorded_as_inf(self, tmp_path: Path) -> None:
486 """NaN loss should land as +inf in the curve so the spike
487 detector flags it as instability without numpy NaN poisoning."""
488 log = tmp_path / "train.jsonl"
489 log.write_text(
490 "\n".join(
491 [
492 json.dumps({"type": "step", "step": 0, "loss": 5.0}),
493 # Real dlm logs encode NaN as the literal NaN; json
494 # itself doesn't permit it, so simulate via Infinity
495 # which json.loads accepts in non-strict mode.
496 '{"type": "step", "step": 1, "loss": NaN}',
497 ]
498 )
499 + "\n",
500 encoding="utf-8",
501 )
502 out = _collect_steps([log])
503 assert out[0] == 5.0
504 assert math.isinf(out[1])
505
506 def test_missing_file_raises(self, tmp_path: Path) -> None:
507 with pytest.raises(TrainingDriftError, match="failed to read"):
508 _collect_steps([tmp_path / "nonexistent.jsonl"])