Python · 8077 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.delta_kl`."""
2
3 from __future__ import annotations
4
5 import numpy as np
6
7 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
8 from dlm_sway.core.result import Verdict
9 from dlm_sway.core.scoring import TokenDist
10 from dlm_sway.probes.base import RunContext, build_probe
11
12
13 def _diverging_backend() -> DummyDifferentialBackend:
14 """Base peaks tightly on token 1; ft is broad uniform. Real divergence."""
15 base = DummyResponses(
16 token_dists={
17 "q1": TokenDist(
18 token_ids=np.array([1, 2, 3], dtype=np.int64),
19 logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)),
20 vocab_size=100,
21 ),
22 "q2": TokenDist(
23 token_ids=np.array([5, 6], dtype=np.int64),
24 logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)),
25 vocab_size=100,
26 ),
27 }
28 )
29 ft = DummyResponses(
30 token_dists={
31 "q1": TokenDist(
32 token_ids=np.array([1, 2, 3], dtype=np.int64),
33 logprobs=np.log(np.array([0.3, 0.35, 0.35], dtype=np.float32)),
34 vocab_size=100,
35 ),
36 "q2": TokenDist(
37 token_ids=np.array([5, 6], dtype=np.int64),
38 logprobs=np.log(np.array([0.4, 0.6], dtype=np.float32)),
39 vocab_size=100,
40 ),
41 }
42 )
43 return DummyDifferentialBackend(base=base, ft=ft)
44
45
46 def _identical_backend() -> DummyDifferentialBackend:
47 dist = TokenDist(
48 token_ids=np.array([1, 2, 3], dtype=np.int64),
49 logprobs=np.log(np.array([0.5, 0.3, 0.2], dtype=np.float32)),
50 vocab_size=100,
51 )
52 base = DummyResponses(token_dists={"q1": dist})
53 ft = DummyResponses(token_dists={"q1": dist})
54 return DummyDifferentialBackend(base=base, ft=ft)
55
56
57 class TestDeltaKL:
58 def test_passes_when_distributions_diverge(self) -> None:
59 probe, spec = build_probe(
60 {
61 "name": "dk",
62 "kind": "delta_kl",
63 "prompts": ["q1", "q2"],
64 "assert_mean_gte": 0.01,
65 }
66 )
67 ctx = RunContext(backend=_diverging_backend())
68 result = probe.run(spec, ctx)
69 assert result.verdict == Verdict.PASS
70 assert result.raw is not None
71 assert result.raw > 0.01
72 assert result.evidence["num_prompts"] == 2
73 assert len(result.evidence["per_prompt"]) == 2
74
75 def test_fails_when_distributions_identical(self) -> None:
76 probe, spec = build_probe(
77 {
78 "name": "dk",
79 "kind": "delta_kl",
80 "prompts": ["q1"],
81 "assert_mean_gte": 0.01,
82 }
83 )
84 ctx = RunContext(backend=_identical_backend())
85 result = probe.run(spec, ctx)
86 assert result.verdict == Verdict.FAIL
87 assert result.raw == 0.0
88
89 def test_z_score_path_when_null_stats_present(self) -> None:
90 probe, spec = build_probe(
91 {
92 "name": "dk",
93 "kind": "delta_kl",
94 "prompts": ["q1"],
95 "assert_z_gte": 2.0,
96 }
97 )
98 null_stats = {"delta_kl": {"mean": 0.01, "std": 0.01, "n": 3.0}}
99 ctx = RunContext(backend=_diverging_backend(), null_stats=null_stats)
100 result = probe.run(spec, ctx)
101 assert result.z_score is not None
102 # Our synthetic ft diverges ~0.1+, far above μ=0.01, σ=0.01 → huge z.
103 assert result.z_score > 2.0
104 assert result.verdict == Verdict.PASS
105
106 def test_error_on_empty_prompts(self) -> None:
107 probe, spec = build_probe({"name": "dk", "kind": "delta_kl", "prompts": []})
108 ctx = RunContext(backend=_identical_backend())
109 result = probe.run(spec, ctx)
110 assert result.verdict == Verdict.ERROR
111
112 def test_kl_kind_available(self) -> None:
113 probe, spec = build_probe(
114 {
115 "name": "dk",
116 "kind": "delta_kl",
117 "prompts": ["q1"],
118 "divergence": "kl",
119 "assert_mean_gte": 0.0,
120 }
121 )
122 ctx = RunContext(backend=_diverging_backend())
123 result = probe.run(spec, ctx)
124 assert result.evidence["divergence_kind"] == "kl"
125
126
127 class TestB1NanLogprobsRouteToError:
128 """S01 regression: NaN logprobs must NEVER produce a passing z-score.
129
130 The historical bug made this pass at +11639σ. Two pins here:
131
132 1. ``probe.run()`` raises ``ProbeError`` when ``_divergence`` sees NaN
133 (unit-level: the probe surfaces the failure).
134 2. When routed through the suite runner, the ProbeError turns into
135 ``Verdict.ERROR`` (integration-level: the product contract — no
136 silent PASS on broken models).
137 """
138
139 @staticmethod
140 def _nan_backend() -> DummyDifferentialBackend:
141 """Backend whose ft view has NaN-laden TokenDist."""
142 import math
143
144 base = DummyResponses(
145 token_dists={
146 "q1": TokenDist(
147 token_ids=np.array([1, 2], dtype=np.int64),
148 logprobs=np.log(np.array([0.9, 0.1], dtype=np.float32)),
149 vocab_size=100,
150 )
151 }
152 )
153 ft = DummyResponses(
154 token_dists={
155 "q1": TokenDist(
156 token_ids=np.array([1, 2], dtype=np.int64),
157 logprobs=np.array([math.nan, math.nan], dtype=np.float32),
158 vocab_size=100,
159 )
160 }
161 )
162 return DummyDifferentialBackend(base=base, ft=ft)
163
164 def test_probe_raises_probe_error_on_nan_logprobs(self) -> None:
165 import pytest
166
167 from dlm_sway.core.errors import ProbeError
168
169 probe, spec = build_probe(
170 {
171 "name": "dk",
172 "kind": "delta_kl",
173 "prompts": ["q1"],
174 "assert_mean_gte": 0.001,
175 }
176 )
177 ctx = RunContext(backend=self._nan_backend())
178 with pytest.raises(ProbeError, match="non-finite"):
179 probe.run(spec, ctx)
180
181 def test_runner_converts_nan_probe_error_to_verdict_error(self) -> None:
182 """Integration: the suite runner catches the ProbeError and emits
183 ERROR, not a bogus PASS. This is the product-level invariant."""
184 from dlm_sway.suite.runner import run as run_suite
185 from dlm_sway.suite.spec import SwaySpec
186
187 spec = SwaySpec.model_validate(
188 {
189 "version": 1,
190 "models": {
191 "base": {"base": "b"},
192 "ft": {"base": "b", "adapter": "/tmp/a"},
193 },
194 "suite": [
195 {
196 "name": "dk",
197 "kind": "delta_kl",
198 "prompts": ["q1"],
199 "assert_mean_gte": 0.001,
200 }
201 ],
202 }
203 )
204 # Pre-seed the preflight prompt so the backend preflight doesn't
205 # short-circuit before the real delta_kl probe runs.
206 backend = self._nan_backend()
207 backend._base_r.token_dists["preflight"] = TokenDist(
208 token_ids=np.array([1, 2], dtype=np.int64),
209 logprobs=np.log(np.array([0.5, 0.5], dtype=np.float32)),
210 vocab_size=100,
211 )
212 backend._ft_r.token_dists["preflight"] = TokenDist(
213 token_ids=np.array([1, 2], dtype=np.int64),
214 logprobs=np.log(np.array([0.5, 0.5], dtype=np.float32)),
215 vocab_size=100,
216 )
217 result = run_suite(spec, backend)
218 # Exactly one probe (delta_kl), verdict ERROR.
219 delta_kl_probe = next(r for r in result.probes if r.kind == "delta_kl")
220 assert delta_kl_probe.verdict == Verdict.ERROR
221 assert "non-finite" in delta_kl_probe.message.lower()
222 # No PASS in the entire suite.
223 assert not any(r.verdict == Verdict.PASS for r in result.probes)