Python · 13667 bytes Raw Blame History
1 """Tests for null-adapter calibration.
2
3 Covers: dummy backend ``as_null_adapter`` yields a plausibly noisy
4 view; ``NullAdapterProbe`` populates ``ctx.null_stats`` in a way
5 downstream probes pick up end-to-end; missing-capability SKIP path.
6 """
7
8 from __future__ import annotations
9
10 import numpy as np
11 import pytest
12
13 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14 from dlm_sway.core.result import Verdict
15 from dlm_sway.core.scoring import NullCalibratedBackend
16 from dlm_sway.probes.base import RunContext, build_probe
17 from dlm_sway.suite.runner import run as run_suite
18 from dlm_sway.suite.spec import SwaySpec
19
20
21 def _diverging_backend() -> DummyDifferentialBackend:
22 base = DummyResponses()
23 ft = DummyResponses()
24 return DummyDifferentialBackend(base=base, ft=ft)
25
26
27 class TestProtocolConformance:
28 def test_dummy_is_null_calibrated(self) -> None:
29 assert isinstance(_diverging_backend(), NullCalibratedBackend)
30
31
32 class TestAsNullAdapter:
33 def test_yields_perturbed_view(self) -> None:
34 backend = _diverging_backend()
35 with backend.as_base() as base:
36 base_dist = base.next_token_dist("hello")
37 with backend.as_null_adapter(seed=0) as null:
38 null_dist = null.next_token_dist("hello")
39 # Some perturbation, but bounded.
40 assert not np.allclose(base_dist.logprobs, null_dist.logprobs)
41
42 def test_different_seeds_yield_different_views(self) -> None:
43 backend = _diverging_backend()
44 with backend.as_null_adapter(seed=1) as v1:
45 d1 = v1.next_token_dist("hello")
46 with backend.as_null_adapter(seed=2) as v2:
47 d2 = v2.next_token_dist("hello")
48 assert not np.allclose(d1.logprobs, d2.logprobs)
49
50 def test_view_exclusion_enforced(self) -> None:
51 import pytest
52
53 backend = _diverging_backend()
54 with backend.as_null_adapter(seed=0), pytest.raises(RuntimeError):
55 with backend.as_base():
56 pass
57
58
59 class TestProbe:
60 def test_populates_null_stats(self) -> None:
61 """Explicit `calibrate_kinds` calibrates regardless of suite order."""
62 backend = _diverging_backend()
63 probe, spec = build_probe(
64 {
65 "name": "null",
66 "kind": "null_adapter",
67 "runs": 3,
68 "calibrate_kinds": ["delta_kl"],
69 }
70 )
71 ctx = RunContext(backend=backend)
72 result = probe.run(spec, ctx)
73 assert result.verdict == Verdict.PASS
74 stats = result.evidence["null_stats"]
75 assert "delta_kl" in stats
76 assert stats["delta_kl"]["n"] == 3.0
77 assert stats["delta_kl"]["std"] > 0.0 # seeded perturbations produce variance
78
79 def test_auto_populates_from_downstream_kinds(self) -> None:
80 """When `calibrate_kinds` is empty, falls back to `ctx.downstream_kinds`."""
81 backend = _diverging_backend()
82 probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
83 ctx = RunContext(
84 backend=backend,
85 # paraphrase_invariance opts in (stable mean_verb under null);
86 # prompt_collapse opts out (half_life is undefined under null).
87 downstream_kinds=("delta_kl", "prompt_collapse", "paraphrase_invariance"),
88 )
89 result = probe.run(spec, ctx)
90 assert result.verdict == Verdict.PASS
91 stats = result.evidence["null_stats"]
92 # Downstream numeric kinds that opt in get stats.
93 assert "delta_kl" in stats
94 assert "paraphrase_invariance" in stats
95 # prompt_collapse opts out — show up in skipped_kinds instead.
96 assert "prompt_collapse" not in stats
97 skipped = {s["kind"] for s in result.evidence["skipped_kinds"]}
98 assert "prompt_collapse" in skipped
99
100 def test_empty_calibrate_kinds_with_no_downstream_is_noop(self) -> None:
101 """No kinds, no calibration — probe still PASSes with empty stats."""
102 backend = _diverging_backend()
103 probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
104 ctx = RunContext(backend=backend) # no downstream_kinds
105 result = probe.run(spec, ctx)
106 assert result.verdict == Verdict.PASS
107 assert result.evidence["null_stats"] == {}
108 assert result.evidence["calibrated_kinds"] == []
109
110 def test_unregistered_kind_is_silently_skipped(self) -> None:
111 backend = _diverging_backend()
112 probe, spec = build_probe(
113 {
114 "name": "null",
115 "kind": "null_adapter",
116 "runs": 2,
117 "calibrate_kinds": ["delta_kl", "nonexistent_kind"],
118 }
119 )
120 ctx = RunContext(backend=backend)
121 result = probe.run(spec, ctx)
122 assert "delta_kl" in result.evidence["null_stats"]
123 assert "nonexistent_kind" not in result.evidence["null_stats"]
124
125 def test_opt_out_probe_is_reported_as_skipped(self) -> None:
126 """A kind whose calibrate_spec returns None surfaces in skipped_kinds."""
127 backend = _diverging_backend()
128 probe, spec = build_probe(
129 {
130 "name": "null",
131 "kind": "null_adapter",
132 "runs": 2,
133 # adapter_revert.calibrate_spec returns None by default
134 # (inherits from base), so we expect it to opt out.
135 "calibrate_kinds": ["adapter_revert", "delta_kl"],
136 }
137 )
138 ctx = RunContext(backend=backend)
139 result = probe.run(spec, ctx)
140 assert "delta_kl" in result.evidence["null_stats"]
141 skipped = [s["kind"] for s in result.evidence["skipped_kinds"]]
142 assert "adapter_revert" in skipped
143
144 def test_runner_threads_null_stats_to_subsequent_probes(self) -> None:
145 """End-to-end: null_adapter first → delta_kl picks up z-score path."""
146 backend = _diverging_backend()
147 raw_spec = SwaySpec.model_validate(
148 {
149 "version": 1,
150 "models": {"base": {"base": "b"}, "ft": {"base": "b", "adapter": "/tmp/a"}},
151 "suite": [
152 {
153 "name": "null",
154 "kind": "null_adapter",
155 "runs": 3,
156 },
157 {
158 "name": "dk",
159 "kind": "delta_kl",
160 "prompts": ["p1", "p2"],
161 "assert_z_gte": -10.0, # permissive so we pass regardless
162 },
163 ],
164 }
165 )
166 result = run_suite(raw_spec, backend)
167 assert len(result.probes) == 2
168 null_result = result.probes[0]
169 dk_result = result.probes[1]
170 assert null_result.verdict == Verdict.PASS
171 # The delta_kl probe should have computed a z_score because null_stats was present.
172 assert dk_result.z_score is not None, (
173 "delta_kl should have z-scored against null baseline, got "
174 f"evidence={dk_result.evidence}, message={dk_result.message}"
175 )
176
177 def test_runner_threaded_null_stats_are_immutable(self) -> None:
178 """B21: a probe shouldn't be able to mutate the stats other probes consume."""
179 from types import MappingProxyType
180
181 backend = _diverging_backend()
182 raw_spec = SwaySpec.model_validate(
183 {
184 "version": 1,
185 "models": {"base": {"base": "b"}, "ft": {"base": "b", "adapter": "/tmp/a"}},
186 "suite": [
187 {"name": "null", "kind": "null_adapter", "runs": 2},
188 {
189 "name": "dk",
190 "kind": "delta_kl",
191 "prompts": ["q1"],
192 "assert_z_gte": -100.0,
193 },
194 ],
195 }
196 )
197 from dlm_sway.probes import delta_kl as dk_mod
198
199 captured: dict[str, object] = {}
200 original_run = dk_mod.DeltaKLProbe.run
201
202 def _capturing_run(self, spec, ctx):
203 captured["null_stats"] = ctx.null_stats
204 return original_run(self, spec, ctx)
205
206 mp = pytest.MonkeyPatch()
207 mp.setattr(dk_mod.DeltaKLProbe, "run", _capturing_run)
208 try:
209 run_suite(raw_spec, backend)
210 finally:
211 mp.undo()
212
213 stats = captured["null_stats"]
214 assert isinstance(stats, MappingProxyType), (
215 f"expected MappingProxyType, got {type(stats).__name__}"
216 )
217 with pytest.raises(TypeError):
218 stats["bogus"] = {"mean": 0.0, "std": 1.0, "n": 1.0} # type: ignore[index]
219
220 def test_cache_hit_short_circuits_calibration(self, tmp_path, monkeypatch) -> None:
221 """A cached stats blob is loaded without re-running any probes."""
222 monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
223
224 class _IdBackend(DummyDifferentialBackend):
225 def cache_identity(self) -> str:
226 return "test:id-backend"
227
228 backend = _IdBackend(base=DummyResponses(), ft=DummyResponses())
229
230 # First call: populates the cache.
231 probe, spec = build_probe(
232 {
233 "name": "null",
234 "kind": "null_adapter",
235 "runs": 2,
236 "calibrate_kinds": ["delta_kl"],
237 }
238 )
239 ctx = RunContext(backend=backend)
240 r1 = probe.run(spec, ctx)
241 assert r1.evidence["from_cache"] is False
242
243 # Second call: same params, same identity → cache hit.
244 r2 = probe.run(spec, ctx)
245 assert r2.evidence["from_cache"] is True
246 assert "delta_kl" in r2.evidence["null_stats"]
247
248 def test_cache_disabled_forces_recompute(self, tmp_path, monkeypatch) -> None:
249 """``cache=false`` bypasses the cache even if a prior run populated it."""
250 monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
251
252 class _IdBackend(DummyDifferentialBackend):
253 def cache_identity(self) -> str:
254 return "test:id-backend-2"
255
256 backend = _IdBackend(base=DummyResponses(), ft=DummyResponses())
257 probe, populating_spec = build_probe(
258 {
259 "name": "null",
260 "kind": "null_adapter",
261 "runs": 2,
262 "calibrate_kinds": ["delta_kl"],
263 }
264 )
265 probe.run(populating_spec, RunContext(backend=backend))
266
267 _, fresh_spec = build_probe(
268 {
269 "name": "null",
270 "kind": "null_adapter",
271 "runs": 2,
272 "calibrate_kinds": ["delta_kl"],
273 "cache": False,
274 }
275 )
276 r = probe.run(fresh_spec, RunContext(backend=backend))
277 assert r.evidence["from_cache"] is False
278
279 def test_degenerate_calibration_flagged_and_refused(self) -> None:
280 """F02 (Audit 03): identical raws or runs≤1 → ``degenerate: 1.0``
281 in the stats dict, and the downstream z-score computation
282 refuses instead of firing on a 1e-6 floor.
283
284 Pre-F02 this test asserted ``std ≥ 1e-6`` + ``z is not None``,
285 which is exactly the contract that produced the audit's
286 +290,766σ observation on a leakage probe under ``runs: 1``.
287 The fix flips both assertions.
288 """
289 backend = _diverging_backend()
290 probe, spec = build_probe(
291 {
292 "name": "null",
293 "kind": "null_adapter",
294 "runs": 1, # single seed → degenerate by construction
295 "calibrate_kinds": ["delta_kl"],
296 }
297 )
298 ctx = RunContext(backend=backend)
299 result = probe.run(spec, ctx)
300 assert result.verdict == Verdict.PASS
301 stats = result.evidence["null_stats"]["delta_kl"]
302 # Std floor is still 1e-6 (preserved for valid-but-tight
303 # multi-seed nulls). What changed is the explicit
304 # ``degenerate`` flag on the stats dict — ``runs: 1`` → True.
305 assert stats["std"] == 1e-6
306 assert stats["degenerate"] >= 0.5
307 # Downstream z_score now refuses rather than emit runaway values.
308 from dlm_sway.probes._zscore import z_score
309
310 assert z_score(0.5, stats) is None
311
312 def test_per_kind_stats_published(self) -> None:
313 """Every calibrating kind gets its own (mean, std, n) triple."""
314 backend = _diverging_backend()
315 probe, spec = build_probe(
316 {
317 "name": "null",
318 "kind": "null_adapter",
319 "runs": 3,
320 "calibrate_kinds": ["delta_kl", "paraphrase_invariance"],
321 }
322 )
323 ctx = RunContext(backend=backend)
324 result = probe.run(spec, ctx)
325 stats = result.evidence["null_stats"]
326 for kind in ("delta_kl", "paraphrase_invariance"):
327 assert kind in stats, f"missing {kind} in published stats"
328 s = stats[kind]
329 assert "mean" in s
330 assert "std" in s
331 assert "n" in s
332 assert s["std"] >= 1e-6
333
334 def test_skip_when_backend_not_null_calibrated(self) -> None:
335 class _Bare:
336 def as_base(self): # noqa: ANN202
337 raise NotImplementedError
338
339 def as_finetuned(self): # noqa: ANN202
340 raise NotImplementedError
341
342 probe, spec = build_probe({"name": "null", "kind": "null_adapter"})
343 ctx = RunContext(backend=_Bare()) # type: ignore[arg-type]
344 result = probe.run(spec, ctx)
345 assert result.verdict == Verdict.SKIP
346 assert "NullCalibratedBackend" in result.message