Python · 6563 bytes Raw Blame History
1 """A2 AdapterRevert — does the fine-tuned model drift back to base under pressure?
2
3 For each test case the user provides a prompt, a "gold" answer (the
4 adapter's intended response), and one or more adversarial paraphrases of
5 the prompt. We generate base-model and ft-model completions on every
6 paraphrase and ask: does the ft output cluster semantically with the
7 base's output (revert) or with the gold (adhere)?
8
9 Signal: ``revert_rate`` = fraction of (case, paraphrase) pairs where
10 ``cos(ft, base) > cos(ft, gold)``. A healthy fine-tune holds below 25%.
11
12 Needs sentence embeddings. Without the ``semsim`` extra installed the
13 probe returns :attr:`Verdict.SKIP` with a pip hint — deterministic
14 n-gram fallbacks don't carry semantic equivalence reliably enough to
15 drive a revert decision, and we'd rather be honest than lossy.
16 """
17
18 from __future__ import annotations
19
20 from typing import Any, Literal
21
22 from pydantic import BaseModel, ConfigDict, Field
23
24 from dlm_sway.core.errors import BackendNotAvailableError
25 from dlm_sway.core.result import ProbeResult, Verdict
26 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
27
28
29 class AdapterRevertCase(BaseModel):
30 """One revert test case."""
31
32 model_config = ConfigDict(extra="forbid", frozen=True)
33
34 prompt: str
35 gold: str
36 """What the adapter is supposed to produce."""
37 paraphrases: list[str] = Field(default_factory=list, min_length=1)
38 """At least one paraphrase is required — revert is observed under
39 reframing, not on the original prompt."""
40
41
42 class AdapterRevertSpec(ProbeSpec):
43 kind: Literal["adapter_revert"] = "adapter_revert"
44 cases: list[AdapterRevertCase] = Field(default_factory=list)
45 max_new_tokens: int = 64
46 embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
47 """HF id of the embedder. Default is ~80 MB, CPU-friendly."""
48 base_gold_similarity_cap: float = 0.75
49 """Skip pairs where base and gold are trivially similar — those
50 can't distinguish revert from adherence, and including them would
51 inflate the revert rate with noise."""
52 assert_revert_rate_lt: float = 0.25
53
54
55 class AdapterRevertProbe(Probe):
56 kind = "adapter_revert"
57 spec_cls = AdapterRevertSpec
58 category = "adherence"
59
60 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
61 assert isinstance(spec, AdapterRevertSpec)
62 if not spec.cases:
63 return ProbeResult(
64 name=spec.name,
65 kind=spec.kind,
66 verdict=Verdict.ERROR,
67 score=None,
68 message="no cases provided",
69 )
70
71 try:
72 embed = _load_embedder(spec.embedding_model)
73 except BackendNotAvailableError as exc:
74 return ProbeResult(
75 name=spec.name,
76 kind=spec.kind,
77 verdict=Verdict.SKIP,
78 score=None,
79 message=str(exc),
80 )
81
82 import numpy as np
83
84 total = 0
85 reverts = 0
86 dropped_trivial = 0
87 per_case: list[dict[str, Any]] = []
88 for case in spec.cases:
89 gold_vec = embed([case.gold])[0]
90 for pp in case.paraphrases:
91 with ctx.backend.as_base() as bv:
92 base_gen = bv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
93 with ctx.backend.as_finetuned() as fv:
94 ft_gen = fv.generate(pp, max_new_tokens=spec.max_new_tokens, seed=ctx.seed)
95 vecs = embed([base_gen, ft_gen])
96 base_vec, ft_vec = vecs[0], vecs[1]
97 base_gold = _cosine(base_vec, gold_vec)
98 if base_gold > spec.base_gold_similarity_cap:
99 dropped_trivial += 1
100 continue
101 cos_ft_base = _cosine(ft_vec, base_vec)
102 cos_ft_gold = _cosine(ft_vec, gold_vec)
103 total += 1
104 if cos_ft_base > cos_ft_gold:
105 reverts += 1
106 per_case.append(
107 {
108 "prompt": pp[:80],
109 "cos_ft_base": cos_ft_base,
110 "cos_ft_gold": cos_ft_gold,
111 "reverted": cos_ft_base > cos_ft_gold,
112 }
113 )
114
115 if total == 0:
116 return ProbeResult(
117 name=spec.name,
118 kind=spec.kind,
119 verdict=Verdict.WARN,
120 score=0.5,
121 message=(
122 f"all {dropped_trivial} cases had base≈gold (> "
123 f"{spec.base_gold_similarity_cap}) — no separable signal"
124 ),
125 evidence={"dropped_trivial": dropped_trivial, "weight": spec.weight},
126 )
127
128 rate = reverts / total
129 verdict = Verdict.PASS if rate < spec.assert_revert_rate_lt else Verdict.FAIL
130 score = max(0.0, 1.0 - rate / max(spec.assert_revert_rate_lt, 1e-6))
131 score = float(np.clip(score, 0.0, 1.0))
132
133 return ProbeResult(
134 name=spec.name,
135 kind=spec.kind,
136 verdict=verdict,
137 score=score,
138 raw=rate,
139 evidence={
140 "revert_rate": rate,
141 "reverts": reverts,
142 "total": total,
143 "dropped_trivial": dropped_trivial,
144 "per_case": per_case[:8], # cap to keep JSON bounded
145 "weight": spec.weight,
146 },
147 message=f"revert_rate={rate:.2%} (reverts={reverts}/{total}, dropped_trivial={dropped_trivial})",
148 )
149
150
151 def _load_embedder(model_id: str): # type: ignore[no-untyped-def]
152 """Return a callable ``list[str] -> np.ndarray`` over encoded vectors."""
153 try:
154 from sentence_transformers import SentenceTransformer
155 except ImportError as exc:
156 raise BackendNotAvailableError(
157 "adapter_revert",
158 extra="semsim",
159 hint="adapter_revert relies on sentence embeddings.",
160 ) from exc
161 st = SentenceTransformer(model_id)
162
163 def _embed(texts: list[str]): # type: ignore[no-untyped-def]
164 return st.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
165
166 return _embed
167
168
169 def _cosine(a: Any, b: Any) -> float:
170 import numpy as np
171
172 av = np.asarray(a, dtype=np.float64)
173 bv = np.asarray(b, dtype=np.float64)
174 na = float(np.linalg.norm(av))
175 nb = float(np.linalg.norm(bv))
176 if na == 0.0 or nb == 0.0:
177 return 0.0
178 return float(np.dot(av, bv) / (na * nb))