tenseleyflow/sway / e05932b

Browse files

sway(probes): B3 preference_flip on chosen/rejected margin inversion

Authored by espadonne
SHA
e05932b5467c0565a6cf919c3c2ca472af4ca942
Parents
887bb35
Tree
d34264d

2 changed files

StatusFile+-
A src/dlm_sway/probes/preference_flip.py 140 0
A tests/unit/test_probe_preference_flip.py 161 0
src/dlm_sway/probes/preference_flip.pyadded
@@ -0,0 +1,140 @@
1
+"""B3 PreferenceFlip — did DPO/ORPO actually flip the chosen/rejected ranking?
2
+
3
+For each ``(prompt, chosen, rejected)`` triple, compute the margin
4
+
5
+.. math::
6
+    m = \\log p(\\text{chosen} \\mid \\text{prompt}) - \\log p(\\text{rejected} \\mid \\text{prompt})
7
+
8
+under both base and fine-tuned views. Interesting triples are the ones
9
+where base got the sign *wrong* (``m_base < 0``); we fail if the
10
+fine-tune doesn't flip a large enough fraction of them.
11
+
12
+Triples come from either an inline ``triples:`` block in the spec or
13
+from PREFERENCE sections in :attr:`RunContext.sections`. The probe
14
+returns :attr:`Verdict.SKIP` when no triples are present — this is the
15
+"no PREFERENCE sections in your document" case, graceful by design.
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+import statistics
21
+from typing import Literal
22
+
23
+from pydantic import BaseModel, ConfigDict, Field
24
+
25
+from dlm_sway.core.result import ProbeResult, Verdict
26
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
27
+
28
+
29
+class PreferenceTriple(BaseModel):
30
+    model_config = ConfigDict(extra="forbid", frozen=True)
31
+
32
+    prompt: str
33
+    chosen: str
34
+    rejected: str
35
+
36
+
37
+class PreferenceFlipSpec(ProbeSpec):
38
+    kind: Literal["preference_flip"] = "preference_flip"
39
+    triples: list[PreferenceTriple] = Field(default_factory=list)
40
+    """Inline triples. If empty, the probe pulls from PREFERENCE
41
+    sections in ctx.sections; if neither is available the probe SKIPs."""
42
+    assert_flip_rate_gte: float = 0.7
43
+    """Fraction of *base-wrong* triples that must flip under ft."""
44
+    min_triples_for_decision: int = 3
45
+
46
+
47
+class PreferenceFlipProbe(Probe):
48
+    kind = "preference_flip"
49
+    spec_cls = PreferenceFlipSpec
50
+    category = "attribution"
51
+
52
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
53
+        assert isinstance(spec, PreferenceFlipSpec)
54
+        triples = list(spec.triples) or _triples_from_sections(ctx)
55
+        if not triples:
56
+            return ProbeResult(
57
+                name=spec.name,
58
+                kind=spec.kind,
59
+                verdict=Verdict.SKIP,
60
+                score=None,
61
+                message="no preference triples (inline or from sections)",
62
+            )
63
+
64
+        base_margins: list[float] = []
65
+        ft_margins: list[float] = []
66
+        for t in triples:
67
+            with ctx.backend.as_base() as b:
68
+                base_margins.append(
69
+                    b.logprob_of(t.prompt, t.chosen) - b.logprob_of(t.prompt, t.rejected)
70
+                )
71
+            with ctx.backend.as_finetuned() as f:
72
+                ft_margins.append(
73
+                    f.logprob_of(t.prompt, t.chosen) - f.logprob_of(t.prompt, t.rejected)
74
+                )
75
+
76
+        # Interesting denominator: base got it wrong.
77
+        base_wrong_idx = [i for i, m in enumerate(base_margins) if m < 0]
78
+        flipped_idx = [i for i in base_wrong_idx if ft_margins[i] > 0]
79
+
80
+        if len(base_wrong_idx) < spec.min_triples_for_decision:
81
+            # Not enough base-wrong triples to decide. Fall back to mean margin delta.
82
+            mean_delta = statistics.fmean(
83
+                (ft - base) for base, ft in zip(base_margins, ft_margins, strict=True)
84
+            )
85
+            verdict = Verdict.WARN
86
+            return ProbeResult(
87
+                name=spec.name,
88
+                kind=spec.kind,
89
+                verdict=verdict,
90
+                score=max(0.0, min(1.0, 0.5 + mean_delta / 4.0)),
91
+                raw=mean_delta,
92
+                base_value=statistics.fmean(base_margins),
93
+                ft_value=statistics.fmean(ft_margins),
94
+                evidence={
95
+                    "base_wrong": len(base_wrong_idx),
96
+                    "total": len(triples),
97
+                    "mean_margin_delta": mean_delta,
98
+                    "weight": spec.weight,
99
+                },
100
+                message=(
101
+                    f"only {len(base_wrong_idx)} base-wrong triples < "
102
+                    f"{spec.min_triples_for_decision} required; reporting mean-margin-delta={mean_delta:+.3f}"
103
+                ),
104
+            )
105
+
106
+        flip_rate = len(flipped_idx) / len(base_wrong_idx)
107
+        verdict = Verdict.PASS if flip_rate >= spec.assert_flip_rate_gte else Verdict.FAIL
108
+        score = min(1.0, flip_rate / max(spec.assert_flip_rate_gte, 1e-6))
109
+        return ProbeResult(
110
+            name=spec.name,
111
+            kind=spec.kind,
112
+            verdict=verdict,
113
+            score=score,
114
+            raw=flip_rate,
115
+            base_value=statistics.fmean(base_margins),
116
+            ft_value=statistics.fmean(ft_margins),
117
+            evidence={
118
+                "flip_rate": flip_rate,
119
+                "flipped": len(flipped_idx),
120
+                "base_wrong": len(base_wrong_idx),
121
+                "total": len(triples),
122
+                "weight": spec.weight,
123
+            },
124
+            message=(
125
+                f"flip_rate={flip_rate:.2%} ({len(flipped_idx)}/{len(base_wrong_idx)} "
126
+                f"base-wrong triples flipped by ft)"
127
+            ),
128
+        )
129
+
130
+
131
+def _triples_from_sections(ctx: RunContext) -> list[PreferenceTriple]:
132
+    if ctx.sections is None:
133
+        return []
134
+    out: list[PreferenceTriple] = []
135
+    for s in ctx.sections:
136
+        if s.kind != "preference":
137
+            continue
138
+        for p in s.preferences:
139
+            out.append(PreferenceTriple(prompt=p.prompt, chosen=p.chosen, rejected=p.rejected))
140
+    return out
tests/unit/test_probe_preference_flip.pyadded
@@ -0,0 +1,161 @@
1
+"""Tests for :mod:`dlm_sway.probes.preference_flip`."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6
+from dlm_sway.core.result import Verdict
7
+from dlm_sway.core.sections import Section, SectionPreference
8
+from dlm_sway.probes.base import RunContext, build_probe
9
+
10
+
11
+def _backend(pairs: list[tuple[str, str, str, float, float]]) -> DummyDifferentialBackend:
12
+    """``pairs`` = list of (prompt, chosen, rejected, base_margin, ft_margin).
13
+
14
+    We distribute the margin half to the chosen and half (negative) to
15
+    the rejected, which is enough to make logprob_of(chosen)-logprob_of(rejected)
16
+    equal the requested margin.
17
+    """
18
+    base_lp: dict[tuple[str, str], float] = {}
19
+    ft_lp: dict[tuple[str, str], float] = {}
20
+    for prompt, chosen, rejected, base_m, ft_m in pairs:
21
+        base_lp[(prompt, chosen)] = base_m / 2
22
+        base_lp[(prompt, rejected)] = -base_m / 2
23
+        ft_lp[(prompt, chosen)] = ft_m / 2
24
+        ft_lp[(prompt, rejected)] = -ft_m / 2
25
+    return DummyDifferentialBackend(
26
+        base=DummyResponses(logprobs=base_lp),
27
+        ft=DummyResponses(logprobs=ft_lp),
28
+    )
29
+
30
+
31
+def test_pass_when_base_wrong_flipped() -> None:
32
+    backend = _backend(
33
+        [
34
+            ("p1", "good1", "bad1", -2.0, 2.0),  # base wrong, ft flips
35
+            ("p2", "good2", "bad2", -1.5, 1.0),  # base wrong, ft flips
36
+            ("p3", "good3", "bad3", -0.5, 0.8),  # base wrong, ft flips
37
+            ("p4", "good4", "bad4", 1.0, 2.0),  # base already right (no contribution)
38
+        ]
39
+    )
40
+    triples = [
41
+        {"prompt": p, "chosen": c, "rejected": r}
42
+        for (p, c, r, _, _) in [
43
+            ("p1", "good1", "bad1", 0, 0),
44
+            ("p2", "good2", "bad2", 0, 0),
45
+            ("p3", "good3", "bad3", 0, 0),
46
+            ("p4", "good4", "bad4", 0, 0),
47
+        ]
48
+    ]
49
+    probe, spec = build_probe(
50
+        {
51
+            "name": "pf",
52
+            "kind": "preference_flip",
53
+            "triples": triples,
54
+            "assert_flip_rate_gte": 0.7,
55
+            "min_triples_for_decision": 3,
56
+        }
57
+    )
58
+    ctx = RunContext(backend=backend)
59
+    result = probe.run(spec, ctx)
60
+    assert result.verdict == Verdict.PASS
61
+    assert result.raw == 1.0  # 3/3 flipped
62
+
63
+
64
+def test_fail_when_base_wrong_not_flipped() -> None:
65
+    backend = _backend(
66
+        [
67
+            ("p1", "good1", "bad1", -2.0, -1.5),  # base wrong, ft still wrong
68
+            ("p2", "good2", "bad2", -1.5, -1.0),  # base wrong, ft still wrong
69
+            ("p3", "good3", "bad3", -0.5, 0.8),  # base wrong, ft flips
70
+        ]
71
+    )
72
+    triples = [
73
+        {"prompt": p, "chosen": c, "rejected": r}
74
+        for p, c, r in [
75
+            ("p1", "good1", "bad1"),
76
+            ("p2", "good2", "bad2"),
77
+            ("p3", "good3", "bad3"),
78
+        ]
79
+    ]
80
+    probe, spec = build_probe(
81
+        {
82
+            "name": "pf",
83
+            "kind": "preference_flip",
84
+            "triples": triples,
85
+            "assert_flip_rate_gte": 0.7,
86
+            "min_triples_for_decision": 3,
87
+        }
88
+    )
89
+    ctx = RunContext(backend=backend)
90
+    result = probe.run(spec, ctx)
91
+    assert result.verdict == Verdict.FAIL
92
+    assert result.raw is not None
93
+    assert result.raw < 0.7
94
+
95
+
96
+def test_skip_when_no_triples_anywhere() -> None:
97
+    probe, spec = build_probe({"name": "pf", "kind": "preference_flip"})
98
+    backend = _backend([])
99
+    ctx = RunContext(backend=backend)
100
+    result = probe.run(spec, ctx)
101
+    assert result.verdict == Verdict.SKIP
102
+
103
+
104
+def test_warn_when_too_few_base_wrong() -> None:
105
+    backend = _backend(
106
+        [
107
+            ("p1", "good1", "bad1", 1.0, 2.0),  # base right
108
+            ("p2", "good2", "bad2", 0.5, 1.0),  # base right
109
+            ("p3", "good3", "bad3", -0.5, 0.5),  # base wrong
110
+        ]
111
+    )
112
+    triples = [
113
+        {"prompt": p, "chosen": c, "rejected": r}
114
+        for p, c, r in [
115
+            ("p1", "good1", "bad1"),
116
+            ("p2", "good2", "bad2"),
117
+            ("p3", "good3", "bad3"),
118
+        ]
119
+    ]
120
+    probe, spec = build_probe(
121
+        {
122
+            "name": "pf",
123
+            "kind": "preference_flip",
124
+            "triples": triples,
125
+            "min_triples_for_decision": 3,
126
+        }
127
+    )
128
+    ctx = RunContext(backend=backend)
129
+    result = probe.run(spec, ctx)
130
+    assert result.verdict == Verdict.WARN
131
+
132
+
133
+def test_triples_pulled_from_sections() -> None:
134
+    pref_section = Section(
135
+        id="p1",
136
+        kind="preference",
137
+        content="...",
138
+        preferences=(
139
+            SectionPreference(prompt="q1", chosen="good", rejected="bad"),
140
+            SectionPreference(prompt="q2", chosen="good2", rejected="bad2"),
141
+            SectionPreference(prompt="q3", chosen="good3", rejected="bad3"),
142
+        ),
143
+    )
144
+    backend = _backend(
145
+        [
146
+            ("q1", "good", "bad", -1.0, 1.0),
147
+            ("q2", "good2", "bad2", -1.0, 1.0),
148
+            ("q3", "good3", "bad3", -1.0, 1.0),
149
+        ]
150
+    )
151
+    probe, spec = build_probe(
152
+        {
153
+            "name": "pf",
154
+            "kind": "preference_flip",
155
+            "assert_flip_rate_gte": 0.7,
156
+            "min_triples_for_decision": 3,
157
+        }
158
+    )
159
+    ctx = RunContext(backend=backend, sections=(pref_section,))
160
+    result = probe.run(spec, ctx)
161
+    assert result.verdict == Verdict.PASS