Python · 9660 bytes Raw Blame History
1 """B3 PreferenceFlip — did DPO/ORPO actually flip the chosen/rejected ranking?
2
3 For each ``(prompt, chosen, rejected)`` triple, compute the margin
4
5 .. math::
6 m = \\log p(\\text{chosen} \\mid \\text{prompt}) - \\log p(\\text{rejected} \\mid \\text{prompt})
7
8 under both base and fine-tuned views. Interesting triples are the ones
9 where base got the sign *wrong* (``m_base < 0``); we fail if the
10 fine-tune doesn't flip a large enough fraction of them.
11
12 Triples come from either an inline ``triples:`` block in the spec or
13 from PREFERENCE sections in :attr:`RunContext.sections`. The probe
14 returns :attr:`Verdict.SKIP` when no triples are present — this is the
15 "no PREFERENCE sections in your document" case, graceful by design.
16 """
17
18 from __future__ import annotations
19
20 import statistics
21 from typing import Literal
22
23 from pydantic import BaseModel, ConfigDict, Field
24
25 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
26 from dlm_sway.core.stats import bootstrap_ci
27 from dlm_sway.probes._zscore import (
28 no_calibration_note,
29 score_from_z,
30 verdict_from_z,
31 z_score,
32 z_scores_by_rank,
33 )
34 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
35 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
36
37
38 class PreferenceTriple(BaseModel):
39 model_config = ConfigDict(extra="forbid", frozen=True)
40
41 prompt: str
42 chosen: str
43 rejected: str
44
45
46 class PreferenceFlipSpec(ProbeSpec):
47 kind: Literal["preference_flip"] = "preference_flip"
48 triples: list[PreferenceTriple] = Field(default_factory=list)
49 """Inline triples. If empty, the probe pulls from PREFERENCE
50 sections in ctx.sections; if neither is available the probe SKIPs."""
51 assert_flip_rate_gte: float = 0.7
52 """Fraction of *base-wrong* triples that must flip under ft."""
53 assert_z_gte: float = 3.0
54 """Z-score pass criterion against the null-adapter baseline, when it
55 exists. Preferred over the raw threshold."""
56 min_triples_for_decision: int = 3
57
58
59 class PreferenceFlipProbe(Probe):
60 kind = "preference_flip"
61 spec_cls = PreferenceFlipSpec
62 category = "attribution"
63
64 @classmethod
65 def calibrate_spec(cls, ctx: RunContext) -> PreferenceFlipSpec | None:
66 # Sentinel triples. On a random-init adapter the flip rate
67 # should be ~0.5 (chance), with a small std across seeds —
68 # a useful null distribution for user suites that configure
69 # a stricter threshold.
70 del ctx
71 return PreferenceFlipSpec(
72 name="_calibration",
73 kind="preference_flip",
74 triples=[
75 PreferenceTriple(prompt="The best pet is", chosen="a loyal dog", rejected="a rock"),
76 PreferenceTriple(prompt="A good answer is", chosen="thoughtful", rejected="loud"),
77 PreferenceTriple(prompt="The next step is", chosen="careful", rejected="reckless"),
78 ],
79 min_triples_for_decision=2,
80 )
81
82 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
83 assert isinstance(spec, PreferenceFlipSpec)
84 triples = list(spec.triples) or _triples_from_sections(ctx)
85 if not triples:
86 return ProbeResult(
87 name=spec.name,
88 kind=spec.kind,
89 verdict=Verdict.SKIP,
90 score=None,
91 message="no preference triples (inline or from sections)",
92 )
93
94 from dlm_sway.core.errors import ProbeError
95
96 base_margins: list[float] = []
97 ft_margins: list[float] = []
98 dropped_triples = 0
99 dropped_reasons: list[str] = []
100 for t in triples:
101 # B14: a single bad triple (zero-token chosen / rejected,
102 # tokenizer hiccup, OOM on one prompt) used to take the whole
103 # batch down. Fence per triple so probes degrade gracefully:
104 # drop the offending triple, count it, surface in evidence.
105 try:
106 with ctx.require_backend.as_base() as b:
107 base_margin = b.logprob_of(t.prompt, t.chosen) - b.logprob_of(
108 t.prompt, t.rejected
109 )
110 with ctx.require_backend.as_finetuned() as f:
111 ft_margin = f.logprob_of(t.prompt, t.chosen) - f.logprob_of(
112 t.prompt, t.rejected
113 )
114 except ProbeError as exc:
115 dropped_triples += 1
116 if len(dropped_reasons) < 5: # cap evidence verbosity
117 dropped_reasons.append(f"{t.prompt[:40]!r}: {exc}")
118 continue
119 base_margins.append(base_margin)
120 ft_margins.append(ft_margin)
121
122 if not base_margins:
123 return ProbeResult(
124 name=spec.name,
125 kind=spec.kind,
126 verdict=Verdict.ERROR,
127 score=None,
128 evidence={
129 "dropped_triples": dropped_triples,
130 "dropped_reasons": dropped_reasons,
131 "weight": spec.weight,
132 },
133 message=(
134 f"every triple raised ProbeError ({dropped_triples} total); no usable margins"
135 ),
136 )
137
138 # Interesting denominator: base got it wrong.
139 base_wrong_idx = [i for i, m in enumerate(base_margins) if m < 0]
140 flipped_idx = [i for i in base_wrong_idx if ft_margins[i] > 0]
141
142 if len(base_wrong_idx) < spec.min_triples_for_decision:
143 # Not enough base-wrong triples to decide. Fall back to mean margin delta.
144 per_triple_deltas = [
145 ft - base for base, ft in zip(base_margins, ft_margins, strict=True)
146 ]
147 mean_delta = statistics.fmean(per_triple_deltas)
148 # F13 — every other numeric probe's WARN path carries a CI
149 # and a per-rank z when null_adapter is in the suite. Match
150 # that shape so downstream consumers don't see inconsistent
151 # fields by verdict.
152 ci_95 = bootstrap_ci(per_triple_deltas, seed=ctx.seed)
153 warn_stats = get_null_stats(ctx, spec.kind)
154 z = z_score(mean_delta, warn_stats)
155 z_by_rank = z_scores_by_rank(
156 mean_delta, get_null_stats_by_rank(ctx, spec.kind), sign=+1
157 )
158 return safe_finalize(
159 name=spec.name,
160 kind=spec.kind,
161 verdict=Verdict.WARN,
162 score=max(0.0, min(1.0, 0.5 + mean_delta / 4.0)),
163 raw=mean_delta,
164 z_score=z,
165 base_value=statistics.fmean(base_margins),
166 ft_value=statistics.fmean(ft_margins),
167 evidence={
168 "base_wrong": len(base_wrong_idx),
169 "total": len(triples),
170 "mean_margin_delta": mean_delta,
171 "dropped_triples": dropped_triples,
172 "dropped_reasons": dropped_reasons,
173 "weight": spec.weight,
174 "z_by_rank": z_by_rank,
175 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
176 },
177 message=(
178 f"only {len(base_wrong_idx)} base-wrong triples < "
179 f"{spec.min_triples_for_decision} required; reporting "
180 f"mean-margin-delta={mean_delta:+.3f}"
181 ),
182 ci_95=ci_95,
183 )
184
185 flip_rate = len(flipped_idx) / len(base_wrong_idx)
186
187 stats = get_null_stats(ctx, spec.kind)
188 z = z_score(flip_rate, stats)
189 z_by_rank = z_scores_by_rank(flip_rate, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
190 verdict_z = verdict_from_z(z, spec.assert_z_gte)
191 if verdict_z is not None:
192 verdict = verdict_z
193 score_val = score_from_z(z)
194 score = score_val if score_val is not None else 0.0
195 message = (
196 f"flip_rate={flip_rate:.2%} ({len(flipped_idx)}/{len(base_wrong_idx)}), "
197 f"z={z:+.2f}σ vs null"
198 )
199 else:
200 verdict = Verdict.PASS if flip_rate >= spec.assert_flip_rate_gte else Verdict.FAIL
201 score = min(1.0, flip_rate / max(spec.assert_flip_rate_gte, 1e-6))
202 message = (
203 f"flip_rate={flip_rate:.2%} ({len(flipped_idx)}/{len(base_wrong_idx)} "
204 f"base-wrong triples flipped by ft) {no_calibration_note(spec.kind)}"
205 )
206
207 return safe_finalize(
208 name=spec.name,
209 kind=spec.kind,
210 verdict=verdict,
211 score=score,
212 raw=flip_rate,
213 z_score=z,
214 base_value=statistics.fmean(base_margins),
215 ft_value=statistics.fmean(ft_margins),
216 evidence={
217 "flip_rate": flip_rate,
218 "flipped": len(flipped_idx),
219 "base_wrong": len(base_wrong_idx),
220 "total": len(triples),
221 "dropped_triples": dropped_triples,
222 "dropped_reasons": dropped_reasons,
223 "weight": spec.weight,
224 "z_by_rank": z_by_rank,
225 },
226 message=message,
227 )
228
229
230 def _triples_from_sections(ctx: RunContext) -> list[PreferenceTriple]:
231 if ctx.sections is None:
232 return []
233 out: list[PreferenceTriple] = []
234 for s in ctx.sections:
235 if s.kind != "preference":
236 continue
237 for p in s.preferences:
238 out.append(PreferenceTriple(prompt=p.prompt, chosen=p.chosen, rejected=p.rejected))
239 return out