Python · 5473 bytes Raw Blame History
1 """B2 ParaphraseInvariance — memorization vs generalization, per case.
2
3 For each ``(prompt, gold, paraphrases)`` test case:
4
5 - ``verbatim_lift``: Δ-per-token = logprob_ft(prompt, gold) - logprob_base(prompt, gold)
6 - ``paraphrase_lift``: mean Δ-per-token over the paraphrased prompts
7
8 A model that memorized the exact prompt has high ``verbatim_lift`` but
9 near-zero ``paraphrase_lift``. A model that learned the underlying
10 *pattern* has both values positive and close to each other.
11
12 We report:
13
14 - ``generalization_ratio = paraphrase_lift / max(verbatim_lift, eps)``
15 - ``verbatim_score``: whether the adapter significantly moved the
16 verbatim-prompt logprob (sanity check)
17
18 The pass criterion depends on the stated intent: by default we require
19 both high verbatim lift and high generalization ratio. If the spec's
20 ``intent`` is ``"memorize"``, the ratio requirement inverts — we *want*
21 verbatim >> paraphrase.
22 """
23
24 from __future__ import annotations
25
26 import statistics
27 from typing import Literal
28
29 from pydantic import BaseModel, ConfigDict, Field
30
31 from dlm_sway.core.result import ProbeResult, Verdict
32 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
33
34 Intent = Literal["generalize", "memorize", "both"]
35
36
37 class ParaphraseCase(BaseModel):
38 """One paraphrase-invariance case."""
39
40 model_config = ConfigDict(extra="forbid", frozen=True)
41
42 prompt: str
43 gold: str
44 paraphrases: list[str] = Field(default_factory=list, min_length=1)
45
46
47 class ParaphraseInvarianceSpec(ProbeSpec):
48 kind: Literal["paraphrase_invariance"] = "paraphrase_invariance"
49 cases: list[ParaphraseCase] = Field(default_factory=list)
50 intent: Intent = "generalize"
51 min_verbatim_lift: float = 0.2
52 min_generalization_ratio: float = 0.5
53 max_generalization_ratio_if_memorize: float = 0.5
54
55
56 class ParaphraseInvarianceProbe(Probe):
57 kind = "paraphrase_invariance"
58 spec_cls = ParaphraseInvarianceSpec
59 category = "attribution"
60
61 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62 assert isinstance(spec, ParaphraseInvarianceSpec)
63 if not spec.cases:
64 return ProbeResult(
65 name=spec.name,
66 kind=spec.kind,
67 verdict=Verdict.ERROR,
68 score=None,
69 message="no cases provided",
70 )
71
72 verbatim_lifts: list[float] = []
73 paraphrase_lifts: list[float] = []
74 per_case: list[dict[str, float | str]] = []
75
76 for case in spec.cases:
77 tokens = max(_token_estimate(case.gold), 1)
78 with ctx.backend.as_base() as b:
79 lp_base_verb = b.logprob_of(case.prompt, case.gold) / tokens
80 lp_base_par = [b.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
81 with ctx.backend.as_finetuned() as f:
82 lp_ft_verb = f.logprob_of(case.prompt, case.gold) / tokens
83 lp_ft_par = [f.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
84
85 verb_lift = lp_ft_verb - lp_base_verb
86 par_lift = statistics.fmean(
87 (ft - base) for base, ft in zip(lp_base_par, lp_ft_par, strict=True)
88 )
89 verbatim_lifts.append(verb_lift)
90 paraphrase_lifts.append(par_lift)
91 per_case.append(
92 {
93 "prompt": case.prompt[:80],
94 "verbatim_lift": verb_lift,
95 "paraphrase_lift": par_lift,
96 }
97 )
98
99 mean_verb = statistics.fmean(verbatim_lifts)
100 mean_par = statistics.fmean(paraphrase_lifts)
101 ratio = mean_par / mean_verb if abs(mean_verb) > 1e-9 else 0.0
102
103 verdict, score, msg = _decide(spec, mean_verb, mean_par, ratio)
104
105 return ProbeResult(
106 name=spec.name,
107 kind=spec.kind,
108 verdict=verdict,
109 score=score,
110 raw=ratio,
111 base_value=mean_verb,
112 ft_value=mean_par,
113 evidence={
114 "verbatim_lift_mean": mean_verb,
115 "paraphrase_lift_mean": mean_par,
116 "generalization_ratio": ratio,
117 "intent": spec.intent,
118 "per_case": per_case[:8],
119 "weight": spec.weight,
120 },
121 message=msg,
122 )
123
124
125 def _decide(
126 spec: ParaphraseInvarianceSpec, verb: float, par: float, ratio: float
127 ) -> tuple[Verdict, float, str]:
128 """Apply the intent-aware pass rule and return (verdict, score, message)."""
129 base_msg = f"verb={verb:+.3f}, para={par:+.3f}, ratio={ratio:.2f}"
130 if spec.intent == "memorize":
131 verd = (
132 Verdict.PASS
133 if verb >= spec.min_verbatim_lift and ratio <= spec.max_generalization_ratio_if_memorize
134 else Verdict.FAIL
135 )
136 score = min(1.0, max(0.0, verb / max(spec.min_verbatim_lift, 1e-6)))
137 return verd, score, f"{base_msg} — intent=memorize"
138 # Default: generalize (or "both")
139 passed = verb >= spec.min_verbatim_lift and ratio >= spec.min_generalization_ratio
140 verd = Verdict.PASS if passed else Verdict.FAIL
141 gen_component = min(1.0, max(0.0, ratio / max(spec.min_generalization_ratio, 1e-6)))
142 verb_component = min(1.0, max(0.0, verb / max(spec.min_verbatim_lift, 1e-6)))
143 score = 0.5 * gen_component + 0.5 * verb_component
144 return verd, score, f"{base_msg} — intent={spec.intent}"
145
146
147 def _token_estimate(s: str) -> int:
148 return max(1, len(s) // 4)