| 1 | """B3 PreferenceFlip — did DPO/ORPO actually flip the chosen/rejected ranking? |
| 2 | |
| 3 | For each ``(prompt, chosen, rejected)`` triple, compute the margin |
| 4 | |
| 5 | .. math:: |
| 6 | m = \\log p(\\text{chosen} \\mid \\text{prompt}) - \\log p(\\text{rejected} \\mid \\text{prompt}) |
| 7 | |
| 8 | under both base and fine-tuned views. Interesting triples are the ones |
| 9 | where base got the sign *wrong* (``m_base < 0``); we fail if the |
| 10 | fine-tune doesn't flip a large enough fraction of them. |
| 11 | |
| 12 | Triples come from either an inline ``triples:`` block in the spec or |
| 13 | from PREFERENCE sections in :attr:`RunContext.sections`. The probe |
| 14 | returns :attr:`Verdict.SKIP` when no triples are present — this is the |
| 15 | "no PREFERENCE sections in your document" case, graceful by design. |
| 16 | """ |
| 17 | |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | import statistics |
| 21 | from typing import Literal |
| 22 | |
| 23 | from pydantic import BaseModel, ConfigDict, Field |
| 24 | |
| 25 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 26 | from dlm_sway.core.stats import bootstrap_ci |
| 27 | from dlm_sway.probes._zscore import ( |
| 28 | no_calibration_note, |
| 29 | score_from_z, |
| 30 | verdict_from_z, |
| 31 | z_score, |
| 32 | z_scores_by_rank, |
| 33 | ) |
| 34 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 35 | from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank |
| 36 | |
| 37 | |
| 38 | class PreferenceTriple(BaseModel): |
| 39 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 40 | |
| 41 | prompt: str |
| 42 | chosen: str |
| 43 | rejected: str |
| 44 | |
| 45 | |
| 46 | class PreferenceFlipSpec(ProbeSpec): |
| 47 | kind: Literal["preference_flip"] = "preference_flip" |
| 48 | triples: list[PreferenceTriple] = Field(default_factory=list) |
| 49 | """Inline triples. If empty, the probe pulls from PREFERENCE |
| 50 | sections in ctx.sections; if neither is available the probe SKIPs.""" |
| 51 | assert_flip_rate_gte: float = 0.7 |
| 52 | """Fraction of *base-wrong* triples that must flip under ft.""" |
| 53 | assert_z_gte: float = 3.0 |
| 54 | """Z-score pass criterion against the null-adapter baseline, when it |
| 55 | exists. Preferred over the raw threshold.""" |
| 56 | min_triples_for_decision: int = 3 |
| 57 | |
| 58 | |
| 59 | class PreferenceFlipProbe(Probe): |
| 60 | kind = "preference_flip" |
| 61 | spec_cls = PreferenceFlipSpec |
| 62 | category = "attribution" |
| 63 | |
| 64 | @classmethod |
| 65 | def calibrate_spec(cls, ctx: RunContext) -> PreferenceFlipSpec | None: |
| 66 | # Sentinel triples. On a random-init adapter the flip rate |
| 67 | # should be ~0.5 (chance), with a small std across seeds — |
| 68 | # a useful null distribution for user suites that configure |
| 69 | # a stricter threshold. |
| 70 | del ctx |
| 71 | return PreferenceFlipSpec( |
| 72 | name="_calibration", |
| 73 | kind="preference_flip", |
| 74 | triples=[ |
| 75 | PreferenceTriple(prompt="The best pet is", chosen="a loyal dog", rejected="a rock"), |
| 76 | PreferenceTriple(prompt="A good answer is", chosen="thoughtful", rejected="loud"), |
| 77 | PreferenceTriple(prompt="The next step is", chosen="careful", rejected="reckless"), |
| 78 | ], |
| 79 | min_triples_for_decision=2, |
| 80 | ) |
| 81 | |
| 82 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 83 | assert isinstance(spec, PreferenceFlipSpec) |
| 84 | triples = list(spec.triples) or _triples_from_sections(ctx) |
| 85 | if not triples: |
| 86 | return ProbeResult( |
| 87 | name=spec.name, |
| 88 | kind=spec.kind, |
| 89 | verdict=Verdict.SKIP, |
| 90 | score=None, |
| 91 | message="no preference triples (inline or from sections)", |
| 92 | ) |
| 93 | |
| 94 | from dlm_sway.core.errors import ProbeError |
| 95 | |
| 96 | base_margins: list[float] = [] |
| 97 | ft_margins: list[float] = [] |
| 98 | dropped_triples = 0 |
| 99 | dropped_reasons: list[str] = [] |
| 100 | for t in triples: |
| 101 | # B14: a single bad triple (zero-token chosen / rejected, |
| 102 | # tokenizer hiccup, OOM on one prompt) used to take the whole |
| 103 | # batch down. Fence per triple so probes degrade gracefully: |
| 104 | # drop the offending triple, count it, surface in evidence. |
| 105 | try: |
| 106 | with ctx.require_backend.as_base() as b: |
| 107 | base_margin = b.logprob_of(t.prompt, t.chosen) - b.logprob_of( |
| 108 | t.prompt, t.rejected |
| 109 | ) |
| 110 | with ctx.require_backend.as_finetuned() as f: |
| 111 | ft_margin = f.logprob_of(t.prompt, t.chosen) - f.logprob_of( |
| 112 | t.prompt, t.rejected |
| 113 | ) |
| 114 | except ProbeError as exc: |
| 115 | dropped_triples += 1 |
| 116 | if len(dropped_reasons) < 5: # cap evidence verbosity |
| 117 | dropped_reasons.append(f"{t.prompt[:40]!r}: {exc}") |
| 118 | continue |
| 119 | base_margins.append(base_margin) |
| 120 | ft_margins.append(ft_margin) |
| 121 | |
| 122 | if not base_margins: |
| 123 | return ProbeResult( |
| 124 | name=spec.name, |
| 125 | kind=spec.kind, |
| 126 | verdict=Verdict.ERROR, |
| 127 | score=None, |
| 128 | evidence={ |
| 129 | "dropped_triples": dropped_triples, |
| 130 | "dropped_reasons": dropped_reasons, |
| 131 | "weight": spec.weight, |
| 132 | }, |
| 133 | message=( |
| 134 | f"every triple raised ProbeError ({dropped_triples} total); no usable margins" |
| 135 | ), |
| 136 | ) |
| 137 | |
| 138 | # Interesting denominator: base got it wrong. |
| 139 | base_wrong_idx = [i for i, m in enumerate(base_margins) if m < 0] |
| 140 | flipped_idx = [i for i in base_wrong_idx if ft_margins[i] > 0] |
| 141 | |
| 142 | if len(base_wrong_idx) < spec.min_triples_for_decision: |
| 143 | # Not enough base-wrong triples to decide. Fall back to mean margin delta. |
| 144 | per_triple_deltas = [ |
| 145 | ft - base for base, ft in zip(base_margins, ft_margins, strict=True) |
| 146 | ] |
| 147 | mean_delta = statistics.fmean(per_triple_deltas) |
| 148 | # F13 — every other numeric probe's WARN path carries a CI |
| 149 | # and a per-rank z when null_adapter is in the suite. Match |
| 150 | # that shape so downstream consumers don't see inconsistent |
| 151 | # fields by verdict. |
| 152 | ci_95 = bootstrap_ci(per_triple_deltas, seed=ctx.seed) |
| 153 | warn_stats = get_null_stats(ctx, spec.kind) |
| 154 | z = z_score(mean_delta, warn_stats) |
| 155 | z_by_rank = z_scores_by_rank( |
| 156 | mean_delta, get_null_stats_by_rank(ctx, spec.kind), sign=+1 |
| 157 | ) |
| 158 | return safe_finalize( |
| 159 | name=spec.name, |
| 160 | kind=spec.kind, |
| 161 | verdict=Verdict.WARN, |
| 162 | score=max(0.0, min(1.0, 0.5 + mean_delta / 4.0)), |
| 163 | raw=mean_delta, |
| 164 | z_score=z, |
| 165 | base_value=statistics.fmean(base_margins), |
| 166 | ft_value=statistics.fmean(ft_margins), |
| 167 | evidence={ |
| 168 | "base_wrong": len(base_wrong_idx), |
| 169 | "total": len(triples), |
| 170 | "mean_margin_delta": mean_delta, |
| 171 | "dropped_triples": dropped_triples, |
| 172 | "dropped_reasons": dropped_reasons, |
| 173 | "weight": spec.weight, |
| 174 | "z_by_rank": z_by_rank, |
| 175 | "raw_ci_95": list(ci_95) if ci_95 is not None else None, |
| 176 | }, |
| 177 | message=( |
| 178 | f"only {len(base_wrong_idx)} base-wrong triples < " |
| 179 | f"{spec.min_triples_for_decision} required; reporting " |
| 180 | f"mean-margin-delta={mean_delta:+.3f}" |
| 181 | ), |
| 182 | ci_95=ci_95, |
| 183 | ) |
| 184 | |
| 185 | flip_rate = len(flipped_idx) / len(base_wrong_idx) |
| 186 | |
| 187 | stats = get_null_stats(ctx, spec.kind) |
| 188 | z = z_score(flip_rate, stats) |
| 189 | z_by_rank = z_scores_by_rank(flip_rate, get_null_stats_by_rank(ctx, spec.kind), sign=+1) |
| 190 | verdict_z = verdict_from_z(z, spec.assert_z_gte) |
| 191 | if verdict_z is not None: |
| 192 | verdict = verdict_z |
| 193 | score_val = score_from_z(z) |
| 194 | score = score_val if score_val is not None else 0.0 |
| 195 | message = ( |
| 196 | f"flip_rate={flip_rate:.2%} ({len(flipped_idx)}/{len(base_wrong_idx)}), " |
| 197 | f"z={z:+.2f}σ vs null" |
| 198 | ) |
| 199 | else: |
| 200 | verdict = Verdict.PASS if flip_rate >= spec.assert_flip_rate_gte else Verdict.FAIL |
| 201 | score = min(1.0, flip_rate / max(spec.assert_flip_rate_gte, 1e-6)) |
| 202 | message = ( |
| 203 | f"flip_rate={flip_rate:.2%} ({len(flipped_idx)}/{len(base_wrong_idx)} " |
| 204 | f"base-wrong triples flipped by ft) {no_calibration_note(spec.kind)}" |
| 205 | ) |
| 206 | |
| 207 | return safe_finalize( |
| 208 | name=spec.name, |
| 209 | kind=spec.kind, |
| 210 | verdict=verdict, |
| 211 | score=score, |
| 212 | raw=flip_rate, |
| 213 | z_score=z, |
| 214 | base_value=statistics.fmean(base_margins), |
| 215 | ft_value=statistics.fmean(ft_margins), |
| 216 | evidence={ |
| 217 | "flip_rate": flip_rate, |
| 218 | "flipped": len(flipped_idx), |
| 219 | "base_wrong": len(base_wrong_idx), |
| 220 | "total": len(triples), |
| 221 | "dropped_triples": dropped_triples, |
| 222 | "dropped_reasons": dropped_reasons, |
| 223 | "weight": spec.weight, |
| 224 | "z_by_rank": z_by_rank, |
| 225 | }, |
| 226 | message=message, |
| 227 | ) |
| 228 | |
| 229 | |
| 230 | def _triples_from_sections(ctx: RunContext) -> list[PreferenceTriple]: |
| 231 | if ctx.sections is None: |
| 232 | return [] |
| 233 | out: list[PreferenceTriple] = [] |
| 234 | for s in ctx.sections: |
| 235 | if s.kind != "preference": |
| 236 | continue |
| 237 | for p in s.preferences: |
| 238 | out.append(PreferenceTriple(prompt=p.prompt, chosen=p.chosen, rejected=p.rejected)) |
| 239 | return out |