sway(probes): A2 adapter_revert via sentence embeddings
- SHA
ed303dd9f10fceae4dab13c6c7e56637717ad1fa- Parents
-
c5cfd2e - Tree
fe04c0d
ed303dd
ed303dd9f10fceae4dab13c6c7e56637717ad1fac5cfd2e
fe04c0d| Status | File | + | - |
|---|---|---|---|
| A |
src/dlm_sway/probes/adapter_revert.py
|
178 | 0 |
| A |
tests/unit/test_probe_adapter_revert.py
|
170 | 0 |
src/dlm_sway/probes/adapter_revert.pyadded@@ -0,0 +1,178 @@ | ||
| 1 | +"""A2 AdapterRevert — does the fine-tuned model drift back to base under pressure? | |
| 2 | + | |
| 3 | +For each test case the user provides a prompt, a "gold" answer (the | |
| 4 | +adapter's intended response), and one or more adversarial paraphrases of | |
| 5 | +the prompt. We generate base-model and ft-model completions on every | |
| 6 | +paraphrase and ask: does the ft output cluster semantically with the | |
| 7 | +base's output (revert) or with the gold (adhere)? | |
| 8 | + | |
| 9 | +Signal: ``revert_rate`` = fraction of (case, paraphrase) pairs where | |
| 10 | +``cos(ft, base) > cos(ft, gold)``. A healthy fine-tune holds below 25%. | |
| 11 | + | |
| 12 | +Needs sentence embeddings. Without the ``semsim`` extra installed the | |
| 13 | +probe returns :attr:`Verdict.SKIP` with a pip hint — deterministic | |
| 14 | +n-gram fallbacks don't carry semantic equivalence reliably enough to | |
| 15 | +drive a revert decision, and we'd rather be honest than lossy. | |
| 16 | +""" | |
| 17 | + | |
| 18 | +from __future__ import annotations | |
| 19 | + | |
| 20 | +from typing import Any, Literal | |
| 21 | + | |
| 22 | +from pydantic import BaseModel, ConfigDict, Field | |
| 23 | + | |
| 24 | +from dlm_sway.core.errors import BackendNotAvailableError | |
| 25 | +from dlm_sway.core.result import ProbeResult, Verdict | |
| 26 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext | |
| 27 | + | |
| 28 | + | |
| 29 | +class AdapterRevertCase(BaseModel): | |
| 30 | + """One revert test case.""" | |
| 31 | + | |
| 32 | + model_config = ConfigDict(extra="forbid", frozen=True) | |
| 33 | + | |
| 34 | + prompt: str | |
| 35 | + gold: str | |
| 36 | + """What the adapter is supposed to produce.""" | |
| 37 | + paraphrases: list[str] = Field(default_factory=list, min_length=1) | |
| 38 | + """At least one paraphrase is required — revert is observed under | |
| 39 | + reframing, not on the original prompt.""" | |
| 40 | + | |
| 41 | + | |
| 42 | +class AdapterRevertSpec(ProbeSpec): | |
| 43 | + kind: Literal["adapter_revert"] = "adapter_revert" | |
| 44 | + cases: list[AdapterRevertCase] = Field(default_factory=list) | |
| 45 | + max_new_tokens: int = 64 | |
| 46 | + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" | |
| 47 | + """HF id of the embedder. Default is ~80 MB, CPU-friendly.""" | |
| 48 | + base_gold_similarity_cap: float = 0.75 | |
| 49 | + """Skip pairs where base and gold are trivially similar — those | |
| 50 | + can't distinguish revert from adherence, and including them would | |
| 51 | + inflate the revert rate with noise.""" | |
| 52 | + assert_revert_rate_lt: float = 0.25 | |
| 53 | + | |
| 54 | + | |
| 55 | +class AdapterRevertProbe(Probe): | |
| 56 | + kind = "adapter_revert" | |
| 57 | + spec_cls = AdapterRevertSpec | |
| 58 | + category = "adherence" | |
| 59 | + | |
| 60 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: | |
| 61 | + assert isinstance(spec, AdapterRevertSpec) | |
| 62 | + if not spec.cases: | |
| 63 | + return ProbeResult( | |
| 64 | + name=spec.name, | |
| 65 | + kind=spec.kind, | |
| 66 | + verdict=Verdict.ERROR, | |
| 67 | + score=None, | |
| 68 | + message="no cases provided", | |
| 69 | + ) | |
| 70 | + | |
| 71 | + try: | |
| 72 | + embed = _load_embedder(spec.embedding_model) | |
| 73 | + except BackendNotAvailableError as exc: | |
| 74 | + return ProbeResult( | |
| 75 | + name=spec.name, | |
| 76 | + kind=spec.kind, | |
| 77 | + verdict=Verdict.SKIP, | |
| 78 | + score=None, | |
| 79 | + message=str(exc), | |
| 80 | + ) | |
| 81 | + | |
| 82 | + import numpy as np | |
| 83 | + | |
| 84 | + total = 0 | |
| 85 | + reverts = 0 | |
| 86 | + dropped_trivial = 0 | |
| 87 | + per_case: list[dict[str, Any]] = [] | |
| 88 | + for case in spec.cases: | |
| 89 | + gold_vec = embed([case.gold])[0] | |
| 90 | + for pp in case.paraphrases: | |
| 91 | + with ctx.backend.as_base() as bv: | |
| 92 | + base_gen = bv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) | |
| 93 | + with ctx.backend.as_finetuned() as fv: | |
| 94 | + ft_gen = fv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed) | |
| 95 | + vecs = embed([base_gen, ft_gen]) | |
| 96 | + base_vec, ft_vec = vecs[0], vecs[1] | |
| 97 | + base_gold = _cosine(base_vec, gold_vec) | |
| 98 | + if base_gold > spec.base_gold_similarity_cap: | |
| 99 | + dropped_trivial += 1 | |
| 100 | + continue | |
| 101 | + cos_ft_base = _cosine(ft_vec, base_vec) | |
| 102 | + cos_ft_gold = _cosine(ft_vec, gold_vec) | |
| 103 | + total += 1 | |
| 104 | + if cos_ft_base > cos_ft_gold: | |
| 105 | + reverts += 1 | |
| 106 | + per_case.append( | |
| 107 | + { | |
| 108 | + "prompt": pp[:80], | |
| 109 | + "cos_ft_base": cos_ft_base, | |
| 110 | + "cos_ft_gold": cos_ft_gold, | |
| 111 | + "reverted": cos_ft_base > cos_ft_gold, | |
| 112 | + } | |
| 113 | + ) | |
| 114 | + | |
| 115 | + if total == 0: | |
| 116 | + return ProbeResult( | |
| 117 | + name=spec.name, | |
| 118 | + kind=spec.kind, | |
| 119 | + verdict=Verdict.WARN, | |
| 120 | + score=0.5, | |
| 121 | + message=( | |
| 122 | + f"all {dropped_trivial} cases had base≈gold (> " | |
| 123 | + f"{spec.base_gold_similarity_cap}) — no separable signal" | |
| 124 | + ), | |
| 125 | + evidence={"dropped_trivial": dropped_trivial, "weight": spec.weight}, | |
| 126 | + ) | |
| 127 | + | |
| 128 | + rate = reverts / total | |
| 129 | + verdict = Verdict.PASS if rate < spec.assert_revert_rate_lt else Verdict.FAIL | |
| 130 | + score = max(0.0, 1.0 - rate / max(spec.assert_revert_rate_lt, 1e-6)) | |
| 131 | + score = float(np.clip(score, 0.0, 1.0)) | |
| 132 | + | |
| 133 | + return ProbeResult( | |
| 134 | + name=spec.name, | |
| 135 | + kind=spec.kind, | |
| 136 | + verdict=verdict, | |
| 137 | + score=score, | |
| 138 | + raw=rate, | |
| 139 | + evidence={ | |
| 140 | + "revert_rate": rate, | |
| 141 | + "reverts": reverts, | |
| 142 | + "total": total, | |
| 143 | + "dropped_trivial": dropped_trivial, | |
| 144 | + "per_case": per_case[:8], # cap to keep JSON bounded | |
| 145 | + "weight": spec.weight, | |
| 146 | + }, | |
| 147 | + message=f"revert_rate={rate:.2%} (reverts={reverts}/{total}, dropped_trivial={dropped_trivial})", | |
| 148 | + ) | |
| 149 | + | |
| 150 | + | |
| 151 | +def _load_embedder(model_id: str): # type: ignore[no-untyped-def] | |
| 152 | + """Return a callable ``list[str] -> np.ndarray`` over encoded vectors.""" | |
| 153 | + try: | |
| 154 | + from sentence_transformers import SentenceTransformer | |
| 155 | + except ImportError as exc: | |
| 156 | + raise BackendNotAvailableError( | |
| 157 | + "adapter_revert", | |
| 158 | + extra="semsim", | |
| 159 | + hint="adapter_revert relies on sentence embeddings.", | |
| 160 | + ) from exc | |
| 161 | + st = SentenceTransformer(model_id) | |
| 162 | + | |
| 163 | + def _embed(texts: list[str]): # type: ignore[no-untyped-def] | |
| 164 | + return st.encode(texts, convert_to_numpy=True, normalize_embeddings=True) | |
| 165 | + | |
| 166 | + return _embed | |
| 167 | + | |
| 168 | + | |
| 169 | +def _cosine(a: Any, b: Any) -> float: | |
| 170 | + import numpy as np | |
| 171 | + | |
| 172 | + av = np.asarray(a, dtype=np.float64) | |
| 173 | + bv = np.asarray(b, dtype=np.float64) | |
| 174 | + na = float(np.linalg.norm(av)) | |
| 175 | + nb = float(np.linalg.norm(bv)) | |
| 176 | + if na == 0.0 or nb == 0.0: | |
| 177 | + return 0.0 | |
| 178 | + return float(np.dot(av, bv) / (na * nb)) | |
tests/unit/test_probe_adapter_revert.pyadded@@ -0,0 +1,170 @@ | ||
| 1 | +"""Tests for :mod:`dlm_sway.probes.adapter_revert`. | |
| 2 | + | |
| 3 | +We stub out the embedder so these tests don't need sentence-transformers | |
| 4 | +installed. The ``probe.py`` SKIP path for the missing-extra case is | |
| 5 | +covered separately by monkeypatching the importer. | |
| 6 | +""" | |
| 7 | + | |
| 8 | +from __future__ import annotations | |
| 9 | + | |
| 10 | +from typing import Any | |
| 11 | + | |
| 12 | +import numpy as np | |
| 13 | +import pytest | |
| 14 | + | |
| 15 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses | |
| 16 | +from dlm_sway.core.result import Verdict | |
| 17 | +from dlm_sway.probes.adapter_revert import AdapterRevertProbe | |
| 18 | +from dlm_sway.probes.base import RunContext, build_probe | |
| 19 | + | |
| 20 | + | |
| 21 | +def _backend(*, ft_like_base: bool = False) -> DummyDifferentialBackend: | |
| 22 | + base = DummyResponses( | |
| 23 | + generations={ | |
| 24 | + "pp1": "cats are mammals", | |
| 25 | + "pp2": "cats have fur", | |
| 26 | + } | |
| 27 | + ) | |
| 28 | + if ft_like_base: | |
| 29 | + ft_gens = dict(base.generations) | |
| 30 | + else: | |
| 31 | + ft_gens = { | |
| 32 | + "pp1": "dolphins are mammals", | |
| 33 | + "pp2": "dolphins are smart", | |
| 34 | + } | |
| 35 | + ft = DummyResponses(generations=ft_gens) | |
| 36 | + return DummyDifferentialBackend(base=base, ft=ft) | |
| 37 | + | |
| 38 | + | |
| 39 | +def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def] | |
| 40 | + def _encode(texts: list[str]): # type: ignore[no-untyped-def] | |
| 41 | + return np.stack([text_to_vec[t] for t in texts]) | |
| 42 | + | |
| 43 | + return _encode | |
| 44 | + | |
| 45 | + | |
| 46 | +@pytest.fixture | |
| 47 | +def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]: | |
| 48 | + """Install a stub embedder with a controllable text→vec mapping. | |
| 49 | + | |
| 50 | + Tests populate the dict before calling ``probe.run()``. | |
| 51 | + """ | |
| 52 | + table: dict[str, np.ndarray] = {} | |
| 53 | + monkeypatch.setattr( | |
| 54 | + "dlm_sway.probes.adapter_revert._load_embedder", | |
| 55 | + lambda _model_id: _stub_embedder(table), # type: ignore[arg-type] | |
| 56 | + ) | |
| 57 | + return table | |
| 58 | + | |
| 59 | + | |
| 60 | +class TestAdapterRevert: | |
| 61 | + def test_healthy_adapter_passes(self, monkeyed_embed: dict[str, np.ndarray]) -> None: | |
| 62 | + # gold and ft-outputs cluster together, base outputs cluster elsewhere. | |
| 63 | + monkeyed_embed["cats are mammals"] = np.array([1.0, 0.0]) | |
| 64 | + monkeyed_embed["cats have fur"] = np.array([1.0, 0.0]) | |
| 65 | + monkeyed_embed["dolphins are mammals"] = np.array([0.0, 1.0]) | |
| 66 | + monkeyed_embed["dolphins are smart"] = np.array([0.0, 1.0]) | |
| 67 | + monkeyed_embed["the answer is dolphins"] = np.array([0.0, 1.0]) # gold | |
| 68 | + | |
| 69 | + probe, spec = build_probe( | |
| 70 | + { | |
| 71 | + "name": "rev", | |
| 72 | + "kind": "adapter_revert", | |
| 73 | + "cases": [ | |
| 74 | + { | |
| 75 | + "prompt": "anything", | |
| 76 | + "gold": "the answer is dolphins", | |
| 77 | + "paraphrases": ["pp1", "pp2"], | |
| 78 | + } | |
| 79 | + ], | |
| 80 | + "assert_revert_rate_lt": 0.25, | |
| 81 | + } | |
| 82 | + ) | |
| 83 | + ctx = RunContext(backend=_backend(ft_like_base=False)) | |
| 84 | + result = probe.run(spec, ctx) | |
| 85 | + assert result.verdict == Verdict.PASS | |
| 86 | + assert result.raw == 0.0 | |
| 87 | + | |
| 88 | + def test_reverting_adapter_fails(self, monkeyed_embed: dict[str, np.ndarray]) -> None: | |
| 89 | + # ft matches base (reverted), diverges from gold. | |
| 90 | + monkeyed_embed["cats are mammals"] = np.array([1.0, 0.0]) | |
| 91 | + monkeyed_embed["cats have fur"] = np.array([1.0, 0.0]) | |
| 92 | + monkeyed_embed["the answer is dolphins"] = np.array([0.0, 1.0]) # gold | |
| 93 | + | |
| 94 | + probe, spec = build_probe( | |
| 95 | + { | |
| 96 | + "name": "rev", | |
| 97 | + "kind": "adapter_revert", | |
| 98 | + "cases": [ | |
| 99 | + { | |
| 100 | + "prompt": "anything", | |
| 101 | + "gold": "the answer is dolphins", | |
| 102 | + "paraphrases": ["pp1", "pp2"], | |
| 103 | + } | |
| 104 | + ], | |
| 105 | + } | |
| 106 | + ) | |
| 107 | + ctx = RunContext(backend=_backend(ft_like_base=True)) | |
| 108 | + result = probe.run(spec, ctx) | |
| 109 | + assert result.verdict == Verdict.FAIL | |
| 110 | + assert result.raw == 1.0 # 100% revert | |
| 111 | + | |
| 112 | + def test_trivially_similar_cases_dropped(self, monkeyed_embed: dict[str, np.ndarray]) -> None: | |
| 113 | + # base and gold are identical → drop. | |
| 114 | + v = np.array([1.0, 0.0]) | |
| 115 | + monkeyed_embed["cats are mammals"] = v | |
| 116 | + monkeyed_embed["cats have fur"] = v | |
| 117 | + monkeyed_embed["dolphins are mammals"] = np.array([0.0, 1.0]) | |
| 118 | + monkeyed_embed["dolphins are smart"] = np.array([0.0, 1.0]) | |
| 119 | + monkeyed_embed["cats are mammals too"] = v # gold — matches base | |
| 120 | + | |
| 121 | + probe, spec = build_probe( | |
| 122 | + { | |
| 123 | + "name": "rev", | |
| 124 | + "kind": "adapter_revert", | |
| 125 | + "cases": [ | |
| 126 | + { | |
| 127 | + "prompt": "anything", | |
| 128 | + "gold": "cats are mammals too", | |
| 129 | + "paraphrases": ["pp1", "pp2"], | |
| 130 | + } | |
| 131 | + ], | |
| 132 | + } | |
| 133 | + ) | |
| 134 | + ctx = RunContext(backend=_backend(ft_like_base=False)) | |
| 135 | + result = probe.run(spec, ctx) | |
| 136 | + # Both paraphrase pairs trivially similar → WARN (no separable signal). | |
| 137 | + assert result.verdict == Verdict.WARN | |
| 138 | + assert result.evidence["dropped_trivial"] == 2 | |
| 139 | + | |
| 140 | + def test_no_cases_errors(self, monkeyed_embed: dict[str, np.ndarray]) -> None: | |
| 141 | + probe, spec = build_probe({"name": "rev", "kind": "adapter_revert", "cases": []}) | |
| 142 | + ctx = RunContext(backend=_backend()) | |
| 143 | + result = probe.run(spec, ctx) | |
| 144 | + assert result.verdict == Verdict.ERROR | |
| 145 | + | |
| 146 | + | |
| 147 | +class TestMissingSemsim: | |
| 148 | + def test_skip_when_sentence_transformers_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: | |
| 149 | + from dlm_sway.core.errors import BackendNotAvailableError | |
| 150 | + | |
| 151 | + def raiser(_model_id: Any) -> Any: # type: ignore[no-untyped-def] | |
| 152 | + raise BackendNotAvailableError( | |
| 153 | + "adapter_revert", | |
| 154 | + extra="semsim", | |
| 155 | + hint="adapter_revert relies on sentence embeddings.", | |
| 156 | + ) | |
| 157 | + | |
| 158 | + monkeypatch.setattr( | |
| 159 | + "dlm_sway.probes.adapter_revert._load_embedder", | |
| 160 | + raiser, # type: ignore[arg-type] | |
| 161 | + ) | |
| 162 | + probe = AdapterRevertProbe() | |
| 163 | + spec = probe.spec_cls( | |
| 164 | + name="rev", | |
| 165 | + cases=[{"prompt": "x", "gold": "y", "paraphrases": ["pp1"]}], # type: ignore[list-item] | |
| 166 | + ) | |
| 167 | + ctx = RunContext(backend=_backend()) | |
| 168 | + result = probe.run(spec, ctx) | |
| 169 | + assert result.verdict == Verdict.SKIP | |
| 170 | + assert "semsim" in result.message | |