tenseleyflow/sway / 67dd8d0

Browse files

probes/gradient_ghost: min-baseline ratio + 17 unit tests covering the verdict ladder (S25 P7)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
67dd8d053fb6c4588d003b584caa72084500f5ee
Parents
58b0322
Tree
26f735b

2 changed files

StatusFile+-
M src/dlm_sway/probes/gradient_ghost.py 26 12
A tests/unit/test_probe_gradient_ghost.py 367 0
src/dlm_sway/probes/gradient_ghost.pymodified
@@ -46,9 +46,14 @@ populated by the dlm autogen bridge from
4646
 
4747
 - ``min_steps_threshold = 50`` — below this is severely undertrained.
4848
 - ``undertrained_layer_ratio = 2.0`` — a layer's mean ``exp_avg_sq``
49
-  must be > 2× the global mean to count as "still has high gradient
50
-  variance." Multiplicative threshold (relative, not absolute) so
51
-  the probe is architecture-agnostic.
49
+  must be > 2× the **minimum** layer's mean to count as "still has
50
+  high gradient variance." We compare against the min (not the
51
+  global mean) because mean rises with outliers — under a global-
52
+  mean baseline the per-layer ratio asymptotically caps at
53
+  ``N/(N-K)`` and can never exceed ``ratio`` for K layers without K
54
+  also rising. Min-baseline gives a stable "K layers are anomalously
55
+  high vs the calmest layer" signal regardless of how many layers
56
+  spike.
5257
 - ``layer_failure_frac = 0.3`` — WARN if more than 30% of layers
5358
   cross the per-layer threshold.
5459
 """
@@ -82,8 +87,9 @@ class GradientGhostSpec(ProbeSpec):
8287
     """``global_step`` below this → FAIL (severely undertrained)."""
8388
     undertrained_layer_ratio: float = Field(default=2.0, gt=1.0)
8489
     """A layer counts as 'high gradient variance' when its mean
85
-    ``exp_avg_sq`` exceeds ``ratio * global_mean``. Strictly > 1
86
-    (a value of 1 would always flag half the layers)."""
90
+    ``exp_avg_sq`` exceeds ``ratio * min_layer_mean``. The min-
91
+    baseline (rather than global mean) is robust to outliers — see
92
+    the module docstring for the asymptotic-cap reasoning."""
8793
     layer_failure_frac: float = Field(default=0.3, ge=0.0, le=1.0)
8894
     """WARN when more than this fraction of layers cross the
8995
     ``undertrained_layer_ratio`` threshold."""
@@ -188,6 +194,7 @@ class GradientGhostProbe(Probe):
188194
         global_mean = statistics.fmean(finite_means)
189195
         per_layer_means: dict[int, float] = {}
190196
         per_layer_undertrained: list[int] = []
197
+        baseline_min: float = 0.0
191198
 
192199
         if grouping is not None and global_mean > 0.0:
193200
             # Group finite per-param means by layer index.
@@ -201,18 +208,25 @@ class GradientGhostProbe(Probe):
201208
             for layer_idx, vals in buckets.items():
202209
                 if not vals:
203210
                     continue
204
-                layer_mean = statistics.fmean(vals)
205
-                per_layer_means[layer_idx] = layer_mean
206
-                if layer_mean > spec.undertrained_layer_ratio * global_mean:
207
-                    per_layer_undertrained.append(layer_idx)
211
+                per_layer_means[layer_idx] = statistics.fmean(vals)
212
+            # Min-baseline ratio (see module docstring on why we use
213
+            # min instead of global mean — global-mean ratio is
214
+            # asymptotically capped and can't catch the case where K
215
+            # layers all spike together).
216
+            if per_layer_means:
217
+                baseline_min = min(per_layer_means.values())
218
+                if baseline_min > 0.0:
219
+                    for layer_idx, mean in per_layer_means.items():
220
+                        if mean > spec.undertrained_layer_ratio * baseline_min:
221
+                            per_layer_undertrained.append(layer_idx)
208222
 
209223
         frac_undertrained = len(per_layer_undertrained) / layer_count if layer_count > 0 else 0.0
210224
 
211
-        # Top-3 worst layers (highest ratio) — useful evidence even
212
-        # when no layer crosses the threshold.
225
+        # Top-3 worst layers (highest ratio vs baseline_min) — useful
226
+        # evidence even when no layer crosses the threshold.
213227
         ranked_layers = sorted(per_layer_means.items(), key=lambda kv: -kv[1])[:3]
214228
         worst_layers = [
215
-            {"layer": idx, "ratio": (mean / global_mean) if global_mean > 0 else None}
229
+            {"layer": idx, "ratio": (mean / baseline_min) if baseline_min > 0 else None}
216230
             for idx, mean in ranked_layers
217231
         ]
218232
 
tests/unit/test_probe_gradient_ghost.pyadded
@@ -0,0 +1,367 @@
1
+"""Unit tests for the ``gradient_ghost`` probe (Sprint 25, F01-style).
2
+
3
+Builds synthetic ``training_state.pt`` + ``adapter_model.safetensors``
4
+fixtures so every verdict branch (PASS / FAIL / WARN / SKIP / ERROR)
5
+runs without needing a real dlm install. The end-to-end check against
6
+a real dlm-store fixture lives in
7
+``tests/integration/test_probe_gradient_ghost.py``.
8
+"""
9
+
10
+from __future__ import annotations
11
+
12
+from pathlib import Path
13
+
14
+import numpy as np
15
+import pytest
16
+
17
+# torch + safetensors ride the [hf] extra. Skip the whole module
18
+# when missing rather than fail collection — same idiom as
19
+# tests/unit/test_mlx_convert.py.
20
+torch = pytest.importorskip("torch", reason="needs the [hf] extra (torch)")
21
+safetensors_numpy = pytest.importorskip(
22
+    "safetensors.numpy", reason="needs the [hf] extra (safetensors)"
23
+)
24
+
25
+from dlm_sway.core.errors import (  # noqa: E402 — import-after-skip
26
+    BackendNotAvailableError,
27
+    MissingTrainingStateError,
28
+)
29
+from dlm_sway.core.result import Verdict  # noqa: E402
30
+from dlm_sway.probes._param_id_mapping import (  # noqa: E402
31
+    ParamMappingError,
32
+    map_param_ids_to_layers,
33
+)
34
+from dlm_sway.probes._training_state import (  # noqa: E402
35
+    TrainingStateError,
36
+    load_training_state,
37
+)
38
+from dlm_sway.probes.base import RunContext, build_probe  # noqa: E402
39
+from dlm_sway.probes.gradient_ghost import GradientGhostProbe  # noqa: E402
40
+
41
+
42
+def _write_synthetic_safetensors(
43
+    dst: Path,
44
+    *,
45
+    num_layers: int = 4,
46
+    target_modules: tuple[str, ...] = ("q_proj", "v_proj"),
47
+    rank: int = 8,
48
+    in_features: int = 64,
49
+    out_features: int = 64,
50
+) -> int:
51
+    """Write a PEFT-shaped safetensors fixture next to the training
52
+    state. Returns the total number of weight keys (matches the
53
+    expected number of optimizer-state params)."""
54
+    weights: dict[str, np.ndarray] = {}
55
+    for layer_idx in range(num_layers):
56
+        for mod in target_modules:
57
+            base = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}"
58
+            weights[f"{base}.lora_A.weight"] = np.zeros((rank, in_features), dtype=np.float32)
59
+            weights[f"{base}.lora_B.weight"] = np.zeros((out_features, rank), dtype=np.float32)
60
+    safetensors_numpy.save_file(weights, str(dst / "adapter_model.safetensors"))
61
+    return len(weights)
62
+
63
+
64
+def _write_synthetic_training_state(
65
+    dst: Path,
66
+    *,
67
+    global_step: int,
68
+    num_params: int,
69
+    exp_avg_sq_per_param: list[float] | None = None,
70
+    nan_per_param: bool = False,
71
+) -> None:
72
+    """Write a minimal ``training_state.pt`` whose shape matches
73
+    dlm's contract.
74
+
75
+    ``exp_avg_sq_per_param`` lets a test plant per-param means (one
76
+    float per param-id) for the per-layer ratio branches.
77
+    ``nan_per_param=True`` sets every exp_avg_sq tensor to NaN
78
+    (proves the all-NaN FAIL branch).
79
+    """
80
+    if exp_avg_sq_per_param is None:
81
+        exp_avg_sq_per_param = [1.0] * num_params
82
+
83
+    state_dict: dict[int, dict[str, object]] = {}
84
+    for pid, sq_mean in enumerate(exp_avg_sq_per_param):
85
+        if nan_per_param:
86
+            tensor = torch.full((4,), float("nan"), dtype=torch.float32)
87
+        else:
88
+            tensor = torch.full((4,), float(sq_mean), dtype=torch.float32)
89
+        state_dict[pid] = {
90
+            "step": torch.tensor(float(global_step)),
91
+            "exp_avg": torch.zeros((4,), dtype=torch.float32),
92
+            "exp_avg_sq": tensor,
93
+        }
94
+
95
+    payload = {
96
+        "optimizer_state_dict": {
97
+            "state": state_dict,
98
+            "param_groups": [{"lr": 1e-4, "params": list(range(num_params))}],
99
+        },
100
+        "scheduler_state_dict": {},
101
+        "scaler_state_dict": None,
102
+        "torch_rng_state": torch.zeros(8, dtype=torch.uint8),
103
+        "cuda_rng_state": None,
104
+        "numpy_rng_state": None,
105
+        "python_random_state": None,
106
+        "global_step": global_step,
107
+        "epoch": float(global_step),
108
+        "best_val_loss": float("inf"),
109
+        "dlm_manifest_hash": None,
110
+        "base_model_revision": "deadbeef",
111
+        "pinned_versions": {"torch": "2.11.0"},
112
+        "use_qlora": False,
113
+    }
114
+    torch.save(payload, str(dst / "training_state.pt"))
115
+
116
+
117
+# === Tests ===
118
+
119
+
120
+class TestProbeRegistry:
121
+    def test_kind_registered(self) -> None:
122
+        """Probe must be discoverable via build_probe."""
123
+        probe, _ = build_probe(
124
+            {"name": "x", "kind": "gradient_ghost", "adapter_path": "/nonexistent"}
125
+        )
126
+        assert isinstance(probe, GradientGhostProbe)
127
+
128
+    def test_needs_backend_false(self) -> None:
129
+        """needs_backend=False enables the runner's skip-backend path."""
130
+        assert GradientGhostProbe.needs_backend is False
131
+
132
+    def test_category_calibration(self) -> None:
133
+        """Category must match the sprint's classification."""
134
+        assert GradientGhostProbe.category == "calibration"
135
+
136
+
137
+class TestVerdictLadder:
138
+    """Each branch in the verdict ladder gets its own test."""
139
+
140
+    def test_pass_when_global_step_high_and_distribution_flat(self, tmp_path: Path) -> None:
141
+        adapter = tmp_path / "adapter"
142
+        adapter.mkdir()
143
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
144
+        # Flat distribution — every param has the same exp_avg_sq.
145
+        _write_synthetic_training_state(
146
+            adapter,
147
+            global_step=200,
148
+            num_params=num_keys,
149
+            exp_avg_sq_per_param=[1.0] * num_keys,
150
+        )
151
+
152
+        probe, spec = build_probe(
153
+            {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
154
+        )
155
+        result = probe.run(spec, RunContext())
156
+        assert result.verdict == Verdict.PASS
157
+        assert result.evidence["global_step"] == 200
158
+        assert result.evidence["frac_layers_undertrained"] == 0.0
159
+
160
+    def test_fail_when_global_step_below_threshold(self, tmp_path: Path) -> None:
161
+        adapter = tmp_path / "adapter"
162
+        adapter.mkdir()
163
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
164
+        _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys)
165
+
166
+        probe, spec = build_probe(
167
+            {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
168
+        )
169
+        result = probe.run(spec, RunContext())
170
+        assert result.verdict == Verdict.FAIL
171
+        assert result.evidence["global_step"] == 2
172
+        assert result.evidence["primary_signal"] == "global_step_below_threshold"
173
+        assert "severely undertrained" in (result.message or "")
174
+
175
+    def test_fail_when_all_exp_avg_sq_nan(self, tmp_path: Path) -> None:
176
+        """Even with global_step >= threshold, every NaN per-param
177
+        triggers a separate FAIL branch — training propagated nothing."""
178
+        adapter = tmp_path / "adapter"
179
+        adapter.mkdir()
180
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
181
+        _write_synthetic_training_state(
182
+            adapter,
183
+            global_step=200,
184
+            num_params=num_keys,
185
+            nan_per_param=True,
186
+        )
187
+
188
+        probe, spec = build_probe(
189
+            {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
190
+        )
191
+        result = probe.run(spec, RunContext())
192
+        assert result.verdict == Verdict.FAIL
193
+        assert result.evidence["primary_signal"] == "all_optimizer_state_nan"
194
+        assert result.evidence["num_nonfinite_exp_avg_sq"] == num_keys
195
+
196
+    def test_warn_when_some_layers_high_but_under_threshold(self, tmp_path: Path) -> None:
197
+        """A heavy-tailed exp_avg_sq distribution where < layer_failure_frac
198
+        of layers cross the per-layer threshold → WARN."""
199
+        adapter = tmp_path / "adapter"
200
+        adapter.mkdir()
201
+        # 4 layers × 2 modules × 2 factors = 16 params.
202
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
203
+        # Flat baseline; bump layer 0's params to 3× (above the 2×
204
+        # threshold) but only 1 of 4 layers crosses (25%).
205
+        magnitudes = [1.0] * num_keys
206
+        for pid in range(4):  # First 4 params = layer 0
207
+            magnitudes[pid] = 3.0
208
+        _write_synthetic_training_state(
209
+            adapter,
210
+            global_step=200,
211
+            num_params=num_keys,
212
+            exp_avg_sq_per_param=magnitudes,
213
+        )
214
+
215
+        probe, spec = build_probe(
216
+            {
217
+                "name": "gg",
218
+                "kind": "gradient_ghost",
219
+                "adapter_path": str(adapter),
220
+                "layer_failure_frac": 0.5,  # Need >50% to FAIL.
221
+            }
222
+        )
223
+        result = probe.run(spec, RunContext())
224
+        assert result.verdict == Verdict.WARN
225
+        assert result.evidence["num_layers_undertrained"] == 1
226
+        assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.25)
227
+
228
+    def test_fail_when_too_many_layers_high(self, tmp_path: Path) -> None:
229
+        """When more than layer_failure_frac of layers cross the
230
+        per-layer threshold, secondary signal also FAILs."""
231
+        adapter = tmp_path / "adapter"
232
+        adapter.mkdir()
233
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
234
+        # First 12 params = first 3 layers all bumped 3×; last layer flat.
235
+        magnitudes = [3.0] * 12 + [1.0] * 4
236
+        _write_synthetic_training_state(
237
+            adapter,
238
+            global_step=200,
239
+            num_params=num_keys,
240
+            exp_avg_sq_per_param=magnitudes,
241
+        )
242
+
243
+        probe, spec = build_probe(
244
+            {
245
+                "name": "gg",
246
+                "kind": "gradient_ghost",
247
+                "adapter_path": str(adapter),
248
+                "layer_failure_frac": 0.3,
249
+            }
250
+        )
251
+        result = probe.run(spec, RunContext())
252
+        assert result.verdict == Verdict.FAIL
253
+        assert result.evidence["frac_layers_undertrained"] == pytest.approx(0.75)
254
+
255
+    def test_skip_when_training_state_missing(self, tmp_path: Path) -> None:
256
+        """No training_state.pt → SKIP (legitimate for non-dlm
257
+        adapters), not ERROR."""
258
+        adapter = tmp_path / "adapter-no-state"
259
+        adapter.mkdir()
260
+        # adapter_model.safetensors doesn't matter — probe SKIPs first.
261
+        probe, spec = build_probe(
262
+            {"name": "gg", "kind": "gradient_ghost", "adapter_path": str(adapter)}
263
+        )
264
+        result = probe.run(spec, RunContext())
265
+        assert result.verdict == Verdict.SKIP
266
+        assert "training_state.pt" in (result.message or "")
267
+
268
+
269
+class TestParamIdMapping:
270
+    """The layer-grouping helper is exercised indirectly via probe
271
+    runs above; this class adds direct coverage of edge cases."""
272
+
273
+    def test_correct_layer_groupings(self, tmp_path: Path) -> None:
274
+        adapter = tmp_path / "a"
275
+        adapter.mkdir()
276
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=3, target_modules=("q_proj",))
277
+        # 3 layers × 1 module × 2 factors = 6 keys.
278
+        assert num_keys == 6
279
+        grouping = map_param_ids_to_layers(adapter, num_params=num_keys)
280
+        assert grouping.num_layers == 3
281
+        assert grouping.params_per_layer == 2
282
+        assert [grouping.layer_of[i] for i in range(6)] == [0, 0, 1, 1, 2, 2]
283
+
284
+    def test_missing_safetensors_raises(self, tmp_path: Path) -> None:
285
+        adapter = tmp_path / "empty"
286
+        adapter.mkdir()
287
+        with pytest.raises(ParamMappingError, match="missing"):
288
+            map_param_ids_to_layers(adapter, num_params=10)
289
+
290
+    def test_mismatched_param_count_raises(self, tmp_path: Path) -> None:
291
+        adapter = tmp_path / "a"
292
+        adapter.mkdir()
293
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=2)
294
+        # Pretend the optimizer has fewer params than safetensors keys.
295
+        with pytest.raises(ParamMappingError, match="adapter / state mismatch"):
296
+            map_param_ids_to_layers(adapter, num_params=num_keys - 2)
297
+
298
+
299
+class TestTrainingStateLoader:
300
+    def test_missing_file_raises_typed(self, tmp_path: Path) -> None:
301
+        with pytest.raises(MissingTrainingStateError):
302
+            load_training_state(tmp_path)
303
+
304
+    def test_corrupt_pickle_raises_typed(self, tmp_path: Path) -> None:
305
+        (tmp_path / "training_state.pt").write_bytes(b"not a pickle")
306
+        with pytest.raises(TrainingStateError, match="failed to torch.load"):
307
+            load_training_state(tmp_path)
308
+
309
+    def test_unexpected_top_level_shape_raises(self, tmp_path: Path) -> None:
310
+        torch.save({"foo": "bar"}, str(tmp_path / "training_state.pt"))
311
+        with pytest.raises(TrainingStateError, match="missing 'optimizer_state_dict'"):
312
+            load_training_state(tmp_path)
313
+
314
+
315
+class TestRunnerSkipsBackend:
316
+    """S25 P5 — runner contract when only no-backend probes scheduled."""
317
+
318
+    def test_runs_with_none_backend_when_only_pre_run_probes(self, tmp_path: Path) -> None:
319
+        """A spec containing only gradient_ghost runs with backend=None."""
320
+        from dlm_sway.core.model import ModelSpec
321
+        from dlm_sway.suite.runner import run as run_suite
322
+        from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec
323
+
324
+        adapter = tmp_path / "adapter"
325
+        adapter.mkdir()
326
+        num_keys = _write_synthetic_safetensors(adapter, num_layers=4)
327
+        _write_synthetic_training_state(adapter, global_step=2, num_params=num_keys)
328
+
329
+        spec = SwaySpec(
330
+            version=1,
331
+            models=SuiteModels(
332
+                base=ModelSpec(base="dummy", kind="dummy"),
333
+                ft=ModelSpec(base="dummy", kind="dummy", adapter=adapter),
334
+            ),
335
+            defaults=SuiteDefaults(seed=0),
336
+            suite=[
337
+                {
338
+                    "name": "gg",
339
+                    "kind": "gradient_ghost",
340
+                    "adapter_path": str(adapter),
341
+                }
342
+            ],
343
+        )
344
+        result = run_suite(spec, backend=None, spec_path="<test>")
345
+        assert len(result.probes) == 1
346
+        assert result.probes[0].verdict == Verdict.FAIL
347
+        assert result.backend_stats == {}  # No backend means no stats.
348
+
349
+    def test_raises_when_backend_required_but_none(self, tmp_path: Path) -> None:
350
+        """A spec with delta_kl + None backend → BackendNotAvailableError."""
351
+        from dlm_sway.core.model import ModelSpec
352
+        from dlm_sway.suite.runner import run as run_suite
353
+        from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec
354
+
355
+        spec = SwaySpec(
356
+            version=1,
357
+            models=SuiteModels(
358
+                base=ModelSpec(base="dummy", kind="dummy"),
359
+                ft=ModelSpec(base="dummy", kind="dummy"),
360
+            ),
361
+            defaults=SuiteDefaults(seed=0),
362
+            suite=[
363
+                {"name": "dk", "kind": "delta_kl", "prompts": ["x"]},
364
+            ],
365
+        )
366
+        with pytest.raises(BackendNotAvailableError, match="delta_kl"):
367
+            run_suite(spec, backend=None, spec_path="<test>")