Python · 5250 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.preference_flip`."""
2
3 from __future__ import annotations
4
5 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6 from dlm_sway.core.result import Verdict
7 from dlm_sway.core.sections import Section, SectionPreference
8 from dlm_sway.probes.base import RunContext, build_probe
9
10
11 def _backend(pairs: list[tuple[str, str, str, float, float]]) -> DummyDifferentialBackend:
12 """``pairs`` = list of (prompt, chosen, rejected, base_margin, ft_margin).
13
14 We distribute the margin half to the chosen and half (negative) to
15 the rejected, which is enough to make logprob_of(chosen)-logprob_of(rejected)
16 equal the requested margin.
17 """
18 base_lp: dict[tuple[str, str], float] = {}
19 ft_lp: dict[tuple[str, str], float] = {}
20 for prompt, chosen, rejected, base_m, ft_m in pairs:
21 base_lp[(prompt, chosen)] = base_m / 2
22 base_lp[(prompt, rejected)] = -base_m / 2
23 ft_lp[(prompt, chosen)] = ft_m / 2
24 ft_lp[(prompt, rejected)] = -ft_m / 2
25 return DummyDifferentialBackend(
26 base=DummyResponses(logprobs=base_lp),
27 ft=DummyResponses(logprobs=ft_lp),
28 )
29
30
31 def test_pass_when_base_wrong_flipped() -> None:
32 backend = _backend(
33 [
34 ("p1", "good1", "bad1", -2.0, 2.0), # base wrong, ft flips
35 ("p2", "good2", "bad2", -1.5, 1.0), # base wrong, ft flips
36 ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips
37 ("p4", "good4", "bad4", 1.0, 2.0), # base already right (no contribution)
38 ]
39 )
40 triples = [
41 {"prompt": p, "chosen": c, "rejected": r}
42 for (p, c, r, _, _) in [
43 ("p1", "good1", "bad1", 0, 0),
44 ("p2", "good2", "bad2", 0, 0),
45 ("p3", "good3", "bad3", 0, 0),
46 ("p4", "good4", "bad4", 0, 0),
47 ]
48 ]
49 probe, spec = build_probe(
50 {
51 "name": "pf",
52 "kind": "preference_flip",
53 "triples": triples,
54 "assert_flip_rate_gte": 0.7,
55 "min_triples_for_decision": 3,
56 }
57 )
58 ctx = RunContext(backend=backend)
59 result = probe.run(spec, ctx)
60 assert result.verdict == Verdict.PASS
61 assert result.raw == 1.0 # 3/3 flipped
62
63
64 def test_fail_when_base_wrong_not_flipped() -> None:
65 backend = _backend(
66 [
67 ("p1", "good1", "bad1", -2.0, -1.5), # base wrong, ft still wrong
68 ("p2", "good2", "bad2", -1.5, -1.0), # base wrong, ft still wrong
69 ("p3", "good3", "bad3", -0.5, 0.8), # base wrong, ft flips
70 ]
71 )
72 triples = [
73 {"prompt": p, "chosen": c, "rejected": r}
74 for p, c, r in [
75 ("p1", "good1", "bad1"),
76 ("p2", "good2", "bad2"),
77 ("p3", "good3", "bad3"),
78 ]
79 ]
80 probe, spec = build_probe(
81 {
82 "name": "pf",
83 "kind": "preference_flip",
84 "triples": triples,
85 "assert_flip_rate_gte": 0.7,
86 "min_triples_for_decision": 3,
87 }
88 )
89 ctx = RunContext(backend=backend)
90 result = probe.run(spec, ctx)
91 assert result.verdict == Verdict.FAIL
92 assert result.raw is not None
93 assert result.raw < 0.7
94
95
96 def test_skip_when_no_triples_anywhere() -> None:
97 probe, spec = build_probe({"name": "pf", "kind": "preference_flip"})
98 backend = _backend([])
99 ctx = RunContext(backend=backend)
100 result = probe.run(spec, ctx)
101 assert result.verdict == Verdict.SKIP
102
103
104 def test_warn_when_too_few_base_wrong() -> None:
105 backend = _backend(
106 [
107 ("p1", "good1", "bad1", 1.0, 2.0), # base right
108 ("p2", "good2", "bad2", 0.5, 1.0), # base right
109 ("p3", "good3", "bad3", -0.5, 0.5), # base wrong
110 ]
111 )
112 triples = [
113 {"prompt": p, "chosen": c, "rejected": r}
114 for p, c, r in [
115 ("p1", "good1", "bad1"),
116 ("p2", "good2", "bad2"),
117 ("p3", "good3", "bad3"),
118 ]
119 ]
120 probe, spec = build_probe(
121 {
122 "name": "pf",
123 "kind": "preference_flip",
124 "triples": triples,
125 "min_triples_for_decision": 3,
126 }
127 )
128 ctx = RunContext(backend=backend)
129 result = probe.run(spec, ctx)
130 assert result.verdict == Verdict.WARN
131
132
133 def test_triples_pulled_from_sections() -> None:
134 pref_section = Section(
135 id="p1",
136 kind="preference",
137 content="...",
138 preferences=(
139 SectionPreference(prompt="q1", chosen="good", rejected="bad"),
140 SectionPreference(prompt="q2", chosen="good2", rejected="bad2"),
141 SectionPreference(prompt="q3", chosen="good3", rejected="bad3"),
142 ),
143 )
144 backend = _backend(
145 [
146 ("q1", "good", "bad", -1.0, 1.0),
147 ("q2", "good2", "bad2", -1.0, 1.0),
148 ("q3", "good3", "bad3", -1.0, 1.0),
149 ]
150 )
151 probe, spec = build_probe(
152 {
153 "name": "pf",
154 "kind": "preference_flip",
155 "assert_flip_rate_gte": 0.7,
156 "min_triples_for_decision": 3,
157 }
158 )
159 ctx = RunContext(backend=backend, sections=(pref_section,))
160 result = probe.run(spec, ctx)
161 assert result.verdict == Verdict.PASS