Python · 10276 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.preference_flip`."""
2
3 from __future__ import annotations
4
5 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6 from dlm_sway.core.result import Verdict
7 from dlm_sway.core.sections import Section, SectionPreference
8 from dlm_sway.probes.base import RunContext, build_probe
9
10
11 def _backend(pairs: list[tuple[str, str, str, float, float]]) -> DummyDifferentialBackend:
12 """``pairs`` = list of (prompt, chosen, rejected, base_margin, ft_margin).
13
14 We distribute the margin half to the chosen and half (negative) to
15 the rejected, which is enough to make logprob_of(chosen)-logprob_of(rejected)
16 equal the requested margin.
17 """
18 base_lp: dict[tuple[str, str], float] = {}
19 ft_lp: dict[tuple[str, str], float] = {}
20 for prompt, chosen, rejected, base_m, ft_m in pairs:
21 base_lp[(prompt, chosen)] = base_m / 2
22 base_lp[(prompt, rejected)] = -base_m / 2
23 ft_lp[(prompt, chosen)] = ft_m / 2
24 ft_lp[(prompt, rejected)] = -ft_m / 2
25 return DummyDifferentialBackend(
26 base=DummyResponses(logprobs=base_lp),
27 ft=DummyResponses(logprobs=ft_lp),
28 )
29
30
31 def test_pass_when_base_wrong_flipped() -> None:
32 backend = _backend(
33 [
34 ("p1", "good1", "bad1", -2.0, 2.0), # base wrong, ft flips
35 ("p2", "good2", "bad2", -1.5, 1.0), # base wrong, ft flips
36 ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips
37 ("p4", "good4", "bad4", 1.0, 2.0), # base already right (no contribution)
38 ]
39 )
40 triples = [
41 {"prompt": p, "chosen": c, "rejected": r}
42 for (p, c, r, _, _) in [
43 ("p1", "good1", "bad1", 0, 0),
44 ("p2", "good2", "bad2", 0, 0),
45 ("p3", "good3", "bad3", 0, 0),
46 ("p4", "good4", "bad4", 0, 0),
47 ]
48 ]
49 probe, spec = build_probe(
50 {
51 "name": "pf",
52 "kind": "preference_flip",
53 "triples": triples,
54 "assert_flip_rate_gte": 0.7,
55 "min_triples_for_decision": 3,
56 }
57 )
58 ctx = RunContext(backend=backend)
59 result = probe.run(spec, ctx)
60 assert result.verdict == Verdict.PASS
61 assert result.raw == 1.0 # 3/3 flipped
62
63
64 def test_fail_when_base_wrong_not_flipped() -> None:
65 backend = _backend(
66 [
67 ("p1", "good1", "bad1", -2.0, -1.5), # base wrong, ft still wrong
68 ("p2", "good2", "bad2", -1.5, -1.0), # base wrong, ft still wrong
69 ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips
70 ]
71 )
72 triples = [
73 {"prompt": p, "chosen": c, "rejected": r}
74 for p, c, r in [
75 ("p1", "good1", "bad1"),
76 ("p2", "good2", "bad2"),
77 ("p3", "good3", "bad3"),
78 ]
79 ]
80 probe, spec = build_probe(
81 {
82 "name": "pf",
83 "kind": "preference_flip",
84 "triples": triples,
85 "assert_flip_rate_gte": 0.7,
86 "min_triples_for_decision": 3,
87 }
88 )
89 ctx = RunContext(backend=backend)
90 result = probe.run(spec, ctx)
91 assert result.verdict == Verdict.FAIL
92 assert result.raw is not None
93 assert result.raw < 0.7
94
95
96 def test_skip_when_no_triples_anywhere() -> None:
97 probe, spec = build_probe({"name": "pf", "kind": "preference_flip"})
98 backend = _backend([])
99 ctx = RunContext(backend=backend)
100 result = probe.run(spec, ctx)
101 assert result.verdict == Verdict.SKIP
102
103
104 def test_warn_when_too_few_base_wrong() -> None:
105 backend = _backend(
106 [
107 ("p1", "good1", "bad1", 1.0, 2.0), # base right
108 ("p2", "good2", "bad2", 0.5, 1.0), # base right
109 ("p3", "good3", "bad3", -0.5, 0.5), # base wrong
110 ]
111 )
112 triples = [
113 {"prompt": p, "chosen": c, "rejected": r}
114 for p, c, r in [
115 ("p1", "good1", "bad1"),
116 ("p2", "good2", "bad2"),
117 ("p3", "good3", "bad3"),
118 ]
119 ]
120 probe, spec = build_probe(
121 {
122 "name": "pf",
123 "kind": "preference_flip",
124 "triples": triples,
125 "min_triples_for_decision": 3,
126 }
127 )
128 ctx = RunContext(backend=backend)
129 result = probe.run(spec, ctx)
130 assert result.verdict == Verdict.WARN
131
132
133 def test_warn_branch_score_formula_pinned() -> None:
134 """C10: pin the WARN-branch numerical formula so a refactor notices.
135
136 Formula: ``score = clip(0.5 + mean_delta / 4.0, 0, 1)`` where
137 ``mean_delta`` is the average of ``(ft_margin - base_margin)`` over
138 every triple. With deltas [+1.0, +0.5, +1.0] the mean is 0.8333…
139 so the score is 0.5 + 0.8333.../4 = 0.70833….
140 """
141 import math
142
143 backend = _backend(
144 [
145 ("p1", "good1", "bad1", 1.0, 2.0), # delta=+1.0 (base right)
146 ("p2", "good2", "bad2", 0.5, 1.0), # delta=+0.5 (base right)
147 ("p3", "good3", "bad3", -0.5, 0.5), # delta=+1.0 (base wrong)
148 ]
149 )
150 triples = [
151 {"prompt": p, "chosen": c, "rejected": r}
152 for p, c, r in [
153 ("p1", "good1", "bad1"),
154 ("p2", "good2", "bad2"),
155 ("p3", "good3", "bad3"),
156 ]
157 ]
158 probe, spec = build_probe(
159 {
160 "name": "pf",
161 "kind": "preference_flip",
162 "triples": triples,
163 "min_triples_for_decision": 3,
164 }
165 )
166 ctx = RunContext(backend=backend)
167 result = probe.run(spec, ctx)
168
169 assert result.verdict == Verdict.WARN
170 expected_mean_delta = (1.0 + 0.5 + 1.0) / 3.0
171 expected_score = 0.5 + expected_mean_delta / 4.0
172 assert result.raw is not None
173 assert math.isclose(result.raw, expected_mean_delta, rel_tol=1e-9)
174 assert result.score is not None
175 assert math.isclose(result.score, expected_score, rel_tol=1e-9)
176
177 # Evidence mirrors the raw metric so report consumers can render it.
178 assert math.isclose(result.evidence["mean_margin_delta"], expected_mean_delta, rel_tol=1e-9)
179 assert result.evidence["base_wrong"] == 1
180 assert result.evidence["total"] == 3
181
182
183 def test_one_bad_triple_does_not_kill_the_batch() -> None:
184 """B14: a triple that raises ProbeError is dropped, not propagated.
185
186 The remaining triples still produce a verdict; the dropped count
187 surfaces in evidence so a user can see what got skipped.
188 """
189 from dlm_sway.core.errors import ProbeError
190
191 backend = _backend(
192 [
193 ("p1", "good1", "bad1", -2.0, 2.0),
194 ("p2", "good2", "bad2", -1.5, 1.0),
195 ("p3", "good3", "bad3", -0.5, 0.8),
196 ]
197 )
198
199 # Wrap the backend's logprob_of so the second triple raises.
200 raising = {"p2"}
201 original_as_base = backend.as_base
202 original_as_finetuned = backend.as_finetuned
203
204 def _raising_view(view_cm):
205 from contextlib import contextmanager
206
207 @contextmanager
208 def _wrap():
209 with view_cm() as view:
210 orig = view.logprob_of
211
212 def fenced(prompt, completion):
213 if prompt in raising:
214 raise ProbeError("logprob_of", f"simulated failure on {prompt!r}")
215 return orig(prompt, completion)
216
217 view.logprob_of = fenced # type: ignore[method-assign]
218 yield view
219
220 return _wrap
221
222 backend.as_base = _raising_view(original_as_base) # type: ignore[method-assign]
223 backend.as_finetuned = _raising_view(original_as_finetuned) # type: ignore[method-assign]
224
225 triples = [
226 {"prompt": p, "chosen": c, "rejected": r}
227 for p, c, r in [("p1", "good1", "bad1"), ("p2", "good2", "bad2"), ("p3", "good3", "bad3")]
228 ]
229 probe, spec = build_probe(
230 {
231 "name": "pf",
232 "kind": "preference_flip",
233 "triples": triples,
234 "assert_flip_rate_gte": 0.7,
235 "min_triples_for_decision": 2,
236 }
237 )
238 ctx = RunContext(backend=backend)
239 result = probe.run(spec, ctx)
240
241 assert result.verdict == Verdict.PASS # the two surviving triples both flipped
242 assert result.evidence["dropped_triples"] == 1
243 assert any("p2" in reason for reason in result.evidence["dropped_reasons"])
244
245
246 def test_all_triples_failing_yields_error() -> None:
247 """When every triple raises, the probe routes to ERROR with an explanation."""
248 from contextlib import contextmanager
249
250 from dlm_sway.core.errors import ProbeError
251
252 backend = _backend([("p1", "g", "b", 0.0, 0.0)])
253 inner_as_base = backend.as_base # capture before monkeypatching
254
255 @contextmanager
256 def _always_raise():
257 with inner_as_base() as view:
258
259 def _raises(*_a, **_k):
260 raise ProbeError("logprob_of", "always")
261
262 view.logprob_of = _raises # type: ignore[method-assign]
263 yield view
264
265 backend.as_base = _always_raise # type: ignore[method-assign]
266 backend.as_finetuned = _always_raise # type: ignore[method-assign]
267
268 probe, spec = build_probe(
269 {
270 "name": "pf",
271 "kind": "preference_flip",
272 "triples": [{"prompt": "p1", "chosen": "g", "rejected": "b"}],
273 }
274 )
275 result = probe.run(spec, RunContext(backend=backend))
276 assert result.verdict == Verdict.ERROR
277 assert result.evidence["dropped_triples"] == 1
278
279
280 def test_triples_pulled_from_sections() -> None:
281 pref_section = Section(
282 id="p1",
283 kind="preference",
284 content="...",
285 preferences=(
286 SectionPreference(prompt="q1", chosen="good", rejected="bad"),
287 SectionPreference(prompt="q2", chosen="good2", rejected="bad2"),
288 SectionPreference(prompt="q3", chosen="good3", rejected="bad3"),
289 ),
290 )
291 backend = _backend(
292 [
293 ("q1", "good", "bad", -1.0, 1.0),
294 ("q2", "good2", "bad2", -1.0, 1.0),
295 ("q3", "good3", "bad3", -1.0, 1.0),
296 ]
297 )
298 probe, spec = build_probe(
299 {
300 "name": "pf",
301 "kind": "preference_flip",
302 "assert_flip_rate_gte": 0.7,
303 "min_triples_for_decision": 3,
304 }
305 )
306 ctx = RunContext(backend=backend, sections=(pref_section,))
307 result = probe.run(spec, ctx)
308 assert result.verdict == Verdict.PASS