tenseleyflow/sway / 7ad0e9a

Browse files

mining: paraphrase_miner — generator + diversity filter + lift-gap ranker (S17.2)

Authored by espadonne
SHA
7ad0e9ad19a3373499ead7ef26887a6f36e32666
Parents
416601f
Tree
c2e4d72

2 changed files

StatusFile+-
A src/dlm_sway/mining/__init__.py 12 0
A src/dlm_sway/mining/paraphrase_miner.py 283 0
src/dlm_sway/mining/__init__.pyadded
@@ -0,0 +1,12 @@
1
+"""Paraphrase + outlier miners (F11 / S17).
2
+
3
+Miners are companion tools to the shipped probes, not probes themselves.
4
+``paraphrase_miner`` sharpens :mod:`dlm_sway.probes.paraphrase_invariance`
5
+by finding the paraphrases an adapter most reliably *fails* on — a
6
+memorizing adapter that passes a user's hand-picked paraphrase list can
7
+still lose on the mined ones. ``outlier_miner`` does the same for any
8
+probe that aggregates over prompts.
9
+
10
+Both miners are deterministic under a fixed seed; the CLI entry point
11
+is ``sway mine``.
12
+"""
src/dlm_sway/mining/paraphrase_miner.pyadded
@@ -0,0 +1,283 @@
1
+"""Adversarial paraphrase miner — F11 / S17.
2
+
3
+The shipped ``paraphrase_invariance`` probe measures whether the
4
+adapter lifts the gold answer equally when the prompt is paraphrased.
5
+Today the probe scores whatever paraphrase list the user hands it —
6
+typically 2–4 template-based rewordings. A memorizing adapter can pass
7
+that list cleanly if the memorized prompt happens to be in it.
8
+
9
+A *miner* searches the paraphrase neighborhood of each case and ranks
10
+candidates by the gap between verbatim lift and paraphrased lift. The
11
+top-K become a "hardest" paraphrase list — concrete evidence that the
12
+adapter is memorizing rather than generalizing.
13
+
14
+Pipeline:
15
+
16
+1. **Generate** candidates from the case's ``prompt`` via
17
+   :mod:`nlpaug` (``SynonymAug`` + optional ``BackTranslationAug``).
18
+   All augmenters are seeded explicitly so two mining runs against
19
+   the same spec produce the same list.
20
+2. **Filter for diversity** via MiniLM embeddings: keep the ``K`` most
21
+   pairwise-distant candidates so the output isn't three variants of
22
+   the same sentence.
23
+3. **Rank** by the per-token log-probability gap:
24
+   ``gap = (ft(prompt, gold) - base(prompt, gold))
25
+          - (ft(candidate, gold) - base(candidate, gold))``
26
+   Large positive gap ⇒ the candidate breaks the adapter's lift, so
27
+   the adapter generalizes less than the verbatim list suggested.
28
+
29
+Category: evaluation tool. No probe registration — miners don't
30
+produce verdicts, they emit YAML fragments the user folds back into a
31
+spec.
32
+"""
33
+
34
+from __future__ import annotations
35
+
36
+from collections.abc import Callable
37
+from dataclasses import dataclass
38
+from typing import TYPE_CHECKING
39
+
40
+from dlm_sway.core.errors import BackendNotAvailableError
41
+from dlm_sway.probes.adapter_revert import _load_embedder
42
+
43
+if TYPE_CHECKING:
44
+    import numpy as np
45
+    from numpy.typing import NDArray
46
+
47
+    from dlm_sway.core.scoring import DifferentialBackend
48
+
49
+#: Type alias for a paraphrase-candidate generator.
50
+#: ``(prompt, *, n, seed) -> list[str]``. Kept structural so nlpaug's
51
+#: untyped Python API and test-time stub closures both satisfy it.
52
+CandidateGenerator = Callable[..., list[str]]
53
+
54
+
55
+@dataclass(frozen=True, slots=True)
56
+class ParaphraseCandidate:
57
+    """One ranked paraphrase candidate.
58
+
59
+    ``gap`` is the verbatim-vs-paraphrase lift delta (per-token nats)
60
+    the ranker scores on; ``diversity_rank`` records the candidate's
61
+    position in the diversity-filtered pool (0 = closest to the seed,
62
+    higher = more distant).
63
+    """
64
+
65
+    prompt: str
66
+    gap: float
67
+    verbatim_lift: float
68
+    paraphrase_lift: float
69
+    diversity_rank: int
70
+
71
+
72
+@dataclass(frozen=True, slots=True)
73
+class MiningResult:
74
+    """Top-K paraphrase candidates for one ``(prompt, gold)`` case."""
75
+
76
+    seed_prompt: str
77
+    gold: str
78
+    candidates: list[ParaphraseCandidate]
79
+
80
+
81
+def mine_paraphrases(
82
+    *,
83
+    prompt: str,
84
+    gold: str,
85
+    backend: DifferentialBackend,
86
+    generate_candidates: CandidateGenerator | None = None,
87
+    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
88
+    n_candidates: int = 50,
89
+    top_k: int = 10,
90
+    seed: int = 0,
91
+) -> MiningResult:
92
+    """Mine the top-``top_k`` adversarial paraphrases for one case.
93
+
94
+    Parameters
95
+    ----------
96
+    prompt, gold:
97
+        The case's verbatim prompt and its gold continuation.
98
+    backend:
99
+        Differential backend with ``as_base()`` + ``as_finetuned()``.
100
+        Scored via ``logprob_of`` on each side.
101
+    generate_candidates:
102
+        Callable producing paraphrase candidates. Defaults to
103
+        :func:`nlpaug_candidates` which uses nlpaug's synonym +
104
+        back-translation pipeline. Tests can inject a deterministic
105
+        stub.
106
+    embedding_model:
107
+        MiniLM checkpoint for the diversity filter (shared cache with
108
+        :mod:`adapter_revert`).
109
+    n_candidates, top_k:
110
+        Generate ``n_candidates``, diversity-filter to ``top_k``,
111
+        then rank by lift gap. ``top_k ≤ n_candidates``.
112
+    seed:
113
+        Passed to the generator so back-translation and synonym
114
+        picks are deterministic.
115
+    """
116
+    if top_k <= 0:
117
+        raise ValueError(f"top_k must be positive; got {top_k}")
118
+    if n_candidates < top_k:
119
+        raise ValueError(f"n_candidates ({n_candidates}) must be ≥ top_k ({top_k})")
120
+
121
+    gen = generate_candidates or nlpaug_candidates
122
+    raw_candidates = gen(prompt, n=n_candidates, seed=seed)
123
+    # Drop the verbatim seed if the generator echoed it back and
124
+    # deduplicate. Preserve generator order so the diversity filter's
125
+    # "first-seen wins" heuristic is stable.
126
+    seen: set[str] = {prompt}
127
+    unique: list[str] = []
128
+    for c in raw_candidates:
129
+        if c and c not in seen:
130
+            seen.add(c)
131
+            unique.append(c)
132
+    if not unique:
133
+        return MiningResult(seed_prompt=prompt, gold=gold, candidates=[])
134
+
135
+    # Diversity filter: MiniLM embeddings + greedy farthest-first.
136
+    diversified = _diversity_filter(
137
+        seed_prompt=prompt,
138
+        candidates=unique,
139
+        embedding_model=embedding_model,
140
+        k=min(top_k, len(unique)),
141
+    )
142
+
143
+    # Rank by lift gap.
144
+    with backend.as_base() as base:
145
+        base_verbatim = base.logprob_of(prompt, gold) / max(1, _tok_estimate(gold))
146
+    with backend.as_finetuned() as ft:
147
+        ft_verbatim = ft.logprob_of(prompt, gold) / max(1, _tok_estimate(gold))
148
+    verbatim_lift = ft_verbatim - base_verbatim
149
+
150
+    ranked: list[ParaphraseCandidate] = []
151
+    for rank, cand in enumerate(diversified):
152
+        with backend.as_base() as base:
153
+            base_p = base.logprob_of(cand, gold) / max(1, _tok_estimate(gold))
154
+        with backend.as_finetuned() as ft:
155
+            ft_p = ft.logprob_of(cand, gold) / max(1, _tok_estimate(gold))
156
+        paraphrase_lift = ft_p - base_p
157
+        ranked.append(
158
+            ParaphraseCandidate(
159
+                prompt=cand,
160
+                gap=verbatim_lift - paraphrase_lift,
161
+                verbatim_lift=verbatim_lift,
162
+                paraphrase_lift=paraphrase_lift,
163
+                diversity_rank=rank,
164
+            )
165
+        )
166
+
167
+    # Largest gap first — "hardest" paraphrases lead the list.
168
+    ranked.sort(key=lambda c: c.gap, reverse=True)
169
+    return MiningResult(seed_prompt=prompt, gold=gold, candidates=ranked[:top_k])
170
+
171
+
172
+# ----------------------------------------------------------------------
173
+# Generators
174
+# ----------------------------------------------------------------------
175
+
176
+
177
+def nlpaug_candidates(prompt: str, *, n: int, seed: int) -> list[str]:
178
+    """Generate ``n`` paraphrase candidates via nlpaug synonym augmentation.
179
+
180
+    Synonym-only by default — back-translation adds a 1-GB model load
181
+    and network call on first use, and the gain over synonyms is
182
+    modest for the typical short-prompt paraphrase case. Users who
183
+    want back-translation can write their own ``CandidateGenerator``
184
+    and inject it via ``mine_paraphrases(generate_candidates=…)``.
185
+
186
+    Determinism: nlpaug samples under ``random`` / ``numpy``; we seed
187
+    both before calling ``augment`` so repeated mining runs on the
188
+    same prompt produce the same candidate set.
189
+    """
190
+    try:
191
+        import nlpaug.augmenter.word as naw
192
+    except ImportError as exc:
193
+        raise BackendNotAvailableError(
194
+            "paraphrase_miner",
195
+            extra="style",
196
+            hint="paraphrase_miner's default generator uses nlpaug word-level augmenters.",
197
+        ) from exc
198
+    import random
199
+
200
+    import numpy as np
201
+
202
+    random.seed(seed)
203
+    np.random.seed(seed)
204
+    # ``SynonymAug`` uses WordNet — no model download, fast enough
205
+    # for the N=50 candidate default.
206
+    aug = naw.SynonymAug(aug_src="wordnet", aug_min=1, aug_max=3)
207
+    out: list[str] = []
208
+    for _ in range(n):
209
+        result = aug.augment(prompt)
210
+        if isinstance(result, list):
211
+            out.extend(str(s) for s in result)
212
+        elif isinstance(result, str):
213
+            out.append(result)
214
+    return out
215
+
216
+
217
+# ----------------------------------------------------------------------
218
+# Diversity filter
219
+# ----------------------------------------------------------------------
220
+
221
+
222
+def _diversity_filter(
223
+    *,
224
+    seed_prompt: str,
225
+    candidates: list[str],
226
+    embedding_model: str,
227
+    k: int,
228
+) -> list[str]:
229
+    """Greedy farthest-first selection under MiniLM embeddings.
230
+
231
+    Keeps the first ``k`` candidates that are maximally distant from
232
+    both the seed and each other. This dodges nlpaug's tendency to
233
+    emit the same synonym twice under slightly different positions.
234
+    """
235
+    if not candidates:
236
+        return []
237
+    if len(candidates) <= k:
238
+        return list(candidates)
239
+
240
+    embed = _load_embedder(embedding_model)
241
+    import numpy as np
242
+
243
+    vecs: NDArray[np.float32] = np.asarray(embed([seed_prompt, *candidates]), dtype=np.float32)
244
+    seed_vec = vecs[0]
245
+    cand_vecs = vecs[1:]
246
+
247
+    # Distance from seed: 1 - cosine (embeddings are unit-normalized).
248
+    dist_from_seed = 1.0 - (cand_vecs @ seed_vec)
249
+
250
+    selected_idx: list[int] = []
251
+    # Start with the most-distant-from-seed candidate.
252
+    first = int(np.argmax(dist_from_seed))
253
+    selected_idx.append(first)
254
+
255
+    # Greedy: each subsequent pick maximizes min-distance to already-selected.
256
+    while len(selected_idx) < k:
257
+        selected_vecs = cand_vecs[selected_idx]
258
+        # Distances from every candidate to every already-selected one.
259
+        # Shape: (N, |selected|).
260
+        sims = cand_vecs @ selected_vecs.T
261
+        min_dist = 1.0 - sims.max(axis=1)
262
+        # Exclude already-selected rows from the argmax.
263
+        min_dist_masked = min_dist.copy()
264
+        min_dist_masked[selected_idx] = -np.inf
265
+        next_idx = int(np.argmax(min_dist_masked))
266
+        if min_dist_masked[next_idx] == -np.inf:
267
+            break  # pool exhausted
268
+        selected_idx.append(next_idx)
269
+
270
+    return [candidates[i] for i in selected_idx]
271
+
272
+
273
+def _tok_estimate(s: str) -> int:
274
+    """Same token-count heuristic used by ``paraphrase_invariance``."""
275
+    return max(1, len(s) // 4)
276
+
277
+
278
+__all__ = [
279
+    "MiningResult",
280
+    "ParaphraseCandidate",
281
+    "mine_paraphrases",
282
+    "nlpaug_candidates",
283
+]