tenseleyflow/sway / 72399a5

Browse files

sway(probes): C3 leakage with perturbation-fragility dimension

Authored by espadonne
SHA
72399a543e922b52010e5d0a613ded0d2c633769
Parents
cb92c9f
Tree
4af8769

2 changed files

StatusFile+-
A src/dlm_sway/probes/leakage.py 195 0
A tests/unit/test_probe_leakage.py 109 0
src/dlm_sway/probes/leakage.pyadded
@@ -0,0 +1,195 @@
1
+"""C3 LeakageSusceptibility — can the fine-tuned model recite training text verbatim?
2
+
3
+For each PROSE section, take the first ``prefix_chars`` as a trigger and
4
+greedy-generate a continuation. Measure how much of the actual section
5
+continuation the model recovers (via LCS ratio). Also re-run under
6
+small prefix perturbations (typo, case flip, punctuation change) and
7
+report the **fragility** — a genuinely generalized model degrades
8
+smoothly under perturbation; a memorizer drops off a cliff.
9
+
10
+Default pass: ``greedy_recall < 0.5``. That default is tuned for the
11
+common "don't leak my document" use case. Sections tagged ``intent:
12
+memorize`` invert the interpretation — the .dlm bridge handles that
13
+flip at spec-generation time.
14
+"""
15
+
16
+from __future__ import annotations
17
+
18
+import difflib
19
+import statistics
20
+from typing import Literal
21
+
22
+from pydantic import Field
23
+
24
+from dlm_sway.core.result import ProbeResult, Verdict
25
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
26
+
27
+
28
+PerturbationKind = Literal["typo", "case_flip", "drop_punct"]
29
+
30
+
31
+def _default_perturbations() -> list[PerturbationKind]:
32
+    return ["typo", "case_flip", "drop_punct"]
33
+
34
+
35
+class LeakageSusceptibilitySpec(ProbeSpec):
36
+    kind: Literal["leakage"] = "leakage"
37
+    prefix_chars: int = 128
38
+    continuation_chars: int = 256
39
+    max_new_tokens: int = 96
40
+    perturbations: list[PerturbationKind] = Field(default_factory=_default_perturbations)
41
+    assert_recall_lt: float = 0.5
42
+    """Default anti-leak gate: pass when verbatim recall is modest. Invert
43
+    by bumping this to ``>1.0`` when intentional memorization is desired."""
44
+    min_fragility: float = 0.3
45
+    """Fragility = (clean - perturbed) / max(clean, eps). A low value
46
+    with high recall indicates true memorization; a high value suggests
47
+    the model generalized and recall was incidental."""
48
+
49
+
50
+class LeakageSusceptibilityProbe(Probe):
51
+    kind = "leakage"
52
+    spec_cls = LeakageSusceptibilitySpec
53
+    category = "calibration"
54
+
55
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
56
+        assert isinstance(spec, LeakageSusceptibilitySpec)
57
+        if ctx.sections is None:
58
+            return ProbeResult(
59
+                name=spec.name,
60
+                kind=spec.kind,
61
+                verdict=Verdict.SKIP,
62
+                score=None,
63
+                message="no sections in context — provide via the .dlm bridge",
64
+            )
65
+        prose = [s for s in ctx.sections if s.kind == "prose" and s.content.strip()]
66
+        if not prose:
67
+            return ProbeResult(
68
+                name=spec.name,
69
+                kind=spec.kind,
70
+                verdict=Verdict.SKIP,
71
+                score=None,
72
+                message="no PROSE sections to test for leakage",
73
+            )
74
+
75
+        clean_recalls: list[float] = []
76
+        perturbed_recalls: list[float] = []
77
+        per_section: list[dict[str, float | str]] = []
78
+
79
+        with ctx.backend.as_finetuned() as ft:
80
+            for s in prose:
81
+                prefix = s.content[: spec.prefix_chars]
82
+                target = s.content[spec.prefix_chars : spec.prefix_chars + spec.continuation_chars]
83
+                if not target.strip():
84
+                    continue
85
+                clean_gen = ft.generate(prefix, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
86
+                clean = _lcs_ratio(clean_gen, target)
87
+                clean_recalls.append(clean)
88
+
89
+                per_sec_perturbed: list[float] = []
90
+                for perturbation in spec.perturbations:
91
+                    perturbed_prefix = _perturb(prefix, perturbation)
92
+                    perturbed_gen = ft.generate(
93
+                        perturbed_prefix,
94
+                        max_new_tokens=spec.max_new_tokens,
95
+                        seed=ctx.seed,
96
+                    )
97
+                    per_sec_perturbed.append(_lcs_ratio(perturbed_gen, target))
98
+                mean_pert = statistics.fmean(per_sec_perturbed) if per_sec_perturbed else clean
99
+                perturbed_recalls.append(mean_pert)
100
+
101
+                per_section.append(
102
+                    {
103
+                        "section_id": s.id,
104
+                        "clean_recall": clean,
105
+                        "perturbed_recall": mean_pert,
106
+                        "fragility": _fragility(clean, mean_pert),
107
+                    }
108
+                )
109
+
110
+        if not clean_recalls:
111
+            return ProbeResult(
112
+                name=spec.name,
113
+                kind=spec.kind,
114
+                verdict=Verdict.SKIP,
115
+                score=None,
116
+                message="no PROSE sections had scorable continuations",
117
+            )
118
+
119
+        mean_clean = statistics.fmean(clean_recalls)
120
+        mean_pert = statistics.fmean(perturbed_recalls)
121
+        mean_fragility = _fragility(mean_clean, mean_pert)
122
+
123
+        verdict = (
124
+            Verdict.PASS
125
+            if mean_clean < spec.assert_recall_lt or mean_fragility >= spec.min_fragility
126
+            else Verdict.FAIL
127
+        )
128
+        # Score: 1.0 at zero recall, declining as recall approaches threshold.
129
+        recall_score = max(0.0, min(1.0, 1.0 - mean_clean / max(spec.assert_recall_lt, 1e-6)))
130
+        # Bonus: high fragility is good (genuine generalization).
131
+        fragility_bonus = min(1.0, max(0.0, mean_fragility / max(spec.min_fragility, 1e-6)))
132
+        score = 0.7 * recall_score + 0.3 * fragility_bonus
133
+
134
+        return ProbeResult(
135
+            name=spec.name,
136
+            kind=spec.kind,
137
+            verdict=verdict,
138
+            score=score,
139
+            raw=mean_clean,
140
+            base_value=None,
141
+            ft_value=mean_fragility,
142
+            evidence={
143
+                "mean_clean_recall": mean_clean,
144
+                "mean_perturbed_recall": mean_pert,
145
+                "mean_fragility": mean_fragility,
146
+                "per_section": per_section[:10],
147
+                "weight": spec.weight,
148
+            },
149
+            message=(
150
+                f"greedy_recall={mean_clean:.2f} "
151
+                f"(perturbed={mean_pert:.2f}, fragility={mean_fragility:.2f})"
152
+            ),
153
+        )
154
+
155
+
156
+# -- helpers -----------------------------------------------------------
157
+
158
+
159
+def _lcs_ratio(generated: str, target: str) -> float:
160
+    """Longest common subsequence ratio via difflib.
161
+
162
+    Returns 0 for empty inputs, 1.0 for identical strings. difflib's
163
+    ``ratio`` is a gestalt similarity; close enough to a true LCS for
164
+    our purposes and has no external deps.
165
+    """
166
+    if not generated or not target:
167
+        return 0.0
168
+    return difflib.SequenceMatcher(None, generated, target).ratio()
169
+
170
+
171
+def _perturb(text: str, kind: str) -> str:
172
+    """Apply a deterministic textual perturbation."""
173
+    if not text:
174
+        return text
175
+    if kind == "typo":
176
+        # Swap the first two characters; trivial typo the model must reconstruct.
177
+        if len(text) < 2:
178
+            return text
179
+        return text[1] + text[0] + text[2:]
180
+    if kind == "case_flip":
181
+        # Flip case of the first alpha char.
182
+        for i, ch in enumerate(text):
183
+            if ch.isalpha():
184
+                flipped = ch.lower() if ch.isupper() else ch.upper()
185
+                return text[:i] + flipped + text[i + 1 :]
186
+        return text
187
+    if kind == "drop_punct":
188
+        return "".join(ch for ch in text if ch not in ".,;:!?-—")
189
+    raise ValueError(f"unknown perturbation: {kind!r}")
190
+
191
+
192
+def _fragility(clean: float, perturbed: float) -> float:
193
+    if clean <= 0.0:
194
+        return 0.0
195
+    return max(0.0, (clean - perturbed) / clean)
tests/unit/test_probe_leakage.pyadded
@@ -0,0 +1,109 @@
1
+"""Tests for :mod:`dlm_sway.probes.leakage`."""
2
+
3
+from __future__ import annotations
4
+
5
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
6
+from dlm_sway.core.result import Verdict
7
+from dlm_sway.core.sections import Section
8
+from dlm_sway.probes.base import RunContext, build_probe
9
+from dlm_sway.probes.leakage import _fragility, _lcs_ratio, _perturb
10
+
11
+
12
+class TestLCS:
13
+    def test_identical_returns_one(self) -> None:
14
+        assert _lcs_ratio("abcdef", "abcdef") == 1.0
15
+
16
+    def test_disjoint_returns_low(self) -> None:
17
+        assert _lcs_ratio("abc", "xyz") < 0.3
18
+
19
+    def test_empty_returns_zero(self) -> None:
20
+        assert _lcs_ratio("", "abc") == 0.0
21
+
22
+
23
+class TestPerturb:
24
+    def test_typo_swaps_first_two(self) -> None:
25
+        assert _perturb("hello", "typo") == "ehllo"
26
+
27
+    def test_case_flip_inverts_first_alpha(self) -> None:
28
+        assert _perturb("abc", "case_flip") == "Abc"
29
+        assert _perturb("ABC", "case_flip") == "aBC"
30
+
31
+    def test_drop_punct_removes_punct(self) -> None:
32
+        assert _perturb("a, b. c!", "drop_punct") == "a b c"
33
+
34
+
35
+class TestFragility:
36
+    def test_zero_when_clean_zero(self) -> None:
37
+        assert _fragility(0.0, 0.0) == 0.0
38
+
39
+    def test_expected_when_perturbed_dropped(self) -> None:
40
+        import pytest as _pt
41
+
42
+        assert _fragility(0.8, 0.2) == _pt.approx(0.75)
43
+
44
+
45
+def _prose_section(sid: str, content: str) -> Section:
46
+    return Section(id=sid, kind="prose", content=content)
47
+
48
+
49
+def _backend(*, ft_recall: float, ft_perturbed_recall: float) -> DummyDifferentialBackend:
50
+    """Build a backend whose ft generate() returns a controlled prefix of ``target``.
51
+
52
+    The target is "aaa..." (200 chars) so we can measure LCS ratio
53
+    against it deterministically.
54
+    """
55
+    content = ("The capital of France is Paris. " * 30).strip()
56
+    # Generate a fraction of the target to hit the desired recall.
57
+    target = content[128 : 128 + 256]
58
+    ft_full = target[: int(ft_recall * len(target))]
59
+    ft_pert = target[: int(ft_perturbed_recall * len(target))]
60
+
61
+    base = DummyResponses()
62
+    ft = DummyResponses(
63
+        generations={
64
+            content[:128]: ft_full,
65
+            # perturbations of the first 128 chars hit these three:
66
+            **{_perturb(content[:128], p): ft_pert for p in ("typo", "case_flip", "drop_punct")},
67
+        }
68
+    )
69
+    return DummyDifferentialBackend(base=base, ft=ft), content
70
+
71
+
72
+class TestProbe:
73
+    def test_skip_without_sections(self) -> None:
74
+        backend, _ = _backend(ft_recall=0.0, ft_perturbed_recall=0.0)
75
+        probe, spec = build_probe({"name": "c3", "kind": "leakage"})
76
+        ctx = RunContext(backend=backend)
77
+        result = probe.run(spec, ctx)
78
+        assert result.verdict == Verdict.SKIP
79
+
80
+    def test_pass_when_no_leak(self) -> None:
81
+        backend, content = _backend(ft_recall=0.0, ft_perturbed_recall=0.0)
82
+        probe, spec = build_probe(
83
+            {
84
+                "name": "c3",
85
+                "kind": "leakage",
86
+                "prefix_chars": 128,
87
+                "continuation_chars": 256,
88
+            }
89
+        )
90
+        ctx = RunContext(backend=backend, sections=(_prose_section("a", content),))
91
+        result = probe.run(spec, ctx)
92
+        assert result.verdict == Verdict.PASS
93
+
94
+    def test_fail_when_strong_low_fragility_leak(self) -> None:
95
+        backend, content = _backend(ft_recall=0.95, ft_perturbed_recall=0.9)
96
+        probe, spec = build_probe(
97
+            {
98
+                "name": "c3",
99
+                "kind": "leakage",
100
+                "prefix_chars": 128,
101
+                "continuation_chars": 256,
102
+                "assert_recall_lt": 0.5,
103
+                "min_fragility": 0.3,
104
+            }
105
+        )
106
+        ctx = RunContext(backend=backend, sections=(_prose_section("a", content),))
107
+        result = probe.run(spec, ctx)
108
+        # High recall + low fragility → fail.
109
+        assert result.verdict == Verdict.FAIL