tenseleyflow/sway / ed303dd

Browse files

sway(probes): A2 adapter_revert via sentence embeddings

Authored by espadonne
SHA
ed303dd9f10fceae4dab13c6c7e56637717ad1fa
Parents
c5cfd2e
Tree
fe04c0d

2 changed files

StatusFile+-
A src/dlm_sway/probes/adapter_revert.py 178 0
A tests/unit/test_probe_adapter_revert.py 170 0
src/dlm_sway/probes/adapter_revert.pyadded
@@ -0,0 +1,178 @@
1
+"""A2 AdapterRevert — does the fine-tuned model drift back to base under pressure?
2
+
3
+For each test case the user provides a prompt, a "gold" answer (the
4
+adapter's intended response), and one or more adversarial paraphrases of
5
+the prompt. We generate base-model and ft-model completions on every
6
+paraphrase and ask: does the ft output cluster semantically with the
7
+base's output (revert) or with the gold (adhere)?
8
+
9
+Signal: ``revert_rate`` = fraction of (case, paraphrase) pairs where
10
+``cos(ft, base) > cos(ft, gold)``. A healthy fine-tune holds below 25%.
11
+
12
+Needs sentence embeddings. Without the ``semsim`` extra installed the
13
+probe returns :attr:`Verdict.SKIP` with a pip hint — deterministic
14
+n-gram fallbacks don't carry semantic equivalence reliably enough to
15
+drive a revert decision, and we'd rather be honest than lossy.
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+from typing import Any, Literal
21
+
22
+from pydantic import BaseModel, ConfigDict, Field
23
+
24
+from dlm_sway.core.errors import BackendNotAvailableError
25
+from dlm_sway.core.result import ProbeResult, Verdict
26
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
27
+
28
+
29
+class AdapterRevertCase(BaseModel):
30
+    """One revert test case."""
31
+
32
+    model_config = ConfigDict(extra="forbid", frozen=True)
33
+
34
+    prompt: str
35
+    gold: str
36
+    """What the adapter is supposed to produce."""
37
+    paraphrases: list[str] = Field(default_factory=list, min_length=1)
38
+    """At least one paraphrase is required — revert is observed under
39
+    reframing, not on the original prompt."""
40
+
41
+
42
+class AdapterRevertSpec(ProbeSpec):
43
+    kind: Literal["adapter_revert"] = "adapter_revert"
44
+    cases: list[AdapterRevertCase] = Field(default_factory=list)
45
+    max_new_tokens: int = 64
46
+    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
47
+    """HF id of the embedder. Default is ~80 MB, CPU-friendly."""
48
+    base_gold_similarity_cap: float = 0.75
49
+    """Skip pairs where base and gold are trivially similar — those
50
+    can't distinguish revert from adherence, and including them would
51
+    inflate the revert rate with noise."""
52
+    assert_revert_rate_lt: float = 0.25
53
+
54
+
55
+class AdapterRevertProbe(Probe):
56
+    kind = "adapter_revert"
57
+    spec_cls = AdapterRevertSpec
58
+    category = "adherence"
59
+
60
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
61
+        assert isinstance(spec, AdapterRevertSpec)
62
+        if not spec.cases:
63
+            return ProbeResult(
64
+                name=spec.name,
65
+                kind=spec.kind,
66
+                verdict=Verdict.ERROR,
67
+                score=None,
68
+                message="no cases provided",
69
+            )
70
+
71
+        try:
72
+            embed = _load_embedder(spec.embedding_model)
73
+        except BackendNotAvailableError as exc:
74
+            return ProbeResult(
75
+                name=spec.name,
76
+                kind=spec.kind,
77
+                verdict=Verdict.SKIP,
78
+                score=None,
79
+                message=str(exc),
80
+            )
81
+
82
+        import numpy as np
83
+
84
+        total = 0
85
+        reverts = 0
86
+        dropped_trivial = 0
87
+        per_case: list[dict[str, Any]] = []
88
+        for case in spec.cases:
89
+            gold_vec = embed([case.gold])[0]
90
+            for pp in case.paraphrases:
91
+                with ctx.backend.as_base() as bv:
92
+                    base_gen = bv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
93
+                with ctx.backend.as_finetuned() as fv:
94
+                    ft_gen = fv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
95
+                vecs = embed([base_gen, ft_gen])
96
+                base_vec, ft_vec = vecs[0], vecs[1]
97
+                base_gold = _cosine(base_vec, gold_vec)
98
+                if base_gold > spec.base_gold_similarity_cap:
99
+                    dropped_trivial += 1
100
+                    continue
101
+                cos_ft_base = _cosine(ft_vec, base_vec)
102
+                cos_ft_gold = _cosine(ft_vec, gold_vec)
103
+                total += 1
104
+                if cos_ft_base > cos_ft_gold:
105
+                    reverts += 1
106
+                per_case.append(
107
+                    {
108
+                        "prompt": pp[:80],
109
+                        "cos_ft_base": cos_ft_base,
110
+                        "cos_ft_gold": cos_ft_gold,
111
+                        "reverted": cos_ft_base > cos_ft_gold,
112
+                    }
113
+                )
114
+
115
+        if total == 0:
116
+            return ProbeResult(
117
+                name=spec.name,
118
+                kind=spec.kind,
119
+                verdict=Verdict.WARN,
120
+                score=0.5,
121
+                message=(
122
+                    f"all {dropped_trivial} cases had base≈gold (> "
123
+                    f"{spec.base_gold_similarity_cap}) — no separable signal"
124
+                ),
125
+                evidence={"dropped_trivial": dropped_trivial, "weight": spec.weight},
126
+            )
127
+
128
+        rate = reverts / total
129
+        verdict = Verdict.PASS if rate < spec.assert_revert_rate_lt else Verdict.FAIL
130
+        score = max(0.0, 1.0 - rate / max(spec.assert_revert_rate_lt, 1e-6))
131
+        score = float(np.clip(score, 0.0, 1.0))
132
+
133
+        return ProbeResult(
134
+            name=spec.name,
135
+            kind=spec.kind,
136
+            verdict=verdict,
137
+            score=score,
138
+            raw=rate,
139
+            evidence={
140
+                "revert_rate": rate,
141
+                "reverts": reverts,
142
+                "total": total,
143
+                "dropped_trivial": dropped_trivial,
144
+                "per_case": per_case[:8],  # cap to keep JSON bounded
145
+                "weight": spec.weight,
146
+            },
147
+            message=f"revert_rate={rate:.2%} (reverts={reverts}/{total}, dropped_trivial={dropped_trivial})",
148
+        )
149
+
150
+
151
+def _load_embedder(model_id: str):  # type: ignore[no-untyped-def]
152
+    """Return a callable ``list[str] -> np.ndarray`` over encoded vectors."""
153
+    try:
154
+        from sentence_transformers import SentenceTransformer
155
+    except ImportError as exc:
156
+        raise BackendNotAvailableError(
157
+            "adapter_revert",
158
+            extra="semsim",
159
+            hint="adapter_revert relies on sentence embeddings.",
160
+        ) from exc
161
+    st = SentenceTransformer(model_id)
162
+
163
+    def _embed(texts: list[str]):  # type: ignore[no-untyped-def]
164
+        return st.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
165
+
166
+    return _embed
167
+
168
+
169
+def _cosine(a: Any, b: Any) -> float:
170
+    import numpy as np
171
+
172
+    av = np.asarray(a, dtype=np.float64)
173
+    bv = np.asarray(b, dtype=np.float64)
174
+    na = float(np.linalg.norm(av))
175
+    nb = float(np.linalg.norm(bv))
176
+    if na == 0.0 or nb == 0.0:
177
+        return 0.0
178
+    return float(np.dot(av, bv) / (na * nb))
tests/unit/test_probe_adapter_revert.pyadded
@@ -0,0 +1,170 @@
1
+"""Tests for :mod:`dlm_sway.probes.adapter_revert`.
2
+
3
+We stub out the embedder so these tests don't need sentence-transformers
4
+installed. The ``probe.py`` SKIP path for the missing-extra case is
5
+covered separately by monkeypatching the importer.
6
+"""
7
+
8
+from __future__ import annotations
9
+
10
+from typing import Any
11
+
12
+import numpy as np
13
+import pytest
14
+
15
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
16
+from dlm_sway.core.result import Verdict
17
+from dlm_sway.probes.adapter_revert import AdapterRevertProbe
18
+from dlm_sway.probes.base import RunContext, build_probe
19
+
20
+
21
+def _backend(*, ft_like_base: bool = False) -> DummyDifferentialBackend:
22
+    base = DummyResponses(
23
+        generations={
24
+            "pp1": "cats are mammals",
25
+            "pp2": "cats have fur",
26
+        }
27
+    )
28
+    if ft_like_base:
29
+        ft_gens = dict(base.generations)
30
+    else:
31
+        ft_gens = {
32
+            "pp1": "dolphins are mammals",
33
+            "pp2": "dolphins are smart",
34
+        }
35
+    ft = DummyResponses(generations=ft_gens)
36
+    return DummyDifferentialBackend(base=base, ft=ft)
37
+
38
+
39
+def _stub_embedder(text_to_vec: dict[str, np.ndarray]):  # type: ignore[no-untyped-def]
40
+    def _encode(texts: list[str]):  # type: ignore[no-untyped-def]
41
+        return np.stack([text_to_vec[t] for t in texts])
42
+
43
+    return _encode
44
+
45
+
46
+@pytest.fixture
47
+def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
48
+    """Install a stub embedder with a controllable text→vec mapping.
49
+
50
+    Tests populate the dict before calling ``probe.run()``.
51
+    """
52
+    table: dict[str, np.ndarray] = {}
53
+    monkeypatch.setattr(
54
+        "dlm_sway.probes.adapter_revert._load_embedder",
55
+        lambda _model_id: _stub_embedder(table),  # type: ignore[arg-type]
56
+    )
57
+    return table
58
+
59
+
60
+class TestAdapterRevert:
61
+    def test_healthy_adapter_passes(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
62
+        # gold and ft-outputs cluster together, base outputs cluster elsewhere.
63
+        monkeyed_embed["cats are mammals"] = np.array([1.0, 0.0])
64
+        monkeyed_embed["cats have fur"] = np.array([1.0, 0.0])
65
+        monkeyed_embed["dolphins are mammals"] = np.array([0.0, 1.0])
66
+        monkeyed_embed["dolphins are smart"] = np.array([0.0, 1.0])
67
+        monkeyed_embed["the answer is dolphins"] = np.array([0.0, 1.0])  # gold
68
+
69
+        probe, spec = build_probe(
70
+            {
71
+                "name": "rev",
72
+                "kind": "adapter_revert",
73
+                "cases": [
74
+                    {
75
+                        "prompt": "anything",
76
+                        "gold": "the answer is dolphins",
77
+                        "paraphrases": ["pp1", "pp2"],
78
+                    }
79
+                ],
80
+                "assert_revert_rate_lt": 0.25,
81
+            }
82
+        )
83
+        ctx = RunContext(backend=_backend(ft_like_base=False))
84
+        result = probe.run(spec, ctx)
85
+        assert result.verdict == Verdict.PASS
86
+        assert result.raw == 0.0
87
+
88
+    def test_reverting_adapter_fails(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
89
+        # ft matches base (reverted), diverges from gold.
90
+        monkeyed_embed["cats are mammals"] = np.array([1.0, 0.0])
91
+        monkeyed_embed["cats have fur"] = np.array([1.0, 0.0])
92
+        monkeyed_embed["the answer is dolphins"] = np.array([0.0, 1.0])  # gold
93
+
94
+        probe, spec = build_probe(
95
+            {
96
+                "name": "rev",
97
+                "kind": "adapter_revert",
98
+                "cases": [
99
+                    {
100
+                        "prompt": "anything",
101
+                        "gold": "the answer is dolphins",
102
+                        "paraphrases": ["pp1", "pp2"],
103
+                    }
104
+                ],
105
+            }
106
+        )
107
+        ctx = RunContext(backend=_backend(ft_like_base=True))
108
+        result = probe.run(spec, ctx)
109
+        assert result.verdict == Verdict.FAIL
110
+        assert result.raw == 1.0  # 100% revert
111
+
112
+    def test_trivially_similar_cases_dropped(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
113
+        # base and gold are identical → drop.
114
+        v = np.array([1.0, 0.0])
115
+        monkeyed_embed["cats are mammals"] = v
116
+        monkeyed_embed["cats have fur"] = v
117
+        monkeyed_embed["dolphins are mammals"] = np.array([0.0, 1.0])
118
+        monkeyed_embed["dolphins are smart"] = np.array([0.0, 1.0])
119
+        monkeyed_embed["cats are mammals too"] = v  # gold — matches base
120
+
121
+        probe, spec = build_probe(
122
+            {
123
+                "name": "rev",
124
+                "kind": "adapter_revert",
125
+                "cases": [
126
+                    {
127
+                        "prompt": "anything",
128
+                        "gold": "cats are mammals too",
129
+                        "paraphrases": ["pp1", "pp2"],
130
+                    }
131
+                ],
132
+            }
133
+        )
134
+        ctx = RunContext(backend=_backend(ft_like_base=False))
135
+        result = probe.run(spec, ctx)
136
+        # Both paraphrase pairs trivially similar → WARN (no separable signal).
137
+        assert result.verdict == Verdict.WARN
138
+        assert result.evidence["dropped_trivial"] == 2
139
+
140
+    def test_no_cases_errors(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
141
+        probe, spec = build_probe({"name": "rev", "kind": "adapter_revert", "cases": []})
142
+        ctx = RunContext(backend=_backend())
143
+        result = probe.run(spec, ctx)
144
+        assert result.verdict == Verdict.ERROR
145
+
146
+
147
+class TestMissingSemsim:
148
+    def test_skip_when_sentence_transformers_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
149
+        from dlm_sway.core.errors import BackendNotAvailableError
150
+
151
+        def raiser(_model_id: Any) -> Any:  # type: ignore[no-untyped-def]
152
+            raise BackendNotAvailableError(
153
+                "adapter_revert",
154
+                extra="semsim",
155
+                hint="adapter_revert relies on sentence embeddings.",
156
+            )
157
+
158
+        monkeypatch.setattr(
159
+            "dlm_sway.probes.adapter_revert._load_embedder",
160
+            raiser,  # type: ignore[arg-type]
161
+        )
162
+        probe = AdapterRevertProbe()
163
+        spec = probe.spec_cls(
164
+            name="rev",
165
+            cases=[{"prompt": "x", "gold": "y", "paraphrases": ["pp1"]}],  # type: ignore[list-item]
166
+        )
167
+        ctx = RunContext(backend=_backend())
168
+        result = probe.run(spec, ctx)
169
+        assert result.verdict == Verdict.SKIP
170
+        assert "semsim" in result.message