@@ -0,0 +1,367 @@ |
| 1 | +"""Unit tests for the ``gradient_ghost`` probe (Sprint 25, F01-style). |
| 2 | + |
| 3 | +Builds synthetic ``training_state.pt`` + ``adapter_model.safetensors`` |
| 4 | +fixtures so every verdict branch (PASS / FAIL / WARN / SKIP / ERROR) |
| 5 | +runs without needing a real dlm install. The end-to-end check against |
| 6 | +a real dlm-store fixture lives in |
| 7 | +``tests/integration/test_probe_gradient_ghost.py``. |
| 8 | +""" |
| 9 | + |
| 10 | +from __future__ import annotations |
| 11 | + |
| 12 | +from pathlib import Path |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import pytest |
| 16 | + |
| 17 | +# torch + safetensors ride the [hf] extra. Skip the whole module |
| 18 | +# when missing rather than fail collection — same idiom as |
| 19 | +# tests/unit/test_mlx_convert.py. |
| 20 | +torch = pytest.importorskip("torch", reason="needs the [hf] extra (torch)") |
| 21 | +safetensors_numpy = pytest.importorskip( |
| 22 | + "safetensors.numpy", reason="needs the [hf] extra (safetensors)" |
| 23 | +) |
| 24 | + |
| 25 | +from dlm_sway.core.errors import ( # noqa: E402 — import-after-skip |
| 26 | + BackendNotAvailableError, |
| 27 | + MissingTrainingStateError, |
| 28 | +) |
| 29 | +from dlm_sway.core.result import Verdict # noqa: E402 |
| 30 | +from dlm_sway.probes._param_id_mapping import ( # noqa: E402 |
| 31 | + ParamMappingError, |
| 32 | + map_param_ids_to_layers, |
| 33 | +) |
| 34 | +from dlm_sway.probes._training_state import ( # noqa: E402 |
| 35 | + TrainingStateError, |
| 36 | + load_training_state, |
| 37 | +) |
| 38 | +from dlm_sway.probes.base import RunContext, build_probe # noqa: E402 |
| 39 | +from dlm_sway.probes.gradient_ghost import GradientGhostProbe # noqa: E402 |
| 40 | + |
| 41 | + |
| 42 | +def _write_synthetic_safetensors( |
| 43 | + dst: Path, |
| 44 | + *, |
| 45 | + num_layers: int = 4, |
| 46 | + target_modules: tuple[str, ...] = ("q_proj", "v_proj"), |
| 47 | + rank: int = 8, |
| 48 | + in_features: int = 64, |
| 49 | + out_features: int = 64, |
| 50 | +) -> int: |
| 51 | + """Write a PEFT-shaped safetensors fixture next to the training |
| 52 | + state. Returns the total number of weight keys (matches the |
| 53 | + expected number of optimizer-state params).""" |
| 54 | + weights: dict[str, np.ndarray] = {} |
| 55 | + for layer_idx in range(num_layers): |
| 56 | + for mod in target_modules: |
| 57 | + base = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}" |
| 58 | + weights[f"{base}.lora_A.weight"] = np.zeros((rank, in_features), dtype=np.float32) |
| 59 | + weights[f"{base}.lora_B.weight"] = np.zeros((out_features, rank), dtype=np.float32) |
| 60 | + safetensors_numpy.save_file(weights, str(dst / "adapter_model.safetensors")) |
| 61 | + return len(weights) |
| 62 | + |
| 63 | + |
| 64 | +def _write_synthetic_training_state( |
| 65 | + dst: Path, |
| 66 | + *, |
| 67 | + global_step: int, |
| 68 | + num_params: int, |
| 69 | + exp_avg_sq_per_param: list[float] | None = None, |
| 70 | + nan_per_param: bool = False, |
| 71 | +) -> None: |
| 72 | + """Write a minimal ``training_state.pt`` whose shape matches |
| 73 | + dlm's contract. |
| 74 | + |
| 75 | + ``exp_avg_sq_per_param`` lets a test plant per-param means (one |
| 76 | + float per param-id) for the per-layer ratio branches. |
| 77 | + ``nan_per_param=True`` sets every exp_avg_sq tensor to NaN |
| 78 | + (proves the all-NaN FAIL branch). |
| 79 | + """ |
| 80 | + if exp_avg_sq_per_param is None: |
| 81 | + exp_avg_sq_per_param = [1.0] * num_params |
| 82 | + |
| 83 | + state_dict: dict[int, dict[str, object]] = {} |
| 84 | + for pid, sq_mean in enumerate(exp_avg_sq_per_param): |
| 85 | + if nan_per_param: |
| 86 | + tensor = torch.full((4,), float("nan"), dtype=torch.float32) |
| 87 | + else: |
| 88 | + tensor = torch.full((4,), float(sq_mean), dtype=torch.float32) |
| 89 | + state_dict[pid] = { |
| 90 | + "step": torch.tensor(float(global_step)), |
| 91 | + "exp_avg": torch.zeros((4,), dtype=torch.float32), |
| 92 | + "exp_avg_sq": tensor, |
| 93 | + } |
| 94 | + |
| 95 | + payload = { |
| 96 | + "optimizer_state_dict": { |
| 97 | + "state": state_dict, |
| 98 | + "param_groups": [{"lr": 1e-4, "params": list(range(num_params))}], |
| 99 | + }, |
| 100 | + "scheduler_state_dict": {}, |
| 101 | + "scaler_state_dict": None, |
| 102 | + "torch_rng_state": torch.zeros(8, dtype=torch.uint8), |
| 103 | + "cuda_rng_state": None, |
| 104 | + "numpy_rng_state": None, |
| 105 | + "python_random_state": None, |
| 106 | + "global_step": global_step, |
| 107 | + "epoch": float(global_step), |
| 108 | + "best_val_loss": float("inf"), |
| 109 | + "dlm_manifest_hash": None, |
| 110 | + "base_model_revision": "deadbeef", |
| 111 | + "pinned_versions": {"torch": "2.11.0"}, |
| 112 | + "use_qlora": False, |
| 113 | + } |
| 114 | + torch.save(payload, str(dst / "training_state.pt")) |
| 115 | + |
| 116 | + |
| 117 | +# === Tests === |
| 118 | + |
| 119 | + |
| 120 | +class TestProbeRegistry: |
| 121 | + def test_kind_registered(self) -> None: |
| 122 | + """Probe must be discoverable via build_probe.""" |
| 123 | + probe, _ = build_probe( |
| 124 | + {"name": "x", "kind": "gradient_ghost", "adapter_path": "/nonexistent"} |
| 125 | + ) |
| 126 | + assert isinstance(probe, GradientGhostProbe) |
| 127 | + |
| 128 | + def test_needs_backend_false(self) -> None: |
| 129 | + """needs_backend=False enables the runner's skip-backend path.""" |
| 130 | + assert GradientGhostProbe.needs_backend is False |
| 131 | + |
| 132 | + def test_category_calibration(self) -> None: |
| 133 | + """Category must match the sprint's classification.""" |
| 134 | + assert GradientGhostProbe.category == "calibration" |
| 135 | + |
| 136 | + |
| 137 | +class TestVerdictLadder: |
| 138 | + """Each branch in the verdict ladder gets its own test.""" |
| 139 | + |
| 140 | + def test_pass_when_global_step_high_and_distribution_flat(self, tmp_path: Path) -> None: |
| 141 | + adapter = tmp_path / "adapter" |
| 142 | + adapter.mkdir() |
| 143 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 144 | + # Flat distribution — every param has the same exp_avg_sq. |
| 145 | + _write_synthetic_training_state( |
| 146 | + adapter, |
| 147 | + global_step=200, |
| 148 | + num_params=num_keys, |
| 149 | + exp_avg_sq_per_param=[1.0] * num_keys, |
| 150 | + ) |
| 151 | + |
| 152 | + probe, spec = build_probe( |
| 153 | + {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 154 | + ) |
| 155 | + result = probe.run(spec, RunContext()) |
| 156 | + assert result.verdict == Verdict.PASS |
| 157 | + assert result.evidence["global_step"] == 200 |
| 158 | + assert result.evidence["frac_layers_undertrained"] == 0.0 |
| 159 | + |
| 160 | + def test_fail_when_global_step_below_threshold(self, tmp_path: Path) -> None: |
| 161 | + adapter = tmp_path / "adapter" |
| 162 | + adapter.mkdir() |
| 163 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 164 | + _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys) |
| 165 | + |
| 166 | + probe, spec = build_probe( |
| 167 | + {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 168 | + ) |
| 169 | + result = probe.run(spec, RunContext()) |
| 170 | + assert result.verdict == Verdict.FAIL |
| 171 | + assert result.evidence["global_step"] == 2 |
| 172 | + assert result.evidence["primary_signal"] == "global_step_below_threshold" |
| 173 | + assert "severely undertrained" in (result.message or "") |
| 174 | + |
| 175 | + def test_fail_when_all_exp_avg_sq_nan(self, tmp_path: Path) -> None: |
| 176 | + """Even with global_step >= threshold, every NaN per-param |
| 177 | + triggers a separate FAIL branch — training propagated nothing.""" |
| 178 | + adapter = tmp_path / "adapter" |
| 179 | + adapter.mkdir() |
| 180 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 181 | + _write_synthetic_training_state( |
| 182 | + adapter, |
| 183 | + global_step=200, |
| 184 | + num_params=num_keys, |
| 185 | + nan_per_param=True, |
| 186 | + ) |
| 187 | + |
| 188 | + probe, spec = build_probe( |
| 189 | + {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 190 | + ) |
| 191 | + result = probe.run(spec, RunContext()) |
| 192 | + assert result.verdict == Verdict.FAIL |
| 193 | + assert result.evidence["primary_signal"] == "all_optimizer_state_nan" |
| 194 | + assert result.evidence["num_nonfinite_exp_avg_sq"] == num_keys |
| 195 | + |
| 196 | + def test_warn_when_some_layers_high_but_under_threshold(self, tmp_path: Path) -> None: |
| 197 | + """A heavy-tailed exp_avg_sq distribution where < layer_failure_frac |
| 198 | + of layers cross the per-layer threshold → WARN.""" |
| 199 | + adapter = tmp_path / "adapter" |
| 200 | + adapter.mkdir() |
| 201 | + # 4 layers × 2 modules × 2 factors = 16 params. |
| 202 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 203 | + # Flat baseline; bump layer 0's params to 3× (above the 2× |
| 204 | + # threshold) but only 1 of 4 layers crosses (25%). |
| 205 | + magnitudes = [1.0] * num_keys |
| 206 | + for pid in range(4): # First 4 params = layer 0 |
| 207 | + magnitudes[pid] = 3.0 |
| 208 | + _write_synthetic_training_state( |
| 209 | + adapter, |
| 210 | + global_step=200, |
| 211 | + num_params=num_keys, |
| 212 | + exp_avg_sq_per_param=magnitudes, |
| 213 | + ) |
| 214 | + |
| 215 | + probe, spec = build_probe( |
| 216 | + { |
| 217 | + "name": "gg", |
| 218 | + "kind": "gradient_ghost", |
| 219 | + "adapter_path": str(adapter), |
| 220 | + "layer_failure_frac": 0.5, # Need >50% to FAIL. |
| 221 | + } |
| 222 | + ) |
| 223 | + result = probe.run(spec, RunContext()) |
| 224 | + assert result.verdict == Verdict.WARN |
| 225 | + assert result.evidence["num_layers_undertrained"] == 1 |
| 226 | + assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.25) |
| 227 | + |
| 228 | + def test_fail_when_too_many_layers_high(self, tmp_path: Path) -> None: |
| 229 | + """When more than layer_failure_frac of layers cross the |
| 230 | + per-layer threshold, secondary signal also FAILs.""" |
| 231 | + adapter = tmp_path / "adapter" |
| 232 | + adapter.mkdir() |
| 233 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 234 | + # First 12 params = first 3 layers all bumped 3×; last layer flat. |
| 235 | + magnitudes = [3.0] * 12 + [1.0] * 4 |
| 236 | + _write_synthetic_training_state( |
| 237 | + adapter, |
| 238 | + global_step=200, |
| 239 | + num_params=num_keys, |
| 240 | + exp_avg_sq_per_param=magnitudes, |
| 241 | + ) |
| 242 | + |
| 243 | + probe, spec = build_probe( |
| 244 | + { |
| 245 | + "name": "gg", |
| 246 | + "kind": "gradient_ghost", |
| 247 | + "adapter_path": str(adapter), |
| 248 | + "layer_failure_frac": 0.3, |
| 249 | + } |
| 250 | + ) |
| 251 | + result = probe.run(spec, RunContext()) |
| 252 | + assert result.verdict == Verdict.FAIL |
| 253 | + assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.75) |
| 254 | + |
| 255 | + def test_skip_when_training_state_missing(self, tmp_path: Path) -> None: |
| 256 | + """No training_state.pt → SKIP (legitimate for non-dlm |
| 257 | + adapters), not ERROR.""" |
| 258 | + adapter = tmp_path / "adapter-no-state" |
| 259 | + adapter.mkdir() |
| 260 | + # adapter_model.safetensors doesn't matter — probe SKIPs first. |
| 261 | + probe, spec = build_probe( |
| 262 | + {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)} |
| 263 | + ) |
| 264 | + result = probe.run(spec, RunContext()) |
| 265 | + assert result.verdict == Verdict.SKIP |
| 266 | + assert "training_state.pt" in (result.message or "") |
| 267 | + |
| 268 | + |
| 269 | +class TestParamIdMapping: |
| 270 | + """The layer-grouping helper is exercised indirectly via probe |
| 271 | + runs above; this class adds direct coverage of edge cases.""" |
| 272 | + |
| 273 | + def test_correct_layer_groupings(self, tmp_path: Path) -> None: |
| 274 | + adapter = tmp_path / "a" |
| 275 | + adapter.mkdir() |
| 276 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=3, target_modules=("q_proj",)) |
| 277 | + # 3 layers × 1 module × 2 factors = 6 keys. |
| 278 | + assert num_keys == 6 |
| 279 | + grouping = map_param_ids_to_layers(adapter, num_params=num_keys) |
| 280 | + assert grouping.num_layers == 3 |
| 281 | + assert grouping.params_per_layer == 2 |
| 282 | + assert [grouping.layer_of[i] for i in range(6)] == [0, 0, 1, 1, 2, 2] |
| 283 | + |
| 284 | + def test_missing_safetensors_raises(self, tmp_path: Path) -> None: |
| 285 | + adapter = tmp_path / "empty" |
| 286 | + adapter.mkdir() |
| 287 | + with pytest.raises(ParamMappingError, match="missing"): |
| 288 | + map_param_ids_to_layers(adapter, num_params=10) |
| 289 | + |
| 290 | + def test_mismatched_param_count_raises(self, tmp_path: Path) -> None: |
| 291 | + adapter = tmp_path / "a" |
| 292 | + adapter.mkdir() |
| 293 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=2) |
| 294 | + # Pretend the optimizer has fewer params than safetensors keys. |
| 295 | + with pytest.raises(ParamMappingError, match="adapter / state mismatch"): |
| 296 | + map_param_ids_to_layers(adapter, num_params=num_keys - 2) |
| 297 | + |
| 298 | + |
| 299 | +class TestTrainingStateLoader: |
| 300 | + def test_missing_file_raises_typed(self, tmp_path: Path) -> None: |
| 301 | + with pytest.raises(MissingTrainingStateError): |
| 302 | + load_training_state(tmp_path) |
| 303 | + |
| 304 | + def test_corrupt_pickle_raises_typed(self, tmp_path: Path) -> None: |
| 305 | + (tmp_path / "training_state.pt").write_bytes(b"not a pickle") |
| 306 | + with pytest.raises(TrainingStateError, match="failed to torch.load"): |
| 307 | + load_training_state(tmp_path) |
| 308 | + |
| 309 | + def test_unexpected_top_level_shape_raises(self, tmp_path: Path) -> None: |
| 310 | + torch.save({"foo": "bar"}, str(tmp_path / "training_state.pt")) |
| 311 | + with pytest.raises(TrainingStateError, match="missing 'optimizer_state_dict'"): |
| 312 | + load_training_state(tmp_path) |
| 313 | + |
| 314 | + |
| 315 | +class TestRunnerSkipsBackend: |
| 316 | + """S25 P5 — runner contract when only no-backend probes scheduled.""" |
| 317 | + |
| 318 | + def test_runs_with_none_backend_when_only_pre_run_probes(self, tmp_path: Path) -> None: |
| 319 | + """A spec containing only gradient_ghost runs with backend=None.""" |
| 320 | + from dlm_sway.core.model import ModelSpec |
| 321 | + from dlm_sway.suite.runner import run as run_suite |
| 322 | + from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec |
| 323 | + |
| 324 | + adapter = tmp_path / "adapter" |
| 325 | + adapter.mkdir() |
| 326 | + num_keys = _write_synthetic_safetensors(adapter, num_layers=4) |
| 327 | + _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys) |
| 328 | + |
| 329 | + spec = SwaySpec( |
| 330 | + version=1, |
| 331 | + models=SuiteModels( |
| 332 | + base=ModelSpec(base="dummy", kind="dummy"), |
| 333 | + ft=ModelSpec(base="dummy", kind="dummy", adapter=adapter), |
| 334 | + ), |
| 335 | + defaults=SuiteDefaults(seed=0), |
| 336 | + suite=[ |
| 337 | + { |
| 338 | + "name": "gg", |
| 339 | + "kind": "gradient_ghost", |
| 340 | + "adapter_path": str(adapter), |
| 341 | + } |
| 342 | + ], |
| 343 | + ) |
| 344 | + result = run_suite(spec, backend=None, spec_path="<test>") |
| 345 | + assert len(result.probes) == 1 |
| 346 | + assert result.probes[0].verdict == Verdict.FAIL |
| 347 | + assert result.backend_stats == {} # No backend means no stats. |
| 348 | + |
| 349 | + def test_raises_when_backend_required_but_none(self, tmp_path: Path) -> None: |
| 350 | + """A spec with delta_kl + None backend → BackendNotAvailableError.""" |
| 351 | + from dlm_sway.core.model import ModelSpec |
| 352 | + from dlm_sway.suite.runner import run as run_suite |
| 353 | + from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec |
| 354 | + |
| 355 | + spec = SwaySpec( |
| 356 | + version=1, |
| 357 | + models=SuiteModels( |
| 358 | + base=ModelSpec(base="dummy", kind="dummy"), |
| 359 | + ft=ModelSpec(base="dummy", kind="dummy"), |
| 360 | + ), |
| 361 | + defaults=SuiteDefaults(seed=0), |
| 362 | + suite=[ |
| 363 | + {"name": "dk", "kind": "delta_kl", "prompts": ["x"]}, |
| 364 | + ], |
| 365 | + ) |
| 366 | + with pytest.raises(BackendNotAvailableError, match="delta_kl"): |
| 367 | + run_suite(spec, backend=None, spec_path="<test>") |