tenseleyflow/sway / a50a7c2

Browse files

tests/preference_flip: cover one-bad-triple and all-fail paths (B14)

Authored by espadonne
SHA
a50a7c277d2a1e14ddbb65ee61e76bfa137b0c8c
Parents
6c5ef8e
Tree
6466947

1 changed file

StatusFile+-
M tests/unit/test_probe_preference_flip.py 97 0
tests/unit/test_probe_preference_flip.pymodified
@@ -180,6 +180,103 @@ def test_warn_branch_score_formula_pinned() -> None:
180180
     assert result.evidence["total"] == 3
181181
 
182182
 
183
+def test_one_bad_triple_does_not_kill_the_batch() -> None:
184
+    """B14: a triple that raises ProbeError is dropped, not propagated.
185
+
186
+    The remaining triples still produce a verdict; the dropped count
187
+    surfaces in evidence so a user can see what got skipped.
188
+    """
189
+    from dlm_sway.core.errors import ProbeError
190
+
191
+    backend = _backend(
192
+        [
193
+            ("p1", "good1", "bad1", -2.0, 2.0),
194
+            ("p2", "good2", "bad2", -1.5, 1.0),
195
+            ("p3", "good3", "bad3", -0.5, 0.8),
196
+        ]
197
+    )
198
+
199
+    # Wrap the backend's logprob_of so the second triple raises.
200
+    raising = {"p2"}
201
+    original_as_base = backend.as_base
202
+    original_as_finetuned = backend.as_finetuned
203
+
204
+    def _raising_view(view_cm):
205
+        from contextlib import contextmanager
206
+
207
+        @contextmanager
208
+        def _wrap():
209
+            with view_cm() as view:
210
+                orig = view.logprob_of
211
+
212
+                def fenced(prompt, completion):
213
+                    if prompt in raising:
214
+                        raise ProbeError("logprob_of", f"simulated failure on {prompt!r}")
215
+                    return orig(prompt, completion)
216
+
217
+                view.logprob_of = fenced  # type: ignore[method-assign]
218
+                yield view
219
+
220
+        return _wrap
221
+
222
+    backend.as_base = _raising_view(original_as_base)  # type: ignore[method-assign]
223
+    backend.as_finetuned = _raising_view(original_as_finetuned)  # type: ignore[method-assign]
224
+
225
+    triples = [
226
+        {"prompt": p, "chosen": c, "rejected": r}
227
+        for p, c, r in [("p1", "good1", "bad1"), ("p2", "good2", "bad2"), ("p3", "good3", "bad3")]
228
+    ]
229
+    probe, spec = build_probe(
230
+        {
231
+            "name": "pf",
232
+            "kind": "preference_flip",
233
+            "triples": triples,
234
+            "assert_flip_rate_gte": 0.7,
235
+            "min_triples_for_decision": 2,
236
+        }
237
+    )
238
+    ctx = RunContext(backend=backend)
239
+    result = probe.run(spec, ctx)
240
+
241
+    assert result.verdict == Verdict.PASS  # the two surviving triples both flipped
242
+    assert result.evidence["dropped_triples"] == 1
243
+    assert any("p2" in reason for reason in result.evidence["dropped_reasons"])
244
+
245
+
246
+def test_all_triples_failing_yields_error() -> None:
247
+    """When every triple raises, the probe routes to ERROR with an explanation."""
248
+    from contextlib import contextmanager
249
+
250
+    from dlm_sway.core.errors import ProbeError
251
+
252
+    backend = _backend([("p1", "g", "b", 0.0, 0.0)])
253
+    inner_as_base = backend.as_base  # capture before monkeypatching
254
+
255
+    @contextmanager
256
+    def _always_raise():
257
+        with inner_as_base() as view:
258
+
259
+            def _raises(*_a, **_k):
260
+                raise ProbeError("logprob_of", "always")
261
+
262
+            view.logprob_of = _raises  # type: ignore[method-assign]
263
+            yield view
264
+
265
+    backend.as_base = _always_raise  # type: ignore[method-assign]
266
+    backend.as_finetuned = _always_raise  # type: ignore[method-assign]
267
+
268
+    probe, spec = build_probe(
269
+        {
270
+            "name": "pf",
271
+            "kind": "preference_flip",
272
+            "triples": [{"prompt": "p1", "chosen": "g", "rejected": "b"}],
273
+        }
274
+    )
275
+    result = probe.run(spec, RunContext(backend=backend))
276
+    assert result.verdict == Verdict.ERROR
277
+    assert result.evidence["dropped_triples"] == 1
278
+
279
+
183280
 def test_triples_pulled_from_sections() -> None:
184281
     pref_section = Section(
185282
         id="p1",