Python · 7260 bytes Raw Blame History
1 """Tests for null-adapter calibration.
2
3 Covers: dummy backend ``as_null_adapter`` yields a plausibly noisy
4 view; ``NullAdapterProbe`` populates ``ctx.null_stats`` in a way
5 downstream probes pick up end-to-end; missing-capability SKIP path.
6 """
7
8 from __future__ import annotations
9
10 import numpy as np
11
12 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
13 from dlm_sway.core.result import Verdict
14 from dlm_sway.core.scoring import NullCalibratedBackend
15 from dlm_sway.probes.base import RunContext, build_probe
16 from dlm_sway.suite.runner import run as run_suite
17 from dlm_sway.suite.spec import SwaySpec
18
19
20 def _diverging_backend() -> DummyDifferentialBackend:
21 base = DummyResponses()
22 ft = DummyResponses()
23 return DummyDifferentialBackend(base=base, ft=ft)
24
25
26 class TestProtocolConformance:
27 def test_dummy_is_null_calibrated(self) -> None:
28 assert isinstance(_diverging_backend(), NullCalibratedBackend)
29
30
31 class TestAsNullAdapter:
32 def test_yields_perturbed_view(self) -> None:
33 backend = _diverging_backend()
34 with backend.as_base() as base:
35 base_dist = base.next_token_dist("hello")
36 with backend.as_null_adapter(seed=0) as null:
37 null_dist = null.next_token_dist("hello")
38 # Some perturbation, but bounded.
39 assert not np.allclose(base_dist.logprobs, null_dist.logprobs)
40
41 def test_different_seeds_yield_different_views(self) -> None:
42 backend = _diverging_backend()
43 with backend.as_null_adapter(seed=1) as v1:
44 d1 = v1.next_token_dist("hello")
45 with backend.as_null_adapter(seed=2) as v2:
46 d2 = v2.next_token_dist("hello")
47 assert not np.allclose(d1.logprobs, d2.logprobs)
48
49 def test_view_exclusion_enforced(self) -> None:
50 import pytest
51
52 backend = _diverging_backend()
53 with backend.as_null_adapter(seed=0), pytest.raises(RuntimeError):
54 with backend.as_base():
55 pass
56
57
58 class TestProbe:
59 def test_populates_null_stats(self) -> None:
60 """Explicit `calibrate_kinds` calibrates regardless of suite order."""
61 backend = _diverging_backend()
62 probe, spec = build_probe(
63 {
64 "name": "null",
65 "kind": "null_adapter",
66 "runs": 3,
67 "calibrate_kinds": ["delta_kl"],
68 }
69 )
70 ctx = RunContext(backend=backend)
71 result = probe.run(spec, ctx)
72 assert result.verdict == Verdict.PASS
73 stats = result.evidence["null_stats"]
74 assert "delta_kl" in stats
75 assert stats["delta_kl"]["n"] == 3.0
76 assert stats["delta_kl"]["std"] > 0.0 # seeded perturbations produce variance
77
78 def test_auto_populates_from_downstream_kinds(self) -> None:
79 """When `calibrate_kinds` is empty, falls back to `ctx.downstream_kinds`."""
80 backend = _diverging_backend()
81 probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
82 ctx = RunContext(
83 backend=backend,
84 downstream_kinds=("delta_kl", "prompt_collapse"),
85 )
86 result = probe.run(spec, ctx)
87 assert result.verdict == Verdict.PASS
88 stats = result.evidence["null_stats"]
89 # Every downstream numeric kind that opts in gets stats.
90 assert "delta_kl" in stats
91 assert "prompt_collapse" in stats
92
93 def test_empty_calibrate_kinds_with_no_downstream_is_noop(self) -> None:
94 """No kinds, no calibration — probe still PASSes with empty stats."""
95 backend = _diverging_backend()
96 probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
97 ctx = RunContext(backend=backend) # no downstream_kinds
98 result = probe.run(spec, ctx)
99 assert result.verdict == Verdict.PASS
100 assert result.evidence["null_stats"] == {}
101 assert result.evidence["calibrated_kinds"] == []
102
103 def test_unregistered_kind_is_silently_skipped(self) -> None:
104 backend = _diverging_backend()
105 probe, spec = build_probe(
106 {
107 "name": "null",
108 "kind": "null_adapter",
109 "runs": 2,
110 "calibrate_kinds": ["delta_kl", "nonexistent_kind"],
111 }
112 )
113 ctx = RunContext(backend=backend)
114 result = probe.run(spec, ctx)
115 assert "delta_kl" in result.evidence["null_stats"]
116 assert "nonexistent_kind" not in result.evidence["null_stats"]
117
118 def test_opt_out_probe_is_reported_as_skipped(self) -> None:
119 """A kind whose calibrate_spec returns None surfaces in skipped_kinds."""
120 backend = _diverging_backend()
121 probe, spec = build_probe(
122 {
123 "name": "null",
124 "kind": "null_adapter",
125 "runs": 2,
126 # adapter_revert.calibrate_spec returns None by default
127 # (inherits from base), so we expect it to opt out.
128 "calibrate_kinds": ["adapter_revert", "delta_kl"],
129 }
130 )
131 ctx = RunContext(backend=backend)
132 result = probe.run(spec, ctx)
133 assert "delta_kl" in result.evidence["null_stats"]
134 skipped = [s["kind"] for s in result.evidence["skipped_kinds"]]
135 assert "adapter_revert" in skipped
136
137 def test_runner_threads_null_stats_to_subsequent_probes(self) -> None:
138 """End-to-end: null_adapter first → delta_kl picks up z-score path."""
139 backend = _diverging_backend()
140 raw_spec = SwaySpec.model_validate(
141 {
142 "version": 1,
143 "models": {"base": {"base": "b"}, "ft": {"base": "b", "adapter": "/tmp/a"}},
144 "suite": [
145 {
146 "name": "null",
147 "kind": "null_adapter",
148 "runs": 3,
149 },
150 {
151 "name": "dk",
152 "kind": "delta_kl",
153 "prompts": ["p1", "p2"],
154 "assert_z_gte": -10.0, # permissive so we pass regardless
155 },
156 ],
157 }
158 )
159 result = run_suite(raw_spec, backend)
160 assert len(result.probes) == 2
161 null_result = result.probes[0]
162 dk_result = result.probes[1]
163 assert null_result.verdict == Verdict.PASS
164 # The delta_kl probe should have computed a z_score because null_stats was present.
165 assert dk_result.z_score is not None, (
166 "delta_kl should have z-scored against null baseline, got "
167 f"evidence={dk_result.evidence}, message={dk_result.message}"
168 )
169
170 def test_skip_when_backend_not_null_calibrated(self) -> None:
171 class _Bare:
172 def as_base(self): # noqa: ANN202
173 raise NotImplementedError
174
175 def as_finetuned(self): # noqa: ANN202
176 raise NotImplementedError
177
178 probe, spec = build_probe({"name": "null", "kind": "null_adapter"})
179 ctx = RunContext(backend=_Bare()) # type: ignore[arg-type]
180 result = probe.run(spec, ctx)
181 assert result.verdict == Verdict.SKIP
182 assert "NullCalibratedBackend" in result.message