tenseleyflow/sway / f0843c2

Browse files

tests: paraphrase_miner — ranker + diversity filter + input validation (S17.5)

Authored by espadonne
SHA
f0843c29053adbc30352d64090a69f268c1fdcc3
Parents
69300ef
Tree
ccc079a

1 changed file

StatusFile+-
A tests/unit/test_paraphrase_miner.py 228 0
tests/unit/test_paraphrase_miner.pyadded
@@ -0,0 +1,228 @@
1
+"""Tests for :mod:`dlm_sway.mining.paraphrase_miner`.
2
+
3
+We stub the embedder + the nlpaug generator so tests run without the
4
+80-MB MiniLM load or the nlpaug wheel. What's under test is the ranker
5
+and the diversity filter — the infrastructure around them (the
6
+embedder, the generator) is exercised by the integration test.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+import numpy as np
12
+import pytest
13
+
14
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
15
+from dlm_sway.mining.paraphrase_miner import (
16
+    MiningResult,
17
+    ParaphraseCandidate,
18
+    mine_paraphrases,
19
+)
20
+
21
+
22
+def _stub_embedder(text_to_vec: dict[str, np.ndarray]):  # type: ignore[no-untyped-def]
23
+    def _encode(texts: list[str]):  # type: ignore[no-untyped-def]
24
+        return np.stack([text_to_vec[t] for t in texts])
25
+
26
+    return _encode
27
+
28
+
29
+@pytest.fixture
30
+def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
31
+    table: dict[str, np.ndarray] = {}
32
+    monkeypatch.setattr(
33
+        "dlm_sway.mining.paraphrase_miner._load_embedder",
34
+        lambda _model_id: _stub_embedder(table),  # type: ignore[arg-type]
35
+    )
36
+    return table
37
+
38
+
39
+def _canned_candidates(candidates: list[str]):  # type: ignore[no-untyped-def]
40
+    """Build a generator closure that returns ``candidates`` verbatim."""
41
+
42
+    def _gen(_prompt: str, *, n: int, seed: int) -> list[str]:
43
+        del n, seed
44
+        return list(candidates)
45
+
46
+    return _gen
47
+
48
+
49
+class TestRanker:
50
+    def test_gap_sort_ranks_hardest_first(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
51
+        """The paraphrase with the largest verbatim-vs-paraphrase lift
52
+        gap ranks first. Planted: three candidates with known
53
+        ``(base, ft)`` logprob pairs; the ranker must surface them in
54
+        gap-descending order regardless of the generator's input
55
+        order."""
56
+        # Seed + candidates
57
+        prompt = "Q: what is the capital?"
58
+        gold = " Paris"
59
+        candidates = ["Q1", "Q2", "Q3"]
60
+        # Embed all four distinctly so the diversity filter keeps
61
+        # every candidate.
62
+        monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
63
+        monkeyed_embed["Q1"] = np.array([0.0, 1.0, 0.0], dtype=np.float32)
64
+        monkeyed_embed["Q2"] = np.array([0.0, 0.0, 1.0], dtype=np.float32)
65
+        monkeyed_embed["Q3"] = np.array([0.5, 0.5, 0.0], dtype=np.float32)
66
+
67
+        # Logprob table — verbatim lift is large; paraphrase lifts
68
+        # vary by design: Q1 lift ≈ 0 (big gap), Q2 lift ≈ verbatim
69
+        # (no gap), Q3 lift somewhere in between.
70
+        base_logprobs = {
71
+            (prompt, gold): -4.0,
72
+            ("Q1", gold): -4.0,  # base unchanged
73
+            ("Q2", gold): -4.0,
74
+            ("Q3", gold): -4.0,
75
+        }
76
+        ft_logprobs = {
77
+            (prompt, gold): -1.0,  # verbatim lift: +3.0
78
+            ("Q1", gold): -4.0,  # paraphrase lift: 0 (big gap of 3)
79
+            ("Q2", gold): -1.0,  # paraphrase lift: +3 (gap of 0)
80
+            ("Q3", gold): -2.5,  # paraphrase lift: +1.5 (gap of 1.5)
81
+        }
82
+        backend = DummyDifferentialBackend(
83
+            base=DummyResponses(logprobs=base_logprobs),
84
+            ft=DummyResponses(logprobs=ft_logprobs),
85
+        )
86
+
87
+        result = mine_paraphrases(
88
+            prompt=prompt,
89
+            gold=gold,
90
+            backend=backend,
91
+            generate_candidates=_canned_candidates(candidates),
92
+            n_candidates=3,
93
+            top_k=3,
94
+            seed=0,
95
+        )
96
+
97
+        assert isinstance(result, MiningResult)
98
+        assert [c.prompt for c in result.candidates] == ["Q1", "Q3", "Q2"]
99
+        # Gaps are in descending order.
100
+        gaps = [c.gap for c in result.candidates]
101
+        assert gaps == sorted(gaps, reverse=True)
102
+        # Top candidate's gap is meaningfully larger than 0.
103
+        assert result.candidates[0].gap > 0.1
104
+
105
+    def test_dedup_drops_verbatim_seed(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
106
+        """If the generator echoes the seed prompt back in its output,
107
+        it's dropped before the ranker sees it — otherwise the seed
108
+        would bubble to the top with gap = 0 and pollute the list."""
109
+        prompt = "prompt"
110
+        gold = " gold"
111
+        # Generator returns the seed + two real candidates; the seed
112
+        # must not appear in the final result.
113
+        candidates = ["prompt", "C1", "C2"]
114
+        for p in candidates + [prompt]:
115
+            monkeyed_embed[p] = (
116
+                np.random.RandomState(hash(p) & 0xFFFFFFFF).randn(4).astype(np.float32)
117
+            )
118
+
119
+        base_logprobs = {(p, gold): -3.0 for p in [prompt, "C1", "C2"]}
120
+        ft_logprobs = {(prompt, gold): -1.0, ("C1", gold): -2.5, ("C2", gold): -2.0}
121
+        backend = DummyDifferentialBackend(
122
+            base=DummyResponses(logprobs=base_logprobs),
123
+            ft=DummyResponses(logprobs=ft_logprobs),
124
+        )
125
+
126
+        result = mine_paraphrases(
127
+            prompt=prompt,
128
+            gold=gold,
129
+            backend=backend,
130
+            generate_candidates=_canned_candidates(candidates),
131
+            n_candidates=3,
132
+            top_k=3,
133
+            seed=0,
134
+        )
135
+        assert prompt not in {c.prompt for c in result.candidates}
136
+
137
+    def test_empty_candidate_list_returns_empty_result(
138
+        self, monkeyed_embed: dict[str, np.ndarray]
139
+    ) -> None:
140
+        del monkeyed_embed
141
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
142
+        result = mine_paraphrases(
143
+            prompt="x",
144
+            gold=" y",
145
+            backend=backend,
146
+            generate_candidates=_canned_candidates([]),
147
+            n_candidates=3,
148
+            top_k=3,
149
+        )
150
+        assert isinstance(result, MiningResult)
151
+        assert result.candidates == []
152
+
153
+
154
+class TestDiversityFilter:
155
+    def test_keeps_farthest_when_candidates_cluster(
156
+        self, monkeyed_embed: dict[str, np.ndarray]
157
+    ) -> None:
158
+        """Two near-duplicate candidates + one distant one → the
159
+        distant candidate must survive the k=2 diversity filter."""
160
+        prompt = "seed"
161
+        gold = " gold"
162
+        candidates = ["near1", "near2", "far"]
163
+        monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
164
+        # near1 / near2 collapse onto the same embedding; far lives
165
+        # orthogonally. With k=2, the filter must pick one near + far
166
+        # (not both nears).
167
+        monkeyed_embed["near1"] = np.array([0.9, 0.1, 0.0], dtype=np.float32)
168
+        monkeyed_embed["near2"] = np.array([0.9, 0.1, 0.0], dtype=np.float32)
169
+        monkeyed_embed["far"] = np.array([0.0, 0.0, 1.0], dtype=np.float32)
170
+
171
+        base_logprobs = {("near1", gold): -3.0, ("near2", gold): -3.0, ("far", gold): -3.0}
172
+        ft_logprobs = {("near1", gold): -2.0, ("near2", gold): -2.0, ("far", gold): -1.0}
173
+        base_logprobs[(prompt, gold)] = -3.0
174
+        ft_logprobs[(prompt, gold)] = -1.0
175
+        backend = DummyDifferentialBackend(
176
+            base=DummyResponses(logprobs=base_logprobs),
177
+            ft=DummyResponses(logprobs=ft_logprobs),
178
+        )
179
+
180
+        result = mine_paraphrases(
181
+            prompt=prompt,
182
+            gold=gold,
183
+            backend=backend,
184
+            generate_candidates=_canned_candidates(candidates),
185
+            n_candidates=3,
186
+            top_k=2,
187
+            seed=0,
188
+        )
189
+        chosen = {c.prompt for c in result.candidates}
190
+        assert "far" in chosen
191
+        assert len(chosen & {"near1", "near2"}) == 1
192
+
193
+
194
+class TestInputValidation:
195
+    def test_rejects_top_k_zero(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
196
+        del monkeyed_embed
197
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
198
+        with pytest.raises(ValueError, match="top_k must be positive"):
199
+            mine_paraphrases(
200
+                prompt="x",
201
+                gold=" y",
202
+                backend=backend,
203
+                generate_candidates=_canned_candidates(["a"]),
204
+                n_candidates=5,
205
+                top_k=0,
206
+            )
207
+
208
+    def test_rejects_n_candidates_below_top_k(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
209
+        del monkeyed_embed
210
+        backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
211
+        with pytest.raises(ValueError, match="must be ≥ top_k"):
212
+            mine_paraphrases(
213
+                prompt="x",
214
+                gold=" y",
215
+                backend=backend,
216
+                generate_candidates=_canned_candidates(["a"]),
217
+                n_candidates=2,
218
+                top_k=5,
219
+            )
220
+
221
+
222
+class TestParaphraseCandidate:
223
+    def test_is_frozen_dataclass(self) -> None:
224
+        c = ParaphraseCandidate(
225
+            prompt="p", gap=0.5, verbatim_lift=1.0, paraphrase_lift=0.5, diversity_rank=0
226
+        )
227
+        with pytest.raises(Exception):  # noqa: B017, PT011 — FrozenInstanceError varies
228
+            c.gap = 0.0  # type: ignore[misc]