| 1 | """Tests for :mod:`dlm_sway.probes.preference_flip`.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 6 | from dlm_sway.core.result import Verdict |
| 7 | from dlm_sway.core.sections import Section, SectionPreference |
| 8 | from dlm_sway.probes.base import RunContext, build_probe |
| 9 | |
| 10 | |
| 11 | def _backend(pairs: list[tuple[str, str, str, float, float]]) -> DummyDifferentialBackend: |
| 12 | """``pairs`` = list of (prompt, chosen, rejected, base_margin, ft_margin). |
| 13 | |
| 14 | We distribute the margin half to the chosen and half (negative) to |
| 15 | the rejected, which is enough to make logprob_of(chosen)-logprob_of(rejected) |
| 16 | equal the requested margin. |
| 17 | """ |
| 18 | base_lp: dict[tuple[str, str], float] = {} |
| 19 | ft_lp: dict[tuple[str, str], float] = {} |
| 20 | for prompt, chosen, rejected, base_m, ft_m in pairs: |
| 21 | base_lp[(prompt, chosen)] = base_m / 2 |
| 22 | base_lp[(prompt, rejected)] = -base_m / 2 |
| 23 | ft_lp[(prompt, chosen)] = ft_m / 2 |
| 24 | ft_lp[(prompt, rejected)] = -ft_m / 2 |
| 25 | return DummyDifferentialBackend( |
| 26 | base=DummyResponses(logprobs=base_lp), |
| 27 | ft=DummyResponses(logprobs=ft_lp), |
| 28 | ) |
| 29 | |
| 30 | |
| 31 | def test_pass_when_base_wrong_flipped() -> None: |
| 32 | backend = _backend( |
| 33 | [ |
| 34 | ("p1", "good1", "bad1", -2.0, 2.0), # base wrong, ft flips |
| 35 | ("p2", "good2", "bad2", -1.5, 1.0), # base wrong, ft flips |
| 36 | ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips |
| 37 | ("p4", "good4", "bad4", 1.0, 2.0), # base already right (no contribution) |
| 38 | ] |
| 39 | ) |
| 40 | triples = [ |
| 41 | {"prompt": p, "chosen": c, "rejected": r} |
| 42 | for (p, c, r, _, _) in [ |
| 43 | ("p1", "good1", "bad1", 0, 0), |
| 44 | ("p2", "good2", "bad2", 0, 0), |
| 45 | ("p3", "good3", "bad3", 0, 0), |
| 46 | ("p4", "good4", "bad4", 0, 0), |
| 47 | ] |
| 48 | ] |
| 49 | probe, spec = build_probe( |
| 50 | { |
| 51 | "name": "pf", |
| 52 | "kind": "preference_flip", |
| 53 | "triples": triples, |
| 54 | "assert_flip_rate_gte": 0.7, |
| 55 | "min_triples_for_decision": 3, |
| 56 | } |
| 57 | ) |
| 58 | ctx = RunContext(backend=backend) |
| 59 | result = probe.run(spec, ctx) |
| 60 | assert result.verdict == Verdict.PASS |
| 61 | assert result.raw == 1.0 # 3/3 flipped |
| 62 | |
| 63 | |
| 64 | def test_fail_when_base_wrong_not_flipped() -> None: |
| 65 | backend = _backend( |
| 66 | [ |
| 67 | ("p1", "good1", "bad1", -2.0, -1.5), # base wrong, ft still wrong |
| 68 | ("p2", "good2", "bad2", -1.5, -1.0), # base wrong, ft still wrong |
| 69 | ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips |
| 70 | ] |
| 71 | ) |
| 72 | triples = [ |
| 73 | {"prompt": p, "chosen": c, "rejected": r} |
| 74 | for p, c, r in [ |
| 75 | ("p1", "good1", "bad1"), |
| 76 | ("p2", "good2", "bad2"), |
| 77 | ("p3", "good3", "bad3"), |
| 78 | ] |
| 79 | ] |
| 80 | probe, spec = build_probe( |
| 81 | { |
| 82 | "name": "pf", |
| 83 | "kind": "preference_flip", |
| 84 | "triples": triples, |
| 85 | "assert_flip_rate_gte": 0.7, |
| 86 | "min_triples_for_decision": 3, |
| 87 | } |
| 88 | ) |
| 89 | ctx = RunContext(backend=backend) |
| 90 | result = probe.run(spec, ctx) |
| 91 | assert result.verdict == Verdict.FAIL |
| 92 | assert result.raw is not None |
| 93 | assert result.raw < 0.7 |
| 94 | |
| 95 | |
| 96 | def test_skip_when_no_triples_anywhere() -> None: |
| 97 | probe, spec = build_probe({"name": "pf", "kind": "preference_flip"}) |
| 98 | backend = _backend([]) |
| 99 | ctx = RunContext(backend=backend) |
| 100 | result = probe.run(spec, ctx) |
| 101 | assert result.verdict == Verdict.SKIP |
| 102 | |
| 103 | |
| 104 | def test_warn_when_too_few_base_wrong() -> None: |
| 105 | backend = _backend( |
| 106 | [ |
| 107 | ("p1", "good1", "bad1", 1.0, 2.0), # base right |
| 108 | ("p2", "good2", "bad2", 0.5, 1.0), # base right |
| 109 | ("p3", "good3", "bad3", -0.5, 0.5), # base wrong |
| 110 | ] |
| 111 | ) |
| 112 | triples = [ |
| 113 | {"prompt": p, "chosen": c, "rejected": r} |
| 114 | for p, c, r in [ |
| 115 | ("p1", "good1", "bad1"), |
| 116 | ("p2", "good2", "bad2"), |
| 117 | ("p3", "good3", "bad3"), |
| 118 | ] |
| 119 | ] |
| 120 | probe, spec = build_probe( |
| 121 | { |
| 122 | "name": "pf", |
| 123 | "kind": "preference_flip", |
| 124 | "triples": triples, |
| 125 | "min_triples_for_decision": 3, |
| 126 | } |
| 127 | ) |
| 128 | ctx = RunContext(backend=backend) |
| 129 | result = probe.run(spec, ctx) |
| 130 | assert result.verdict == Verdict.WARN |
| 131 | |
| 132 | |
| 133 | def test_warn_branch_score_formula_pinned() -> None: |
| 134 | """C10: pin the WARN-branch numerical formula so a refactor notices. |
| 135 | |
| 136 | Formula: ``score = clip(0.5 + mean_delta / 4.0, 0, 1)`` where |
| 137 | ``mean_delta`` is the average of ``(ft_margin - base_margin)`` over |
| 138 | every triple. With deltas [+1.0, +0.5, +1.0] the mean is 0.8333… |
| 139 | so the score is 0.5 + 0.8333.../4 = 0.70833…. |
| 140 | """ |
| 141 | import math |
| 142 | |
| 143 | backend = _backend( |
| 144 | [ |
| 145 | ("p1", "good1", "bad1", 1.0, 2.0), # delta=+1.0 (base right) |
| 146 | ("p2", "good2", "bad2", 0.5, 1.0), # delta=+0.5 (base right) |
| 147 | ("p3", "good3", "bad3", -0.5, 0.5), # delta=+1.0 (base wrong) |
| 148 | ] |
| 149 | ) |
| 150 | triples = [ |
| 151 | {"prompt": p, "chosen": c, "rejected": r} |
| 152 | for p, c, r in [ |
| 153 | ("p1", "good1", "bad1"), |
| 154 | ("p2", "good2", "bad2"), |
| 155 | ("p3", "good3", "bad3"), |
| 156 | ] |
| 157 | ] |
| 158 | probe, spec = build_probe( |
| 159 | { |
| 160 | "name": "pf", |
| 161 | "kind": "preference_flip", |
| 162 | "triples": triples, |
| 163 | "min_triples_for_decision": 3, |
| 164 | } |
| 165 | ) |
| 166 | ctx = RunContext(backend=backend) |
| 167 | result = probe.run(spec, ctx) |
| 168 | |
| 169 | assert result.verdict == Verdict.WARN |
| 170 | expected_mean_delta = (1.0 + 0.5 + 1.0) / 3.0 |
| 171 | expected_score = 0.5 + expected_mean_delta / 4.0 |
| 172 | assert result.raw is not None |
| 173 | assert math.isclose(result.raw, expected_mean_delta, rel_tol=1e-9) |
| 174 | assert result.score is not None |
| 175 | assert math.isclose(result.score, expected_score, rel_tol=1e-9) |
| 176 | |
| 177 | # Evidence mirrors the raw metric so report consumers can render it. |
| 178 | assert math.isclose(result.evidence["mean_margin_delta"], expected_mean_delta, rel_tol=1e-9) |
| 179 | assert result.evidence["base_wrong"] == 1 |
| 180 | assert result.evidence["total"] == 3 |
| 181 | |
| 182 | |
| 183 | def test_one_bad_triple_does_not_kill_the_batch() -> None: |
| 184 | """B14: a triple that raises ProbeError is dropped, not propagated. |
| 185 | |
| 186 | The remaining triples still produce a verdict; the dropped count |
| 187 | surfaces in evidence so a user can see what got skipped. |
| 188 | """ |
| 189 | from dlm_sway.core.errors import ProbeError |
| 190 | |
| 191 | backend = _backend( |
| 192 | [ |
| 193 | ("p1", "good1", "bad1", -2.0, 2.0), |
| 194 | ("p2", "good2", "bad2", -1.5, 1.0), |
| 195 | ("p3", "good3", "bad3", -0.5, 0.8), |
| 196 | ] |
| 197 | ) |
| 198 | |
| 199 | # Wrap the backend's logprob_of so the second triple raises. |
| 200 | raising = {"p2"} |
| 201 | original_as_base = backend.as_base |
| 202 | original_as_finetuned = backend.as_finetuned |
| 203 | |
| 204 | def _raising_view(view_cm): |
| 205 | from contextlib import contextmanager |
| 206 | |
| 207 | @contextmanager |
| 208 | def _wrap(): |
| 209 | with view_cm() as view: |
| 210 | orig = view.logprob_of |
| 211 | |
| 212 | def fenced(prompt, completion): |
| 213 | if prompt in raising: |
| 214 | raise ProbeError("logprob_of", f"simulated failure on {prompt!r}") |
| 215 | return orig(prompt, completion) |
| 216 | |
| 217 | view.logprob_of = fenced # type: ignore[method-assign] |
| 218 | yield view |
| 219 | |
| 220 | return _wrap |
| 221 | |
| 222 | backend.as_base = _raising_view(original_as_base) # type: ignore[method-assign] |
| 223 | backend.as_finetuned = _raising_view(original_as_finetuned) # type: ignore[method-assign] |
| 224 | |
| 225 | triples = [ |
| 226 | {"prompt": p, "chosen": c, "rejected": r} |
| 227 | for p, c, r in [("p1", "good1", "bad1"), ("p2", "good2", "bad2"), ("p3", "good3", "bad3")] |
| 228 | ] |
| 229 | probe, spec = build_probe( |
| 230 | { |
| 231 | "name": "pf", |
| 232 | "kind": "preference_flip", |
| 233 | "triples": triples, |
| 234 | "assert_flip_rate_gte": 0.7, |
| 235 | "min_triples_for_decision": 2, |
| 236 | } |
| 237 | ) |
| 238 | ctx = RunContext(backend=backend) |
| 239 | result = probe.run(spec, ctx) |
| 240 | |
| 241 | assert result.verdict == Verdict.PASS # the two surviving triples both flipped |
| 242 | assert result.evidence["dropped_triples"] == 1 |
| 243 | assert any("p2" in reason for reason in result.evidence["dropped_reasons"]) |
| 244 | |
| 245 | |
| 246 | def test_all_triples_failing_yields_error() -> None: |
| 247 | """When every triple raises, the probe routes to ERROR with an explanation.""" |
| 248 | from contextlib import contextmanager |
| 249 | |
| 250 | from dlm_sway.core.errors import ProbeError |
| 251 | |
| 252 | backend = _backend([("p1", "g", "b", 0.0, 0.0)]) |
| 253 | inner_as_base = backend.as_base # capture before monkeypatching |
| 254 | |
| 255 | @contextmanager |
| 256 | def _always_raise(): |
| 257 | with inner_as_base() as view: |
| 258 | |
| 259 | def _raises(*_a, **_k): |
| 260 | raise ProbeError("logprob_of", "always") |
| 261 | |
| 262 | view.logprob_of = _raises # type: ignore[method-assign] |
| 263 | yield view |
| 264 | |
| 265 | backend.as_base = _always_raise # type: ignore[method-assign] |
| 266 | backend.as_finetuned = _always_raise # type: ignore[method-assign] |
| 267 | |
| 268 | probe, spec = build_probe( |
| 269 | { |
| 270 | "name": "pf", |
| 271 | "kind": "preference_flip", |
| 272 | "triples": [{"prompt": "p1", "chosen": "g", "rejected": "b"}], |
| 273 | } |
| 274 | ) |
| 275 | result = probe.run(spec, RunContext(backend=backend)) |
| 276 | assert result.verdict == Verdict.ERROR |
| 277 | assert result.evidence["dropped_triples"] == 1 |
| 278 | |
| 279 | |
| 280 | def test_triples_pulled_from_sections() -> None: |
| 281 | pref_section = Section( |
| 282 | id="p1", |
| 283 | kind="preference", |
| 284 | content="...", |
| 285 | preferences=( |
| 286 | SectionPreference(prompt="q1", chosen="good", rejected="bad"), |
| 287 | SectionPreference(prompt="q2", chosen="good2", rejected="bad2"), |
| 288 | SectionPreference(prompt="q3", chosen="good3", rejected="bad3"), |
| 289 | ), |
| 290 | ) |
| 291 | backend = _backend( |
| 292 | [ |
| 293 | ("q1", "good", "bad", -1.0, 1.0), |
| 294 | ("q2", "good2", "bad2", -1.0, 1.0), |
| 295 | ("q3", "good3", "bad3", -1.0, 1.0), |
| 296 | ] |
| 297 | ) |
| 298 | probe, spec = build_probe( |
| 299 | { |
| 300 | "name": "pf", |
| 301 | "kind": "preference_flip", |
| 302 | "assert_flip_rate_gte": 0.7, |
| 303 | "min_triples_for_decision": 3, |
| 304 | } |
| 305 | ) |
| 306 | ctx = RunContext(backend=backend, sections=(pref_section,)) |
| 307 | result = probe.run(spec, ctx) |
| 308 | assert result.verdict == Verdict.PASS |