tenseleyflow/sway / c15794c

Browse files

tests/integration: gradient_ghost real-store FAIL + synthetic converged PASS + runner skip-backend e2e (S25 P8)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
c15794cf28952551bad07c0460419b48e2ef882d
Parents
67dd8d0
Tree
4c3d091

1 changed file

StatusFile+-
A tests/integration/test_probe_gradient_ghost.py 191 0
tests/integration/test_probe_gradient_ghost.pyadded
@@ -0,0 +1,191 @@
1
+"""S25 — gradient_ghost integration tests.
2
+
3
+Two flavors:
4
+
5
+1. **Real-store (skipped on CI):** runs against a known-undertrained
6
+   adapter at ``~/.dlm/store/01KPPFAB2Z6DWCWY0QV702TSTX/`` if
7
+   present. This is the prove-the-value test the sprint DoD requires
8
+   on a real dlm-trained adapter. Skipped cleanly when the store is
9
+   absent so CI without local dlm install still passes.
10
+2. **Synthetic-converged (runs everywhere):** writes a fully-formed
11
+   converged training_state.pt + matching safetensors fixture and
12
+   asserts PASS. Pairs with the real-store FAIL case to give end-
13
+   to-end "FAIL on undertrained, PASS on converged" coverage in CI.
14
+
15
+Marked ``slow + online`` because building a synthetic converged
16
+training_state.pt requires torch-pickle round-tripping a real-shape
17
+optimizer state — heavier than a unit test should be.
18
+"""
19
+
20
+from __future__ import annotations
21
+
22
+from pathlib import Path
23
+
24
+import numpy as np
25
+import pytest
26
+
27
+torch = pytest.importorskip("torch", reason="needs the [hf] extra (torch)")
28
+safetensors_numpy = pytest.importorskip(
29
+    "safetensors.numpy", reason="needs the [hf] extra (safetensors)"
30
+)
31
+
32
+from dlm_sway.core.result import Verdict  # noqa: E402
33
+from dlm_sway.probes.base import RunContext, build_probe  # noqa: E402
34
+
35
+pytestmark = [pytest.mark.slow, pytest.mark.online]
36
+
37
+
38
+_REAL_STORE_PATH = (
39
+    Path.home() / ".dlm" / "store" / "01KPPFAB2Z6DWCWY0QV702TSTX" / "adapter" / "versions" / "v0001"
40
+)
41
+
42
+
43
+def test_real_undertrained_dlm_store_fails(tmp_path: Path) -> None:
44
+    """If a known dlm-trained undertrained adapter is on disk, the
45
+    probe must FAIL on it.
46
+
47
+    Skipped on machines without the local fixture (CI). The store
48
+    was the ground-truth artifact that drove the sprint design — it
49
+    was a real ``--max-steps 2`` smoke-test run.
50
+    """
51
+    if not (_REAL_STORE_PATH / "training_state.pt").exists():
52
+        pytest.skip(
53
+            f"no dlm store fixture at {_REAL_STORE_PATH} — skipping the "
54
+            "real-adapter prove-the-value test (synthetic test below "
55
+            "still runs)"
56
+        )
57
+
58
+    probe, spec = build_probe(
59
+        {
60
+            "name": "gg_real",
61
+            "kind": "gradient_ghost",
62
+            "adapter_path": str(_REAL_STORE_PATH),
63
+        }
64
+    )
65
+    result = probe.run(spec, RunContext())
66
+
67
+    assert result.verdict == Verdict.FAIL, (
68
+        f"expected FAIL on a known-undertrained dlm store, got {result.verdict}: {result.message}"
69
+    )
70
+    # The real fixture is global_step=2 — a clean primary-signal hit.
71
+    assert result.evidence["global_step"] < 50
72
+    assert result.evidence["primary_signal"] in (
73
+        "global_step_below_threshold",
74
+        "all_optimizer_state_nan",
75
+    )
76
+
77
+
78
+def _build_converged_fixture(adapter_dir: Path) -> int:
79
+    """Write a synthetic 'converged' adapter pair.
80
+
81
+    - safetensors with realistic per-layer LoRA tensor names
82
+    - training_state.pt with global_step=500 (well above threshold)
83
+      and a flat per-param exp_avg_sq distribution (no layer
84
+      crosses the per-layer ratio).
85
+    """
86
+    adapter_dir.mkdir(parents=True, exist_ok=True)
87
+    num_layers = 4
88
+    target_modules = ("q_proj", "v_proj")
89
+    rank = 8
90
+    in_features = 64
91
+
92
+    weights: dict[str, np.ndarray] = {}
93
+    for layer_idx in range(num_layers):
94
+        for mod in target_modules:
95
+            base = f"base_model.model.model.layers.{layer_idx}.self_attn.{mod}"
96
+            weights[f"{base}.lora_A.weight"] = np.zeros((rank, in_features), dtype=np.float32)
97
+            weights[f"{base}.lora_B.weight"] = np.zeros((in_features, rank), dtype=np.float32)
98
+    safetensors_numpy.save_file(weights, str(adapter_dir / "adapter_model.safetensors"))
99
+    num_keys = len(weights)
100
+
101
+    # Flat distribution: every param's exp_avg_sq is 0.1 (a small but
102
+    # finite value typical of a converged Adam state).
103
+    state_dict: dict[int, dict[str, object]] = {}
104
+    for pid in range(num_keys):
105
+        state_dict[pid] = {
106
+            "step": torch.tensor(500.0),
107
+            "exp_avg": torch.zeros((4,), dtype=torch.float32),
108
+            "exp_avg_sq": torch.full((4,), 0.1, dtype=torch.float32),
109
+        }
110
+
111
+    payload = {
112
+        "optimizer_state_dict": {
113
+            "state": state_dict,
114
+            "param_groups": [{"lr": 1e-4, "params": list(range(num_keys))}],
115
+        },
116
+        "scheduler_state_dict": {},
117
+        "scaler_state_dict": None,
118
+        "torch_rng_state": torch.zeros(8, dtype=torch.uint8),
119
+        "cuda_rng_state": None,
120
+        "numpy_rng_state": None,
121
+        "python_random_state": None,
122
+        "global_step": 500,
123
+        "epoch": 5.0,
124
+        "best_val_loss": 0.42,
125
+        "dlm_manifest_hash": None,
126
+        "base_model_revision": "synthetic-test-fixture",
127
+        "pinned_versions": {"torch": "2.11.0"},
128
+        "use_qlora": False,
129
+    }
130
+    torch.save(payload, str(adapter_dir / "training_state.pt"))
131
+    return num_keys
132
+
133
+
134
+def test_synthetic_converged_adapter_passes(tmp_path: Path) -> None:
135
+    """A hand-rolled converged training_state (global_step=500, flat
136
+    exp_avg_sq distribution) must PASS.
137
+
138
+    Together with the real-store FAIL test above, covers the
139
+    sprint's prove-the-value: 'undertrained → FAIL, converged → PASS'.
140
+    """
141
+    adapter_dir = tmp_path / "synthetic-converged"
142
+    _build_converged_fixture(adapter_dir)
143
+
144
+    probe, spec = build_probe(
145
+        {
146
+            "name": "gg_synth",
147
+            "kind": "gradient_ghost",
148
+            "adapter_path": str(adapter_dir),
149
+        }
150
+    )
151
+    result = probe.run(spec, RunContext())
152
+
153
+    assert result.verdict == Verdict.PASS, (
154
+        f"expected PASS on a synthetic converged adapter, got {result.verdict}: {result.message}"
155
+    )
156
+    assert result.evidence["global_step"] == 500
157
+    assert result.evidence["frac_layers_undertrained"] == 0.0
158
+    assert result.evidence["num_layers"] == 4
159
+
160
+
161
+def test_runner_skips_backend_for_pure_pre_run_suite(tmp_path: Path) -> None:
162
+    """End-to-end: a suite containing only gradient_ghost runs
163
+    successfully with backend=None. Confirms the S25 P5 runner
164
+    contract holds end-to-end (not just at the probe level)."""
165
+    from dlm_sway.core.model import ModelSpec
166
+    from dlm_sway.suite.runner import run as run_suite
167
+    from dlm_sway.suite.spec import SuiteDefaults, SuiteModels, SwaySpec
168
+
169
+    adapter_dir = tmp_path / "synthetic-converged"
170
+    _build_converged_fixture(adapter_dir)
171
+
172
+    spec = SwaySpec(
173
+        version=1,
174
+        models=SuiteModels(
175
+            base=ModelSpec(base="dummy", kind="dummy"),
176
+            ft=ModelSpec(base="dummy", kind="dummy", adapter=adapter_dir),
177
+        ),
178
+        defaults=SuiteDefaults(seed=0),
179
+        suite=[
180
+            {
181
+                "name": "gg",
182
+                "kind": "gradient_ghost",
183
+                "adapter_path": str(adapter_dir),
184
+            },
185
+        ],
186
+    )
187
+    result = run_suite(spec, backend=None, spec_path="<integration>")
188
+    assert len(result.probes) == 1
189
+    assert result.probes[0].verdict == Verdict.PASS
190
+    # No backend, no backend stats.
191
+    assert result.backend_stats == {}