| 1 | """Unit tests for the ``gradient_ghost`` probe (Sprint 25, F01-style). |
| 2 | |
| 3 | Builds synthetic ``training_state.pt`` + ``adapter_model.safetensors`` |
| 4 | fixtures so every verdict branch (PASS / FAIL / WARN / SKIP / ERROR) |
| 5 | runs without needing a real dlm install. The end-to-end check against |
| 6 | a real dlm-store fixture lives in |
| 7 | ``tests/integration/test_probe_gradient_ghost.py``. |
| 8 | """ |
| 9 | |
| 10 | from __future__ import annotations |
| 11 | |
| 12 | from pathlib import Path |
| 13 | |
| 14 | import numpy as np |
| 15 | import pytest |
| 16 | |
| 17 | # torch + safetensors ride the [hf] extra. Skip the whole module |
| 18 | # when missing rather than fail collection — same idiom as |
| 19 | # tests/unit/test_mlx_convert.py. |
| 20 | torch = pytest.importorskip("torch", reason="needs the [hf] extra (torch)") |
| 21 | safetensors_numpy = pytest.importorskip( |
| 22 | "safetensors.numpy", reason="needs the [hf] extra (safetensors)" |
| 23 | ) |
| 24 | |
| 25 | from dlm_sway.core.errors import ( # noqa: E402 — import-after-skip |
| 26 | BackendNotAvailableError, |
| 27 | MissingTrainingStateError, |
| 28 | ) |
| 29 | from dlm_sway.core.result import Verdict # noqa: E402 |
| 30 | from dlm_sway.probes._param_id_mapping import ( # noqa: E402 |
| 31 | ParamMappingError, |
| 32 | map_param_ids_to_layers, |
| 33 | ) |
| 34 | from dlm_sway.probes._training_state import ( # noqa: E402 |
| 35 | TrainingStateError, |
| 36 | load_training_state, |
| 37 | ) |
| 38 | from dlm_sway.probes.base import RunContext, build_probe # noqa: E402 |
| 39 | from dlm_sway.probes.gradient_ghost import GradientGhostProbe # noqa: E402 |
| 40 | |
| 41 | |
| 42 | def _write_synthetic_safetensors( |
| 43 | dst: Path, |
| 44 | *, |
| 45 | num_layers: int = 4, |
| 46 | target_modules: tuple[str, ...] = ("q_proj", "v_proj"), |
| 47 | rank: int = 8, |
| 48 | in_features: int = 64, |
| 49 | out_features: int = 64, |
| 50 | ) -> int: |
| 51 | """Write a PEFT-shaped safetensors fixture next to the training |
| 52 | state. Returns the total number of weight keys (matches the |
| 53 | expected number of optimizer-state params).""" |
| 54 | weights: dict[str, np.ndarray] = {} |
| 55 | for layer_idx in range(num_layers): |
| 56 | for mod in target_modules: |
| 57 | base = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}" |
| 58 | weights[f"{base}.lora_A.weight"] = np.zeros((rank, in_features), dtype=np.float32) |
| 59 | weights[f"{base}.lora_B.weight"] = np.zeros((out_features, rank), dtype=np.float32) |
| 60 | safetensors_numpy.save_file(weights, str(dst / "adapter_model.safetensors")) |
| 61 | return len(weights) |
| 62 | |
| 63 | |
| 64 | def _write_synthetic_training_state( |
| 65 | dst: Path, |
| 66 | *, |
| 67 | global_step: int, |
| 68 | num_params: int, |
| 69 | exp_avg_sq_per_param: list[float] | None = None, |
| 70 | nan_per_param: bool = False, |
| 71 | ) -> None: |
| 72 | """Write a minimal ``training_state.pt`` whose shape matches |
| 73 | dlm's contract. |
| 74 | |
| 75 | ``exp_avg_sq_per_param`` lets a test plant per-param means (one |
| 76 | float per param-id) for the per-layer ratio branches. |
| 77 | ``nan_per_param=True`` sets every exp_avg_sq tensor to NaN |
| 78 | (proves the all-NaN FAIL branch). |
| 79 | """ |
| 80 | if exp_avg_sq_per_param is None: |
| 81 | exp_avg_sq_per_param = [1.0] * num_params |
| 82 | |
| 83 | state_dict: dict[int, dict[str, object]] = {} |
| 84 | for pid, sq_mean in enumerate(exp_avg_sq_per_param): |
| 85 | if nan_per_param: |
| 86 | tensor = torch.full((4,), float("nan"), dtype=torch.float32) |
| 87 | else: |
| 88 | tensor = torch.full((4,), float(sq_mean), dtype=torch.float32) |
| 89 | state_dict[pid] = { |
| 90 | "step": torch.tensor(float(global_step)), |
| 91 | "exp_avg": torch.zeros((4,), dtype=torch.float32), |
| 92 | "exp_avg_sq": tensor, |
| 93 | } |
| 94 | |
| 95 | payload = { |
| 96 | "optimizer_state_dict": { |
| 97 | "state": state_dict, |
| 98 | "param_groups": [{"lr": 1e-4, "params": list(range(num_params))}], |
| 99 | }, |
| 100 | "scheduler_state_dict": {}, |
| 101 | "scaler_state_dict": None, |
| 102 | "torch_rng_state": torch.zeros(8, dtype=torch.uint8), |
| 103 | "cuda_rng_state": None, |
| 104 | "numpy_rng_state": None, |
| 105 | "python_random_state": None, |
| 106 | "global_step": global_step, |
| 107 | "epoch": float(global_step), |
| 108 | "best_val_loss": float("inf"), |
| 109 | "dlm_manifest_hash": None, |
| 110 | "base_model_revision": "deadbeef", |
| 111 | "pinned_versions": {"torch": "2.11.0"}, |
| 112 | "use_qlora": False, |
| 113 | } |
| 114 | torch.save(payload, str(dst / "training_state.pt")) |
| 115 | |
| 116 | |
| 117 | # === Tests === |
| 118 | |
| 119 | |
| 120 | class TestProbeRegistry: |
| 121 | def test_kind_registered(self) -> None: |
| 122 | """Probe must be discoverable via build_probe.""" |
| 123 | probe, _ = build_probe( |
| 124 | {"name": "x", "kind": "gradient_ghost", "adapter_path": "/nonexistent"} |
| 125 | ) |
| 126 | assert isinstance(probe, GradientGhostProbe) |
| 127 | |
| 128 | def test_needs_backend_false(self) -> None: |
| 129 | """needs_backend=False enables the runner's skip-backend path.""" |
| 130 | assert GradientGhostProbe.needs_backend is False |
| 131 | |
| 132 | def test_category_calibration(self) -> None: |
| 133 | """Category must match the sprint's classification.""" |
| 134 | assert GradientGhostProbe.category == "calibration" |
| 135 | |
| 136 | |
| 137 | class TestVerdictLadder: |
| 138 | """Each branch in the verdict ladder gets its own test.""" |
| 139 | |
| 140 | def test_pass_when_global_step_high_and_distribution_flat(self, tmp_path: Path) -> None: |
| 141 | adapter = tmp_path / "adapter" |
| 142 | adapter.mkdir() |
| 143 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 144 | # Flat distribution — every param has the same exp_avg_sq. |
| 145 | _write_synthetic_training_state( |
| 146 | adapter, |
| 147 | global_step=200, |
| 148 | num_params=num_keys, |
| 149 | exp_avg_sq_per_param=[1.0] * num_keys, |
| 150 | ) |
| 151 | |
| 152 | probe, spec = build_probe( |
| 153 | {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 154 | ) |
| 155 | result = probe.run(spec, RunContext()) |
| 156 | assert result.verdict == Verdict.PASS |
| 157 | assert result.evidence["global_step"] == 200 |
| 158 | assert result.evidence["frac_layers_undertrained"] == 0.0 |
| 159 | |
| 160 | def test_fail_when_global_step_below_threshold(self, tmp_path: Path) -> None: |
| 161 | adapter = tmp_path / "adapter" |
| 162 | adapter.mkdir() |
| 163 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 164 | _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys) |
| 165 | |
| 166 | probe, spec = build_probe( |
| 167 | {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 168 | ) |
| 169 | result = probe.run(spec, RunContext()) |
| 170 | assert result.verdict == Verdict.FAIL |
| 171 | assert result.evidence["global_step"] == 2 |
| 172 | assert result.evidence["primary_signal"] == "global_step_below_threshold" |
| 173 | assert "severely undertrained" in (result.message or "") |
| 174 | |
| 175 | def test_fail_when_all_exp_avg_sq_nan(self, tmp_path: Path) -> None: |
| 176 | """Even with global_step >= threshold, every NaN per-param |
| 177 | triggers a separate FAIL branch — training propagated nothing.""" |
| 178 | adapter = tmp_path / "adapter" |
| 179 | adapter.mkdir() |
| 180 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 181 | _write_synthetic_training_state( |
| 182 | adapter, |
| 183 | global_step=200, |
| 184 | num_params=num_keys, |
| 185 | nan_per_param=True, |
| 186 | ) |
| 187 | |
| 188 | probe, spec = build_probe( |
| 189 | {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 190 | ) |
| 191 | result = probe.run(spec, RunContext()) |
| 192 | assert result.verdict == Verdict.FAIL |
| 193 | assert result.evidence["primary_signal"] == "all_optimizer_state_nan" |
| 194 | assert result.evidence["num_nonfinite_exp_avg_sq"] == num_keys |
| 195 | |
| 196 | def test_warn_when_some_layers_high_but_under_threshold(self, tmp_path: Path) -> None: |
| 197 | """A heavy-tailed exp_avg_sq distribution where < layer_failure_frac |
| 198 | of layers cross the per-layer threshold → WARN.""" |
| 199 | adapter = tmp_path / "adapter" |
| 200 | adapter.mkdir() |
| 201 | # 4 layers × 2 modules × 2 factors = 16 params. |
| 202 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 203 | # Flat baseline; bump layer 0's params to 3× (above the 2× |
| 204 | # threshold) but only 1 of 4 layers crosses (25%). |
| 205 | magnitudes = [1.0] * num_keys |
| 206 | for pid in range(4): # First 4 params = layer 0 |
| 207 | magnitudes[pid] = 3.0 |
| 208 | _write_synthetic_training_state( |
| 209 | adapter, |
| 210 | global_step=200, |
| 211 | num_params=num_keys, |
| 212 | exp_avg_sq_per_param=magnitudes, |
| 213 | ) |
| 214 | |
| 215 | probe, spec = build_probe( |
| 216 | { |
| 217 | "name": "gg", |
| 218 | "kind": "gradient_ghost", |
| 219 | "adapter_path": str(adapter), |
| 220 | "layer_failure_frac": 0.5, # Need >50% to FAIL. |
| 221 | } |
| 222 | ) |
| 223 | result = probe.run(spec, RunContext()) |
| 224 | assert result.verdict == Verdict.WARN |
| 225 | assert result.evidence["num_layers_undertrained"] == 1 |
| 226 | assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.25) |
| 227 | |
| 228 | def test_fail_when_too_many_layers_high(self, tmp_path: Path) -> None: |
| 229 | """When more than layer_failure_frac of layers cross the |
| 230 | per-layer threshold, secondary signal also FAILs.""" |
| 231 | adapter = tmp_path / "adapter" |
| 232 | adapter.mkdir() |
| 233 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 234 | # First 12 params = first 3 layers all bumped 3×; last layer flat. |
| 235 | magnitudes = [3.0] * 12 + [1.0] * 4 |
| 236 | _write_synthetic_training_state( |
| 237 | adapter, |
| 238 | global_step=200, |
| 239 | num_params=num_keys, |
| 240 | exp_avg_sq_per_param=magnitudes, |
| 241 | ) |
| 242 | |
| 243 | probe, spec = build_probe( |
| 244 | { |
| 245 | "name": "gg", |
| 246 | "kind": "gradient_ghost", |
| 247 | "adapter_path": str(adapter), |
| 248 | "layer_failure_frac": 0.3, |
| 249 | } |
| 250 | ) |
| 251 | result = probe.run(spec, RunContext()) |
| 252 | assert result.verdict == Verdict.FAIL |
| 253 | assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.75) |
| 254 | |
| 255 | def test_skip_when_training_state_missing(self, tmp_path: Path) -> None: |
| 256 | """No training_state.pt → SKIP (legitimate for non-dlm |
| 257 | adapters), not ERROR.""" |
| 258 | adapter = tmp_path / "adapter-no-state" |
| 259 | adapter.mkdir() |
| 260 | # adapter_model.safetensors doesn't matter — probe SKIPs first. |
| 261 | probe, spec = build_probe( |
| 262 | {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 263 | ) |
| 264 | result = probe.run(spec, RunContext()) |
| 265 | assert result.verdict == Verdict.SKIP |
| 266 | assert "training_state.pt" in (result.message or "") |
| 267 | |
| 268 | |
| 269 | class TestParamIdMapping: |
| 270 | """The layer-grouping helper is exercised indirectly via probe |
| 271 | runs above; this class adds direct coverage of edge cases.""" |
| 272 | |
| 273 | def test_correct_layer_groupings(self, tmp_path: Path) -> None: |
| 274 | adapter = tmp_path / "a" |
| 275 | adapter.mkdir() |
| 276 | num_keys = _write_synthetic_safetensors(adapter, num_layers=3, target_modules=("q_proj",)) |
| 277 | # 3 layers × 1 module × 2 factors = 6 keys. |
| 278 | assert num_keys == 6 |
| 279 | grouping = map_param_ids_to_layers(adapter, num_params=num_keys) |
| 280 | assert grouping.num_layers == 3 |
| 281 | assert grouping.params_per_layer == 2 |
| 282 | assert [grouping.layer_of[i] for i in range(6)] == [0, 0, 1, 1, 2, 2] |
| 283 | |
| 284 | def test_missing_safetensors_raises(self, tmp_path: Path) -> None: |
| 285 | adapter = tmp_path / "empty" |
| 286 | adapter.mkdir() |
| 287 | with pytest.raises(ParamMappingError, match="missing"): |
| 288 | map_param_ids_to_layers(adapter, num_params=10) |
| 289 | |
| 290 | def test_mismatched_param_count_raises(self, tmp_path: Path) -> None: |
| 291 | adapter = tmp_path / "a" |
| 292 | adapter.mkdir() |
| 293 | num_keys = _write_synthetic_safetensors(adapter, num_layers=2) |
| 294 | # Pretend the optimizer has fewer params than safetensors keys. |
| 295 | with pytest.raises(ParamMappingError, match="adapter / state mismatch"): |
| 296 | map_param_ids_to_layers(adapter, num_params=num_keys - 2) |
| 297 | |
| 298 | |
| 299 | class TestTrainingStateLoader: |
| 300 | def test_missing_file_raises_typed(self, tmp_path: Path) -> None: |
| 301 | with pytest.raises(MissingTrainingStateError): |
| 302 | load_training_state(tmp_path) |
| 303 | |
| 304 | def test_corrupt_pickle_raises_typed(self, tmp_path: Path) -> None: |
| 305 | (tmp_path / "training_state.pt").write_bytes(b"not a pickle") |
| 306 | with pytest.raises(TrainingStateError, match="failed to torch.load"): |
| 307 | load_training_state(tmp_path) |
| 308 | |
| 309 | def test_unexpected_top_level_shape_raises(self, tmp_path: Path) -> None: |
| 310 | torch.save({"foo": "bar"}, str(tmp_path / "training_state.pt")) |
| 311 | with pytest.raises(TrainingStateError, match="missing 'optimizer_state_dict'"): |
| 312 | load_training_state(tmp_path) |
| 313 | |
| 314 | |
| 315 | class TestRunnerSkipsBackend: |
| 316 | """S25 P5 — runner contract when only no-backend probes scheduled.""" |
| 317 | |
| 318 | def test_runs_with_none_backend_when_only_pre_run_probes(self, tmp_path: Path) -> None: |
| 319 | """A spec containing only gradient_ghost runs with backend=None.""" |
| 320 | from dlm_sway.core.model import ModelSpec |
| 321 | from dlm_sway.suite.runner import run as run_suite |
| 322 | from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec |
| 323 | |
| 324 | adapter = tmp_path / "adapter" |
| 325 | adapter.mkdir() |
| 326 | num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 327 | _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys) |
| 328 | |
| 329 | spec = SwaySpec( |
| 330 | version=1, |
| 331 | models=SuiteModels( |
| 332 | base=ModelSpec(base="dummy", kind="dummy"), |
| 333 | ft=ModelSpec(base="dummy", kind="dummy", adapter=adapter), |
| 334 | ), |
| 335 | defaults=SuiteDefaults(seed=0), |
| 336 | suite=[ |
| 337 | { |
| 338 | "name": "gg", |
| 339 | "kind": "gradient_ghost", |
| 340 | "adapter_path": str(adapter), |
| 341 | } |
| 342 | ], |
| 343 | ) |
| 344 | result = run_suite(spec, backend=None, spec_path="<test>") |
| 345 | assert len(result.probes) == 1 |
| 346 | assert result.probes[0].verdict == Verdict.FAIL |
| 347 | assert result.backend_stats == {} # No backend means no stats. |
| 348 | |
| 349 | def test_raises_when_backend_required_but_none(self, tmp_path: Path) -> None: |
| 350 | """A spec with delta_kl + None backend → BackendNotAvailableError.""" |
| 351 | from dlm_sway.core.model import ModelSpec |
| 352 | from dlm_sway.suite.runner import run as run_suite |
| 353 | from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec |
| 354 | |
| 355 | spec = SwaySpec( |
| 356 | version=1, |
| 357 | models=SuiteModels( |
| 358 | base=ModelSpec(base="dummy", kind="dummy"), |
| 359 | ft=ModelSpec(base="dummy", kind="dummy"), |
| 360 | ), |
| 361 | defaults=SuiteDefaults(seed=0), |
| 362 | suite=[ |
| 363 | {"name": "dk", "kind": "delta_kl", "prompts": ["x"]}, |
| 364 | ], |
| 365 | ) |
| 366 | with pytest.raises(BackendNotAvailableError, match="delta_kl"): |
| 367 | run_suite(spec, backend=None, spec_path="<test>") |