Python · 14709 bytes Raw Blame History
1 """Unit tests for the ``gradient_ghost`` probe (Sprint 25, F01-style).
2
3 Builds synthetic ``training_state.pt`` + ``adapter_model.safetensors``
4 fixtures so every verdict branch (PASS / FAIL / WARN / SKIP / ERROR)
5 runs without needing a real dlm install. The end-to-end check against
6 a real dlm-store fixture lives in
7 ``tests/integration/test_probe_gradient_ghost.py``.
8 """
9
10 from __future__ import annotations
11
12 from pathlib import Path
13
14 import numpy as np
15 import pytest
16
17 # torch + safetensors ride the [hf] extra. Skip the whole module
18 # when missing rather than fail collection — same idiom as
19 # tests/unit/test_mlx_convert.py.
20 torch = pytest.importorskip("torch", reason="needs the [hf] extra (torch)")
21 safetensors_numpy = pytest.importorskip(
22 "safetensors.numpy", reason="needs the [hf] extra (safetensors)"
23 )
24
25 from dlm_sway.core.errors import ( # noqa: E402 — import-after-skip
26 BackendNotAvailableError,
27 MissingTrainingStateError,
28 )
29 from dlm_sway.core.result import Verdict # noqa: E402
30 from dlm_sway.probes._param_id_mapping import ( # noqa: E402
31 ParamMappingError,
32 map_param_ids_to_layers,
33 )
34 from dlm_sway.probes._training_state import ( # noqa: E402
35 TrainingStateError,
36 load_training_state,
37 )
38 from dlm_sway.probes.base import RunContext, build_probe # noqa: E402
39 from dlm_sway.probes.gradient_ghost import GradientGhostProbe # noqa: E402
40
41
42 def _write_synthetic_safetensors(
43 dst: Path,
44 *,
45 num_layers: int = 4,
46 target_modules: tuple[str, ...] = ("q_proj", "v_proj"),
47 rank: int = 8,
48 in_features: int = 64,
49 out_features: int = 64,
50 ) -> int:
51 """Write a PEFT-shaped safetensors fixture next to the training
52 state. Returns the total number of weight keys (matches the
53 expected number of optimizer-state params)."""
54 weights: dict[str, np.ndarray] = {}
55 for layer_idx in range(num_layers):
56 for mod in target_modules:
57 base = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}"
58 weights[f"{base}.lora_A.weight"] = np.zeros((rank, in_features), dtype=np.float32)
59 weights[f"{base}.lora_B.weight"] = np.zeros((out_features, rank), dtype=np.float32)
60 safetensors_numpy.save_file(weights, str(dst / "adapter_model.safetensors"))
61 return len(weights)
62
63
64 def _write_synthetic_training_state(
65 dst: Path,
66 *,
67 global_step: int,
68 num_params: int,
69 exp_avg_sq_per_param: list[float] | None = None,
70 nan_per_param: bool = False,
71 ) -> None:
72 """Write a minimal ``training_state.pt`` whose shape matches
73 dlm's contract.
74
75 ``exp_avg_sq_per_param`` lets a test plant per-param means (one
76 float per param-id) for the per-layer ratio branches.
77 ``nan_per_param=True`` sets every exp_avg_sq tensor to NaN
78 (proves the all-NaN FAIL branch).
79 """
80 if exp_avg_sq_per_param is None:
81 exp_avg_sq_per_param = [1.0] * num_params
82
83 state_dict: dict[int, dict[str, object]] = {}
84 for pid, sq_mean in enumerate(exp_avg_sq_per_param):
85 if nan_per_param:
86 tensor = torch.full((4,), float("nan"), dtype=torch.float32)
87 else:
88 tensor = torch.full((4,), float(sq_mean), dtype=torch.float32)
89 state_dict[pid] = {
90 "step": torch.tensor(float(global_step)),
91 "exp_avg": torch.zeros((4,), dtype=torch.float32),
92 "exp_avg_sq": tensor,
93 }
94
95 payload = {
96 "optimizer_state_dict": {
97 "state": state_dict,
98 "param_groups": [{"lr": 1e-4, "params": list(range(num_params))}],
99 },
100 "scheduler_state_dict": {},
101 "scaler_state_dict": None,
102 "torch_rng_state": torch.zeros(8, dtype=torch.uint8),
103 "cuda_rng_state": None,
104 "numpy_rng_state": None,
105 "python_random_state": None,
106 "global_step": global_step,
107 "epoch": float(global_step),
108 "best_val_loss": float("inf"),
109 "dlm_manifest_hash": None,
110 "base_model_revision": "deadbeef",
111 "pinned_versions": {"torch": "2.11.0"},
112 "use_qlora": False,
113 }
114 torch.save(payload, str(dst / "training_state.pt"))
115
116
117 # === Tests ===
118
119
120 class TestProbeRegistry:
121 def test_kind_registered(self) -> None:
122 """Probe must be discoverable via build_probe."""
123 probe, _ = build_probe(
124 {"name": "x", "kind": "gradient_ghost", "adapter_path": "/nonexistent"}
125 )
126 assert isinstance(probe, GradientGhostProbe)
127
128 def test_needs_backend_false(self) -> None:
129 """needs_backend=False enables the runner's skip-backend path."""
130 assert GradientGhostProbe.needs_backend is False
131
132 def test_category_calibration(self) -> None:
133 """Category must match the sprint's classification."""
134 assert GradientGhostProbe.category == "calibration"
135
136
137 class TestVerdictLadder:
138 """Each branch in the verdict ladder gets its own test."""
139
140 def test_pass_when_global_step_high_and_distribution_flat(self, tmp_path: Path) -> None:
141 adapter = tmp_path / "adapter"
142 adapter.mkdir()
143 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
144 # Flat distribution — every param has the same exp_avg_sq.
145 _write_synthetic_training_state(
146 adapter,
147 global_step=200,
148 num_params=num_keys,
149 exp_avg_sq_per_param=[1.0] * num_keys,
150 )
151
152 probe, spec = build_probe(
153 {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
154 )
155 result = probe.run(spec, RunContext())
156 assert result.verdict == Verdict.PASS
157 assert result.evidence["global_step"] == 200
158 assert result.evidence["frac_layers_undertrained"] == 0.0
159
160 def test_fail_when_global_step_below_threshold(self, tmp_path: Path) -> None:
161 adapter = tmp_path / "adapter"
162 adapter.mkdir()
163 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
164 _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys)
165
166 probe, spec = build_probe(
167 {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
168 )
169 result = probe.run(spec, RunContext())
170 assert result.verdict == Verdict.FAIL
171 assert result.evidence["global_step"] == 2
172 assert result.evidence["primary_signal"] == "global_step_below_threshold"
173 assert "severely undertrained" in (result.message or "")
174
175 def test_fail_when_all_exp_avg_sq_nan(self, tmp_path: Path) -> None:
176 """Even with global_step >= threshold, every NaN per-param
177 triggers a separate FAIL branch — training propagated nothing."""
178 adapter = tmp_path / "adapter"
179 adapter.mkdir()
180 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
181 _write_synthetic_training_state(
182 adapter,
183 global_step=200,
184 num_params=num_keys,
185 nan_per_param=True,
186 )
187
188 probe, spec = build_probe(
189 {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
190 )
191 result = probe.run(spec, RunContext())
192 assert result.verdict == Verdict.FAIL
193 assert result.evidence["primary_signal"] == "all_optimizer_state_nan"
194 assert result.evidence["num_nonfinite_exp_avg_sq"] == num_keys
195
196 def test_warn_when_some_layers_high_but_under_threshold(self, tmp_path: Path) -> None:
197 """A heavy-tailed exp_avg_sq distribution where < layer_failure_frac
198 of layers cross the per-layer threshold → WARN."""
199 adapter = tmp_path / "adapter"
200 adapter.mkdir()
201 # 4 layers × 2 modules × 2 factors = 16 params.
202 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
203 # Flat baseline; bump layer 0's params to 3× (above the 2×
204 # threshold) but only 1 of 4 layers crosses (25%).
205 magnitudes = [1.0] * num_keys
206 for pid in range(4): # First 4 params = layer 0
207 magnitudes[pid] = 3.0
208 _write_synthetic_training_state(
209 adapter,
210 global_step=200,
211 num_params=num_keys,
212 exp_avg_sq_per_param=magnitudes,
213 )
214
215 probe, spec = build_probe(
216 {
217 "name": "gg",
218 "kind": "gradient_ghost",
219 "adapter_path": str(adapter),
220 "layer_failure_frac": 0.5, # Need >50% to FAIL.
221 }
222 )
223 result = probe.run(spec, RunContext())
224 assert result.verdict == Verdict.WARN
225 assert result.evidence["num_layers_undertrained"] == 1
226 assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.25)
227
228 def test_fail_when_too_many_layers_high(self, tmp_path: Path) -> None:
229 """When more than layer_failure_frac of layers cross the
230 per-layer threshold, secondary signal also FAILs."""
231 adapter = tmp_path / "adapter"
232 adapter.mkdir()
233 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
234 # First 12 params = first 3 layers all bumped 3×; last layer flat.
235 magnitudes = [3.0] * 12 + [1.0] * 4
236 _write_synthetic_training_state(
237 adapter,
238 global_step=200,
239 num_params=num_keys,
240 exp_avg_sq_per_param=magnitudes,
241 )
242
243 probe, spec = build_probe(
244 {
245 "name": "gg",
246 "kind": "gradient_ghost",
247 "adapter_path": str(adapter),
248 "layer_failure_frac": 0.3,
249 }
250 )
251 result = probe.run(spec, RunContext())
252 assert result.verdict == Verdict.FAIL
253 assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.75)
254
255 def test_skip_when_training_state_missing(self, tmp_path: Path) -> None:
256 """No training_state.pt → SKIP (legitimate for non-dlm
257 adapters), not ERROR."""
258 adapter = tmp_path / "adapter-no-state"
259 adapter.mkdir()
260 # adapter_model.safetensors doesn't matter — probe SKIPs first.
261 probe, spec = build_probe(
262 {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
263 )
264 result = probe.run(spec, RunContext())
265 assert result.verdict == Verdict.SKIP
266 assert "training_state.pt" in (result.message or "")
267
268
269 class TestParamIdMapping:
270 """The layer-grouping helper is exercised indirectly via probe
271 runs above; this class adds direct coverage of edge cases."""
272
273 def test_correct_layer_groupings(self, tmp_path: Path) -> None:
274 adapter = tmp_path / "a"
275 adapter.mkdir()
276 num_keys = _write_synthetic_safetensors(adapter, num_layers=3, target_modules=("q_proj",))
277 # 3 layers × 1 module × 2 factors = 6 keys.
278 assert num_keys == 6
279 grouping = map_param_ids_to_layers(adapter, num_params=num_keys)
280 assert grouping.num_layers == 3
281 assert grouping.params_per_layer == 2
282 assert [grouping.layer_of[i] for i in range(6)] == [0, 0, 1, 1, 2, 2]
283
284 def test_missing_safetensors_raises(self, tmp_path: Path) -> None:
285 adapter = tmp_path / "empty"
286 adapter.mkdir()
287 with pytest.raises(ParamMappingError, match="missing"):
288 map_param_ids_to_layers(adapter, num_params=10)
289
290 def test_mismatched_param_count_raises(self, tmp_path: Path) -> None:
291 adapter = tmp_path / "a"
292 adapter.mkdir()
293 num_keys = _write_synthetic_safetensors(adapter, num_layers=2)
294 # Pretend the optimizer has fewer params than safetensors keys.
295 with pytest.raises(ParamMappingError, match="adapter / state mismatch"):
296 map_param_ids_to_layers(adapter, num_params=num_keys - 2)
297
298
299 class TestTrainingStateLoader:
300 def test_missing_file_raises_typed(self, tmp_path: Path) -> None:
301 with pytest.raises(MissingTrainingStateError):
302 load_training_state(tmp_path)
303
304 def test_corrupt_pickle_raises_typed(self, tmp_path: Path) -> None:
305 (tmp_path / "training_state.pt").write_bytes(b"not a pickle")
306 with pytest.raises(TrainingStateError, match="failed to torch.load"):
307 load_training_state(tmp_path)
308
309 def test_unexpected_top_level_shape_raises(self, tmp_path: Path) -> None:
310 torch.save({"foo": "bar"}, str(tmp_path / "training_state.pt"))
311 with pytest.raises(TrainingStateError, match="missing 'optimizer_state_dict'"):
312 load_training_state(tmp_path)
313
314
315 class TestRunnerSkipsBackend:
316 """S25 P5 — runner contract when only no-backend probes scheduled."""
317
318 def test_runs_with_none_backend_when_only_pre_run_probes(self, tmp_path: Path) -> None:
319 """A spec containing only gradient_ghost runs with backend=None."""
320 from dlm_sway.core.model import ModelSpec
321 from dlm_sway.suite.runner import run as run_suite
322 from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec
323
324 adapter = tmp_path / "adapter"
325 adapter.mkdir()
326 num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
327 _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys)
328
329 spec = SwaySpec(
330 version=1,
331 models=SuiteModels(
332 base=ModelSpec(base="dummy", kind="dummy"),
333 ft=ModelSpec(base="dummy", kind="dummy", adapter=adapter),
334 ),
335 defaults=SuiteDefaults(seed=0),
336 suite=[
337 {
338 "name": "gg",
339 "kind": "gradient_ghost",
340 "adapter_path": str(adapter),
341 }
342 ],
343 )
344 result = run_suite(spec, backend=None, spec_path="<test>")
345 assert len(result.probes) == 1
346 assert result.probes[0].verdict == Verdict.FAIL
347 assert result.backend_stats == {} # No backend means no stats.
348
349 def test_raises_when_backend_required_but_none(self, tmp_path: Path) -> None:
350 """A spec with delta_kl + None backend → BackendNotAvailableError."""
351 from dlm_sway.core.model import ModelSpec
352 from dlm_sway.suite.runner import run as run_suite
353 from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec
354
355 spec = SwaySpec(
356 version=1,
357 models=SuiteModels(
358 base=ModelSpec(base="dummy", kind="dummy"),
359 ft=ModelSpec(base="dummy", kind="dummy"),
360 ),
361 defaults=SuiteDefaults(seed=0),
362 suite=[
363 {"name": "dk", "kind": "delta_kl", "prompts": ["x"]},
364 ],
365 )
366 with pytest.raises(BackendNotAvailableError, match="delta_kl"):
367 run_suite(spec, backend=None, spec_path="<test>")