Python · 8835 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.mining.paraphrase_miner`.
2
3 We stub the embedder + the nlpaug generator so tests run without the
4 80-MB MiniLM load or the nlpaug wheel. What's under test is the ranker
5 and the diversity filter — the infrastructure around them (the
6 embedder, the generator) is exercised by the integration test.
7 """
8
9 from __future__ import annotations
10
11 import numpy as np
12 import pytest
13
14 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
15 from dlm_sway.mining.paraphrase_miner import (
16 MiningResult,
17 ParaphraseCandidate,
18 mine_paraphrases,
19 )
20
21
22 def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def]
23 def _encode(texts: list[str]): # type: ignore[no-untyped-def]
24 return np.stack([text_to_vec[t] for t in texts])
25
26 return _encode
27
28
29 @pytest.fixture
30 def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
31 table: dict[str, np.ndarray] = {}
32 monkeypatch.setattr(
33 "dlm_sway.mining.paraphrase_miner._load_embedder",
34 lambda _model_id: _stub_embedder(table), # type: ignore[arg-type]
35 )
36 return table
37
38
39 def _canned_candidates(candidates: list[str]): # type: ignore[no-untyped-def]
40 """Build a generator closure that returns ``candidates`` verbatim."""
41
42 def _gen(_prompt: str, *, n: int, seed: int) -> list[str]:
43 del n, seed
44 return list(candidates)
45
46 return _gen
47
48
49 class TestRanker:
50 def test_gap_sort_ranks_hardest_first(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
51 """The paraphrase with the largest verbatim-vs-paraphrase lift
52 gap ranks first. Planted: three candidates with known
53 ``(base, ft)`` logprob pairs; the ranker must surface them in
54 gap-descending order regardless of the generator's input
55 order."""
56 # Seed + candidates
57 prompt = "Q: what is the capital?"
58 gold = " Paris"
59 candidates = ["Q1", "Q2", "Q3"]
60 # Embed all four distinctly so the diversity filter keeps
61 # every candidate.
62 monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
63 monkeyed_embed["Q1"] = np.array([0.0, 1.0, 0.0], dtype=np.float32)
64 monkeyed_embed["Q2"] = np.array([0.0, 0.0, 1.0], dtype=np.float32)
65 monkeyed_embed["Q3"] = np.array([0.5, 0.5, 0.0], dtype=np.float32)
66
67 # Logprob table — verbatim lift is large; paraphrase lifts
68 # vary by design: Q1 lift ≈ 0 (big gap), Q2 lift ≈ verbatim
69 # (no gap), Q3 lift somewhere in between.
70 base_logprobs = {
71 (prompt, gold): -4.0,
72 ("Q1", gold): -4.0, # base unchanged
73 ("Q2", gold): -4.0,
74 ("Q3", gold): -4.0,
75 }
76 ft_logprobs = {
77 (prompt, gold): -1.0, # verbatim lift: +3.0
78 ("Q1", gold): -4.0, # paraphrase lift: 0 (big gap of 3)
79 ("Q2", gold): -1.0, # paraphrase lift: +3 (gap of 0)
80 ("Q3", gold): -2.5, # paraphrase lift: +1.5 (gap of 1.5)
81 }
82 backend = DummyDifferentialBackend(
83 base=DummyResponses(logprobs=base_logprobs),
84 ft=DummyResponses(logprobs=ft_logprobs),
85 )
86
87 result = mine_paraphrases(
88 prompt=prompt,
89 gold=gold,
90 backend=backend,
91 generate_candidates=_canned_candidates(candidates),
92 n_candidates=3,
93 top_k=3,
94 seed=0,
95 )
96
97 assert isinstance(result, MiningResult)
98 assert [c.prompt for c in result.candidates] == ["Q1", "Q3", "Q2"]
99 # Gaps are in descending order.
100 gaps = [c.gap for c in result.candidates]
101 assert gaps == sorted(gaps, reverse=True)
102 # Top candidate's gap is meaningfully larger than 0.
103 assert result.candidates[0].gap > 0.1
104
105 def test_dedup_drops_verbatim_seed(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
106 """If the generator echoes the seed prompt back in its output,
107 it's dropped before the ranker sees it — otherwise the seed
108 would bubble to the top with gap = 0 and pollute the list."""
109 prompt = "prompt"
110 gold = " gold"
111 # Generator returns the seed + two real candidates; the seed
112 # must not appear in the final result.
113 candidates = ["prompt", "C1", "C2"]
114 for p in candidates + [prompt]:
115 monkeyed_embed[p] = (
116 np.random.RandomState(hash(p) & 0xFFFFFFFF).randn(4).astype(np.float32)
117 )
118
119 base_logprobs = {(p, gold): -3.0 for p in [prompt, "C1", "C2"]}
120 ft_logprobs = {(prompt, gold): -1.0, ("C1", gold): -2.5, ("C2", gold): -2.0}
121 backend = DummyDifferentialBackend(
122 base=DummyResponses(logprobs=base_logprobs),
123 ft=DummyResponses(logprobs=ft_logprobs),
124 )
125
126 result = mine_paraphrases(
127 prompt=prompt,
128 gold=gold,
129 backend=backend,
130 generate_candidates=_canned_candidates(candidates),
131 n_candidates=3,
132 top_k=3,
133 seed=0,
134 )
135 assert prompt not in {c.prompt for c in result.candidates}
136
137 def test_empty_candidate_list_returns_empty_result(
138 self, monkeyed_embed: dict[str, np.ndarray]
139 ) -> None:
140 del monkeyed_embed
141 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
142 result = mine_paraphrases(
143 prompt="x",
144 gold=" y",
145 backend=backend,
146 generate_candidates=_canned_candidates([]),
147 n_candidates=3,
148 top_k=3,
149 )
150 assert isinstance(result, MiningResult)
151 assert result.candidates == []
152
153
154 class TestDiversityFilter:
155 def test_keeps_farthest_when_candidates_cluster(
156 self, monkeyed_embed: dict[str, np.ndarray]
157 ) -> None:
158 """Two near-duplicate candidates + one distant one → the
159 distant candidate must survive the k=2 diversity filter."""
160 prompt = "seed"
161 gold = " gold"
162 candidates = ["near1", "near2", "far"]
163 monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
164 # near1 / near2 collapse onto the same embedding; far lives
165 # orthogonally. With k=2, the filter must pick one near + far
166 # (not both nears).
167 monkeyed_embed["near1"] = np.array([0.9, 0.1, 0.0], dtype=np.float32)
168 monkeyed_embed["near2"] = np.array([0.9, 0.1, 0.0], dtype=np.float32)
169 monkeyed_embed["far"] = np.array([0.0, 0.0, 1.0], dtype=np.float32)
170
171 base_logprobs = {("near1", gold): -3.0, ("near2", gold): -3.0, ("far", gold): -3.0}
172 ft_logprobs = {("near1", gold): -2.0, ("near2", gold): -2.0, ("far", gold): -1.0}
173 base_logprobs[(prompt, gold)] = -3.0
174 ft_logprobs[(prompt, gold)] = -1.0
175 backend = DummyDifferentialBackend(
176 base=DummyResponses(logprobs=base_logprobs),
177 ft=DummyResponses(logprobs=ft_logprobs),
178 )
179
180 result = mine_paraphrases(
181 prompt=prompt,
182 gold=gold,
183 backend=backend,
184 generate_candidates=_canned_candidates(candidates),
185 n_candidates=3,
186 top_k=2,
187 seed=0,
188 )
189 chosen = {c.prompt for c in result.candidates}
190 assert "far" in chosen
191 assert len(chosen & {"near1", "near2"}) == 1
192
193
194 class TestInputValidation:
195 def test_rejects_top_k_zero(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
196 del monkeyed_embed
197 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
198 with pytest.raises(ValueError, match="top_k must be positive"):
199 mine_paraphrases(
200 prompt="x",
201 gold=" y",
202 backend=backend,
203 generate_candidates=_canned_candidates(["a"]),
204 n_candidates=5,
205 top_k=0,
206 )
207
208 def test_rejects_n_candidates_below_top_k(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
209 del monkeyed_embed
210 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
211 with pytest.raises(ValueError, match="must be ≥ top_k"):
212 mine_paraphrases(
213 prompt="x",
214 gold=" y",
215 backend=backend,
216 generate_candidates=_canned_candidates(["a"]),
217 n_candidates=2,
218 top_k=5,
219 )
220
221
222 class TestParaphraseCandidate:
223 def test_is_frozen_dataclass(self) -> None:
224 c = ParaphraseCandidate(
225 prompt="p", gap=0.5, verbatim_lift=1.0, paraphrase_lift=0.5, diversity_rank=0
226 )
227 with pytest.raises(Exception): # noqa: B017, PT011 — FrozenInstanceError varies
228 c.gap = 0.0 # type: ignore[misc]