tenseleyflow/sway / 887bb35

Browse files

sway(probes): B2 paraphrase_invariance with intent-aware pass rule

Authored by espadonne
SHA
887bb35a0cee6e21f88d877d0657e40c26a5c3d4
Parents
022829a
Tree
99280a8

2 changed files

StatusFile+-
A src/dlm_sway/probes/paraphrase_invariance.py 148 0
A tests/unit/test_probe_paraphrase_invariance.py 91 0
src/dlm_sway/probes/paraphrase_invariance.pyadded
@@ -0,0 +1,148 @@
1
+"""B2 ParaphraseInvariance — memorization vs generalization, per case.
2
+
3
+For each ``(prompt, gold, paraphrases)`` test case:
4
+
5
+- ``verbatim_lift``:  Δ-per-token = logprob_ft(prompt, gold) - logprob_base(prompt, gold)
6
+- ``paraphrase_lift``: mean Δ-per-token over the paraphrased prompts
7
+
8
+A model that memorized the exact prompt has high ``verbatim_lift`` but
9
+near-zero ``paraphrase_lift``. A model that learned the underlying
10
+*pattern* has both values positive and close to each other.
11
+
12
+We report:
13
+
14
+- ``generalization_ratio = paraphrase_lift / max(verbatim_lift, eps)``
15
+- ``verbatim_score``: whether the adapter significantly moved the
16
+  verbatim-prompt logprob (sanity check)
17
+
18
+The pass criterion depends on the stated intent: by default we require
19
+both high verbatim lift and high generalization ratio. If the spec's
20
+``intent`` is ``"memorize"``, the ratio requirement inverts — we *want*
21
+verbatim >> paraphrase.
22
+"""
23
+
24
+from __future__ import annotations
25
+
26
+import statistics
27
+from typing import Literal
28
+
29
+from pydantic import BaseModel, ConfigDict, Field
30
+
31
+from dlm_sway.core.result import ProbeResult, Verdict
32
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
33
+
34
+Intent = Literal["generalize", "memorize", "both"]
35
+
36
+
37
+class ParaphraseCase(BaseModel):
38
+    """One paraphrase-invariance case."""
39
+
40
+    model_config = ConfigDict(extra="forbid", frozen=True)
41
+
42
+    prompt: str
43
+    gold: str
44
+    paraphrases: list[str] = Field(default_factory=list, min_length=1)
45
+
46
+
47
+class ParaphraseInvarianceSpec(ProbeSpec):
48
+    kind: Literal["paraphrase_invariance"] = "paraphrase_invariance"
49
+    cases: list[ParaphraseCase] = Field(default_factory=list)
50
+    intent: Intent = "generalize"
51
+    min_verbatim_lift: float = 0.2
52
+    min_generalization_ratio: float = 0.5
53
+    max_generalization_ratio_if_memorize: float = 0.5
54
+
55
+
56
+class ParaphraseInvarianceProbe(Probe):
57
+    kind = "paraphrase_invariance"
58
+    spec_cls = ParaphraseInvarianceSpec
59
+    category = "attribution"
60
+
61
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62
+        assert isinstance(spec, ParaphraseInvarianceSpec)
63
+        if not spec.cases:
64
+            return ProbeResult(
65
+                name=spec.name,
66
+                kind=spec.kind,
67
+                verdict=Verdict.ERROR,
68
+                score=None,
69
+                message="no cases provided",
70
+            )
71
+
72
+        verbatim_lifts: list[float] = []
73
+        paraphrase_lifts: list[float] = []
74
+        per_case: list[dict[str, float | str]] = []
75
+
76
+        for case in spec.cases:
77
+            tokens = max(_token_estimate(case.gold), 1)
78
+            with ctx.backend.as_base() as b:
79
+                lp_base_verb = b.logprob_of(case.prompt, case.gold) / tokens
80
+                lp_base_par = [b.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
81
+            with ctx.backend.as_finetuned() as f:
82
+                lp_ft_verb = f.logprob_of(case.prompt, case.gold) / tokens
83
+                lp_ft_par = [f.logprob_of(p, case.gold) / tokens for p in case.paraphrases]
84
+
85
+            verb_lift = lp_ft_verb - lp_base_verb
86
+            par_lift = statistics.fmean(
87
+                (ft - base) for base, ft in zip(lp_base_par, lp_ft_par, strict=True)
88
+            )
89
+            verbatim_lifts.append(verb_lift)
90
+            paraphrase_lifts.append(par_lift)
91
+            per_case.append(
92
+                {
93
+                    "prompt": case.prompt[:80],
94
+                    "verbatim_lift": verb_lift,
95
+                    "paraphrase_lift": par_lift,
96
+                }
97
+            )
98
+
99
+        mean_verb = statistics.fmean(verbatim_lifts)
100
+        mean_par = statistics.fmean(paraphrase_lifts)
101
+        ratio = mean_par / mean_verb if abs(mean_verb) > 1e-9 else 0.0
102
+
103
+        verdict, score, msg = _decide(spec, mean_verb, mean_par, ratio)
104
+
105
+        return ProbeResult(
106
+            name=spec.name,
107
+            kind=spec.kind,
108
+            verdict=verdict,
109
+            score=score,
110
+            raw=ratio,
111
+            base_value=mean_verb,
112
+            ft_value=mean_par,
113
+            evidence={
114
+                "verbatim_lift_mean": mean_verb,
115
+                "paraphrase_lift_mean": mean_par,
116
+                "generalization_ratio": ratio,
117
+                "intent": spec.intent,
118
+                "per_case": per_case[:8],
119
+                "weight": spec.weight,
120
+            },
121
+            message=msg,
122
+        )
123
+
124
+
125
+def _decide(
126
+    spec: ParaphraseInvarianceSpec, verb: float, par: float, ratio: float
127
+) -> tuple[Verdict, float, str]:
128
+    """Apply the intent-aware pass rule and return (verdict, score, message)."""
129
+    base_msg = f"verb={verb:+.3f}, para={par:+.3f}, ratio={ratio:.2f}"
130
+    if spec.intent == "memorize":
131
+        verd = (
132
+            Verdict.PASS
133
+            if verb >= spec.min_verbatim_lift and ratio <= spec.max_generalization_ratio_if_memorize
134
+            else Verdict.FAIL
135
+        )
136
+        score = min(1.0, max(0.0, verb / max(spec.min_verbatim_lift, 1e-6)))
137
+        return verd, score, f"{base_msg} — intent=memorize"
138
+    # Default: generalize (or "both")
139
+    passed = verb >= spec.min_verbatim_lift and ratio >= spec.min_generalization_ratio
140
+    verd = Verdict.PASS if passed else Verdict.FAIL
141
+    gen_component = min(1.0, max(0.0, ratio / max(spec.min_generalization_ratio, 1e-6)))
142
+    verb_component = min(1.0, max(0.0, verb / max(spec.min_verbatim_lift, 1e-6)))
143
+    score = 0.5 * gen_component + 0.5 * verb_component
144
+    return verd, score, f"{base_msg} — intent={spec.intent}"
145
+
146
+
147
+def _token_estimate(s: str) -> int:
148
+    return max(1, len(s) // 4)
tests/unit/test_probe_paraphrase_invariance.pyadded
@@ -0,0 +1,91 @@
1
+"""Tests for :mod:`dlm_sway.probes.paraphrase_invariance`."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6
+from dlm_sway.core.result import Verdict
7
+from dlm_sway.probes.base import RunContext, build_probe
8
+
9
+
10
+def _backend(*, par_lift_fraction: float, verb_lift: float = 10.0) -> DummyDifferentialBackend:
11
+    """Return a backend with tunable verbatim/paraphrase lifts.
12
+
13
+    The ft view adds ``verb_lift`` nats to the verbatim (Q,A) logprob
14
+    and ``par_lift_fraction * verb_lift`` to paraphrase logprobs.
15
+    """
16
+    base = DummyResponses(
17
+        logprobs={
18
+            ("Q", "A"): -20.0,
19
+            ("Q_par1", "A"): -20.0,
20
+            ("Q_par2", "A"): -20.0,
21
+        }
22
+    )
23
+    ft = DummyResponses(
24
+        logprobs={
25
+            ("Q", "A"): -20.0 + verb_lift,
26
+            ("Q_par1", "A"): -20.0 + par_lift_fraction * verb_lift,
27
+            ("Q_par2", "A"): -20.0 + par_lift_fraction * verb_lift,
28
+        }
29
+    )
30
+    return DummyDifferentialBackend(base=base, ft=ft)
31
+
32
+
33
+def test_pass_when_generalizing() -> None:
34
+    # High paraphrase lift + high verbatim → healthy generalization.
35
+    backend = _backend(par_lift_fraction=0.9)
36
+    probe, spec = build_probe(
37
+        {
38
+            "name": "pi",
39
+            "kind": "paraphrase_invariance",
40
+            "intent": "generalize",
41
+            "min_verbatim_lift": 0.05,
42
+            "min_generalization_ratio": 0.5,
43
+            "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1", "Q_par2"]}],
44
+        }
45
+    )
46
+    ctx = RunContext(backend=backend)
47
+    result = probe.run(spec, ctx)
48
+    assert result.verdict == Verdict.PASS
49
+    assert result.raw is not None
50
+    assert result.raw >= 0.5
51
+
52
+
53
+def test_fails_when_only_memorized_but_intent_generalize() -> None:
54
+    backend = _backend(par_lift_fraction=0.0)
55
+    probe, spec = build_probe(
56
+        {
57
+            "name": "pi",
58
+            "kind": "paraphrase_invariance",
59
+            "intent": "generalize",
60
+            "min_verbatim_lift": 0.05,
61
+            "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}],
62
+        }
63
+    )
64
+    ctx = RunContext(backend=backend)
65
+    result = probe.run(spec, ctx)
66
+    assert result.verdict == Verdict.FAIL
67
+
68
+
69
+def test_passes_memorize_intent_when_only_memorized() -> None:
70
+    backend = _backend(par_lift_fraction=0.0)
71
+    probe, spec = build_probe(
72
+        {
73
+            "name": "pi",
74
+            "kind": "paraphrase_invariance",
75
+            "intent": "memorize",
76
+            "min_verbatim_lift": 0.05,
77
+            "max_generalization_ratio_if_memorize": 0.3,
78
+            "cases": [{"prompt": "Q", "gold": "A", "paraphrases": ["Q_par1"]}],
79
+        }
80
+    )
81
+    ctx = RunContext(backend=backend)
82
+    result = probe.run(spec, ctx)
83
+    assert result.verdict == Verdict.PASS
84
+
85
+
86
+def test_error_on_empty_cases() -> None:
87
+    probe, spec = build_probe({"name": "pi", "kind": "paraphrase_invariance", "cases": []})
88
+    backend = _backend(par_lift_fraction=0.9)
89
+    ctx = RunContext(backend=backend)
90
+    result = probe.run(spec, ctx)
91
+    assert result.verdict == Verdict.ERROR