tenseleyflow/sway / 0652fdb

Browse files

tests: delta_kl NaN-routes-to-error at both probe and runner levels (B1 regression)

Authored by espadonne
SHA
0652fdb487fc3921fc6df404271e9734a33a2c30
Parents
8b61fb7
Tree
52fdc23

1 changed file

StatusFile+-
M tests/unit/test_probe_delta_kl.py 99 0
tests/unit/test_probe_delta_kl.pymodified
@@ -122,3 +122,102 @@ class TestDeltaKL:
122122
         ctx = RunContext(backend=_diverging_backend())
123123
         result = probe.run(spec, ctx)
124124
         assert result.evidence["divergence_kind"] == "kl"
125
+
126
+
127
+class TestB1NanLogprobsRouteToError:
128
+    """S01 regression: NaN logprobs must NEVER produce a passing z-score.
129
+
130
+    The historical bug made this pass at +11639σ. Two pins here:
131
+
132
+    1. ``probe.run()`` raises ``ProbeError`` when ``_divergence`` sees NaN
133
+       (unit-level: the probe surfaces the failure).
134
+    2. When routed through the suite runner, the ProbeError turns into
135
+       ``Verdict.ERROR`` (integration-level: the product contract — no
136
+       silent PASS on broken models).
137
+    """
138
+
139
+    @staticmethod
140
+    def _nan_backend() -> DummyDifferentialBackend:
141
+        """Backend whose ft view has NaN-laden TokenDist."""
142
+        import math
143
+
144
+        base = DummyResponses(
145
+            token_dists={
146
+                "q1": TokenDist(
147
+                    token_ids=np.array([1, 2], dtype=np.int64),
148
+                    logprobs=np.log(np.array([0.9, 0.1], dtype=np.float32)),
149
+                    vocab_size=100,
150
+                )
151
+            }
152
+        )
153
+        ft = DummyResponses(
154
+            token_dists={
155
+                "q1": TokenDist(
156
+                    token_ids=np.array([1, 2], dtype=np.int64),
157
+                    logprobs=np.array([math.nan, math.nan], dtype=np.float32),
158
+                    vocab_size=100,
159
+                )
160
+            }
161
+        )
162
+        return DummyDifferentialBackend(base=base, ft=ft)
163
+
164
+    def test_probe_raises_probe_error_on_nan_logprobs(self) -> None:
165
+        import pytest
166
+
167
+        from dlm_sway.core.errors import ProbeError
168
+
169
+        probe, spec = build_probe(
170
+            {
171
+                "name": "dk",
172
+                "kind": "delta_kl",
173
+                "prompts": ["q1"],
174
+                "assert_mean_gte": 0.001,
175
+            }
176
+        )
177
+        ctx = RunContext(backend=self._nan_backend())
178
+        with pytest.raises(ProbeError, match="non-finite"):
179
+            probe.run(spec, ctx)
180
+
181
+    def test_runner_converts_nan_probe_error_to_verdict_error(self) -> None:
182
+        """Integration: the suite runner catches the ProbeError and emits
183
+        ERROR, not a bogus PASS. This is the product-level invariant."""
184
+        from dlm_sway.suite.runner import run as run_suite
185
+        from dlm_sway.suite.spec import SwaySpec
186
+
187
+        spec = SwaySpec.model_validate(
188
+            {
189
+                "version": 1,
190
+                "models": {
191
+                    "base": {"base": "b"},
192
+                    "ft": {"base": "b", "adapter": "/tmp/a"},
193
+                },
194
+                "suite": [
195
+                    {
196
+                        "name": "dk",
197
+                        "kind": "delta_kl",
198
+                        "prompts": ["q1"],
199
+                        "assert_mean_gte": 0.001,
200
+                    }
201
+                ],
202
+            }
203
+        )
204
+        # Pre-seed the preflight prompt so the backend preflight doesn't
205
+        # short-circuit before the real delta_kl probe runs.
206
+        backend = self._nan_backend()
207
+        backend._base_r.token_dists["preflight"] = TokenDist(
208
+            token_ids=np.array([1, 2], dtype=np.int64),
209
+            logprobs=np.log(np.array([0.5, 0.5], dtype=np.float32)),
210
+            vocab_size=100,
211
+        )
212
+        backend._ft_r.token_dists["preflight"] = TokenDist(
213
+            token_ids=np.array([1, 2], dtype=np.int64),
214
+            logprobs=np.log(np.array([0.5, 0.5], dtype=np.float32)),
215
+            vocab_size=100,
216
+        )
217
+        result = run_suite(spec, backend)
218
+        # Exactly one probe (delta_kl), verdict ERROR.
219
+        delta_kl_probe = next(r for r in result.probes if r.kind == "delta_kl")
220
+        assert delta_kl_probe.verdict == Verdict.ERROR
221
+        assert "non-finite" in delta_kl_probe.message.lower()
222
+        # No PASS in the entire suite.
223
+        assert not any(r.verdict == Verdict.PASS for r in result.probes)