Python · 10022 bytes Raw Blame History
1 """Adversarial paraphrase miner — F11 / S17.
2
3 The shipped ``paraphrase_invariance`` probe measures whether the
4 adapter lifts the gold answer equally when the prompt is paraphrased.
5 Today the probe scores whatever paraphrase list the user hands it —
6 typically 2–4 template-based rewordings. A memorizing adapter can pass
7 that list cleanly if the memorized prompt happens to be in it.
8
9 A *miner* searches the paraphrase neighborhood of each case and ranks
10 candidates by the gap between verbatim lift and paraphrased lift. The
11 top-K become a "hardest" paraphrase list — concrete evidence that the
12 adapter is memorizing rather than generalizing.
13
14 Pipeline:
15
16 1. **Generate** candidates from the case's ``prompt`` via
17 :mod:`nlpaug` (``SynonymAug`` + optional ``BackTranslationAug``).
18 All augmenters are seeded explicitly so two mining runs against
19 the same spec produce the same list.
20 2. **Filter for diversity** via MiniLM embeddings: keep the ``K`` most
21 pairwise-distant candidates so the output isn't three variants of
22 the same sentence.
23 3. **Rank** by the per-token log-probability gap:
24 ``gap = (ft(prompt, gold) - base(prompt, gold))
25 - (ft(candidate, gold) - base(candidate, gold))``
26 Large positive gap ⇒ the candidate breaks the adapter's lift, so
27 the adapter generalizes less than the verbatim list suggested.
28
29 Category: evaluation tool. No probe registration — miners don't
30 produce verdicts, they emit YAML fragments the user folds back into a
31 spec.
32 """
33
34 from __future__ import annotations
35
36 from collections.abc import Callable
37 from dataclasses import dataclass
38 from typing import TYPE_CHECKING
39
40 from dlm_sway.core.errors import BackendNotAvailableError
41 from dlm_sway.probes.adapter_revert import _load_embedder
42
43 if TYPE_CHECKING:
44 import numpy as np
45 from numpy.typing import NDArray
46
47 from dlm_sway.core.scoring import DifferentialBackend
48
49 #: Type alias for a paraphrase-candidate generator.
50 #: ``(prompt, *, n, seed) -> list[str]``. Kept structural so nlpaug's
51 #: untyped Python API and test-time stub closures both satisfy it.
52 CandidateGenerator = Callable[..., list[str]]
53
54
55 @dataclass(frozen=True, slots=True)
56 class ParaphraseCandidate:
57 """One ranked paraphrase candidate.
58
59 ``gap`` is the verbatim-vs-paraphrase lift delta (per-token nats)
60 the ranker scores on; ``diversity_rank`` records the candidate's
61 position in the diversity-filtered pool (0 = closest to the seed,
62 higher = more distant).
63 """
64
65 prompt: str
66 gap: float
67 verbatim_lift: float
68 paraphrase_lift: float
69 diversity_rank: int
70
71
72 @dataclass(frozen=True, slots=True)
73 class MiningResult:
74 """Top-K paraphrase candidates for one ``(prompt, gold)`` case."""
75
76 seed_prompt: str
77 gold: str
78 candidates: list[ParaphraseCandidate]
79
80
81 def mine_paraphrases(
82 *,
83 prompt: str,
84 gold: str,
85 backend: DifferentialBackend,
86 generate_candidates: CandidateGenerator | None = None,
87 embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
88 n_candidates: int = 50,
89 top_k: int = 10,
90 seed: int = 0,
91 ) -> MiningResult:
92 """Mine the top-``top_k`` adversarial paraphrases for one case.
93
94 Parameters
95 ----------
96 prompt, gold:
97 The case's verbatim prompt and its gold continuation.
98 backend:
99 Differential backend with ``as_base()`` + ``as_finetuned()``.
100 Scored via ``logprob_of`` on each side.
101 generate_candidates:
102 Callable producing paraphrase candidates. Defaults to
103 :func:`nlpaug_candidates` which uses nlpaug's synonym +
104 back-translation pipeline. Tests can inject a deterministic
105 stub.
106 embedding_model:
107 MiniLM checkpoint for the diversity filter (shared cache with
108 :mod:`adapter_revert`).
109 n_candidates, top_k:
110 Generate ``n_candidates``, diversity-filter to ``top_k``,
111 then rank by lift gap. ``top_k ≤ n_candidates``.
112 seed:
113 Passed to the generator so back-translation and synonym
114 picks are deterministic.
115 """
116 if top_k <= 0:
117 raise ValueError(f"top_k must be positive; got {top_k}")
118 if n_candidates < top_k:
119 raise ValueError(f"n_candidates ({n_candidates}) must be ≥ top_k ({top_k})")
120
121 gen = generate_candidates or nlpaug_candidates
122 raw_candidates = gen(prompt, n=n_candidates, seed=seed)
123 # Drop the verbatim seed if the generator echoed it back and
124 # deduplicate. Preserve generator order so the diversity filter's
125 # "first-seen wins" heuristic is stable.
126 seen: set[str] = {prompt}
127 unique: list[str] = []
128 for c in raw_candidates:
129 if c and c not in seen:
130 seen.add(c)
131 unique.append(c)
132 if not unique:
133 return MiningResult(seed_prompt=prompt, gold=gold, candidates=[])
134
135 # Diversity filter: MiniLM embeddings + greedy farthest-first.
136 diversified = _diversity_filter(
137 seed_prompt=prompt,
138 candidates=unique,
139 embedding_model=embedding_model,
140 k=min(top_k, len(unique)),
141 )
142
143 # Rank by lift gap.
144 with backend.as_base() as base:
145 base_verbatim = base.logprob_of(prompt, gold) / max(1, _tok_estimate(gold))
146 with backend.as_finetuned() as ft:
147 ft_verbatim = ft.logprob_of(prompt, gold) / max(1, _tok_estimate(gold))
148 verbatim_lift = ft_verbatim - base_verbatim
149
150 ranked: list[ParaphraseCandidate] = []
151 for rank, cand in enumerate(diversified):
152 with backend.as_base() as base:
153 base_p = base.logprob_of(cand, gold) / max(1, _tok_estimate(gold))
154 with backend.as_finetuned() as ft:
155 ft_p = ft.logprob_of(cand, gold) / max(1, _tok_estimate(gold))
156 paraphrase_lift = ft_p - base_p
157 ranked.append(
158 ParaphraseCandidate(
159 prompt=cand,
160 gap=verbatim_lift - paraphrase_lift,
161 verbatim_lift=verbatim_lift,
162 paraphrase_lift=paraphrase_lift,
163 diversity_rank=rank,
164 )
165 )
166
167 # Largest gap first — "hardest" paraphrases lead the list.
168 ranked.sort(key=lambda c: c.gap, reverse=True)
169 return MiningResult(seed_prompt=prompt, gold=gold, candidates=ranked[:top_k])
170
171
172 # ----------------------------------------------------------------------
173 # Generators
174 # ----------------------------------------------------------------------
175
176
177 def nlpaug_candidates(prompt: str, *, n: int, seed: int) -> list[str]:
178 """Generate ``n`` paraphrase candidates via nlpaug synonym augmentation.
179
180 Synonym-only by default — back-translation adds a 1-GB model load
181 and network call on first use, and the gain over synonyms is
182 modest for the typical short-prompt paraphrase case. Users who
183 want back-translation can write their own ``CandidateGenerator``
184 and inject it via ``mine_paraphrases(generate_candidates=…)``.
185
186 Determinism: nlpaug samples under ``random`` / ``numpy``; we seed
187 both before calling ``augment`` so repeated mining runs on the
188 same prompt produce the same candidate set.
189 """
190 try:
191 import nlpaug.augmenter.word as naw
192 except ImportError as exc:
193 raise BackendNotAvailableError(
194 "paraphrase_miner",
195 extra="style",
196 hint="paraphrase_miner's default generator uses nlpaug word-level augmenters.",
197 ) from exc
198 import random
199
200 import numpy as np
201
202 random.seed(seed)
203 np.random.seed(seed)
204 # ``SynonymAug`` uses WordNet — no model download, fast enough
205 # for the N=50 candidate default.
206 aug = naw.SynonymAug(aug_src="wordnet", aug_min=1, aug_max=3)
207 out: list[str] = []
208 for _ in range(n):
209 result = aug.augment(prompt)
210 if isinstance(result, list):
211 out.extend(str(s) for s in result)
212 elif isinstance(result, str):
213 out.append(result)
214 return out
215
216
217 # ----------------------------------------------------------------------
218 # Diversity filter
219 # ----------------------------------------------------------------------
220
221
222 def _diversity_filter(
223 *,
224 seed_prompt: str,
225 candidates: list[str],
226 embedding_model: str,
227 k: int,
228 ) -> list[str]:
229 """Greedy farthest-first selection under MiniLM embeddings.
230
231 Keeps the first ``k`` candidates that are maximally distant from
232 both the seed and each other. This dodges nlpaug's tendency to
233 emit the same synonym twice under slightly different positions.
234 """
235 if not candidates:
236 return []
237 if len(candidates) <= k:
238 return list(candidates)
239
240 embed = _load_embedder(embedding_model)
241 import numpy as np
242
243 vecs: NDArray[np.float32] = np.asarray(embed([seed_prompt, *candidates]), dtype=np.float32)
244 seed_vec = vecs[0]
245 cand_vecs = vecs[1:]
246
247 # Distance from seed: 1 - cosine (embeddings are unit-normalized).
248 dist_from_seed = 1.0 - (cand_vecs @ seed_vec)
249
250 selected_idx: list[int] = []
251 # Start with the most-distant-from-seed candidate.
252 first = int(np.argmax(dist_from_seed))
253 selected_idx.append(first)
254
255 # Greedy: each subsequent pick maximizes min-distance to already-selected.
256 while len(selected_idx) < k:
257 selected_vecs = cand_vecs[selected_idx]
258 # Distances from every candidate to every already-selected one.
259 # Shape: (N, |selected|).
260 sims = cand_vecs @ selected_vecs.T
261 min_dist = 1.0 - sims.max(axis=1)
262 # Exclude already-selected rows from the argmax.
263 min_dist_masked = min_dist.copy()
264 min_dist_masked[selected_idx] = -np.inf
265 next_idx = int(np.argmax(min_dist_masked))
266 if min_dist_masked[next_idx] == -np.inf:
267 break # pool exhausted
268 selected_idx.append(next_idx)
269
270 return [candidates[i] for i in selected_idx]
271
272
273 def _tok_estimate(s: str) -> int:
274 """Same token-count heuristic used by ``paraphrase_invariance``."""
275 return max(1, len(s) // 4)
276
277
278 __all__ = [
279 "MiningResult",
280 "ParaphraseCandidate",
281 "mine_paraphrases",
282 "nlpaug_candidates",
283 ]