| 1 | """Tests for :mod:`dlm_sway.mining.paraphrase_miner`. |
| 2 | |
| 3 | We stub the embedder + the nlpaug generator so tests run without the |
| 4 | 80-MB MiniLM load or the nlpaug wheel. What's under test is the ranker |
| 5 | and the diversity filter — the infrastructure around them (the |
| 6 | embedder, the generator) is exercised by the integration test. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | import numpy as np |
| 12 | import pytest |
| 13 | |
| 14 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 15 | from dlm_sway.mining.paraphrase_miner import ( |
| 16 | MiningResult, |
| 17 | ParaphraseCandidate, |
| 18 | mine_paraphrases, |
| 19 | ) |
| 20 | |
| 21 | |
| 22 | def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def] |
| 23 | def _encode(texts: list[str]): # type: ignore[no-untyped-def] |
| 24 | return np.stack([text_to_vec[t] for t in texts]) |
| 25 | |
| 26 | return _encode |
| 27 | |
| 28 | |
| 29 | @pytest.fixture |
| 30 | def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]: |
| 31 | table: dict[str, np.ndarray] = {} |
| 32 | monkeypatch.setattr( |
| 33 | "dlm_sway.mining.paraphrase_miner._load_embedder", |
| 34 | lambda _model_id: _stub_embedder(table), # type: ignore[arg-type] |
| 35 | ) |
| 36 | return table |
| 37 | |
| 38 | |
| 39 | def _canned_candidates(candidates: list[str]): # type: ignore[no-untyped-def] |
| 40 | """Build a generator closure that returns ``candidates`` verbatim.""" |
| 41 | |
| 42 | def _gen(_prompt: str, *, n: int, seed: int) -> list[str]: |
| 43 | del n, seed |
| 44 | return list(candidates) |
| 45 | |
| 46 | return _gen |
| 47 | |
| 48 | |
| 49 | class TestRanker: |
| 50 | def test_gap_sort_ranks_hardest_first(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 51 | """The paraphrase with the largest verbatim-vs-paraphrase lift |
| 52 | gap ranks first. Planted: three candidates with known |
| 53 | ``(base, ft)`` logprob pairs; the ranker must surface them in |
| 54 | gap-descending order regardless of the generator's input |
| 55 | order.""" |
| 56 | # Seed + candidates |
| 57 | prompt = "Q: what is the capital?" |
| 58 | gold = " Paris" |
| 59 | candidates = ["Q1", "Q2", "Q3"] |
| 60 | # Embed all four distinctly so the diversity filter keeps |
| 61 | # every candidate. |
| 62 | monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32) |
| 63 | monkeyed_embed["Q1"] = np.array([0.0, 1.0, 0.0], dtype=np.float32) |
| 64 | monkeyed_embed["Q2"] = np.array([0.0, 0.0, 1.0], dtype=np.float32) |
| 65 | monkeyed_embed["Q3"] = np.array([0.5, 0.5, 0.0], dtype=np.float32) |
| 66 | |
| 67 | # Logprob table — verbatim lift is large; paraphrase lifts |
| 68 | # vary by design: Q1 lift ≈ 0 (big gap), Q2 lift ≈ verbatim |
| 69 | # (no gap), Q3 lift somewhere in between. |
| 70 | base_logprobs = { |
| 71 | (prompt, gold): -4.0, |
| 72 | ("Q1", gold): -4.0, # base unchanged |
| 73 | ("Q2", gold): -4.0, |
| 74 | ("Q3", gold): -4.0, |
| 75 | } |
| 76 | ft_logprobs = { |
| 77 | (prompt, gold): -1.0, # verbatim lift: +3.0 |
| 78 | ("Q1", gold): -4.0, # paraphrase lift: 0 (big gap of 3) |
| 79 | ("Q2", gold): -1.0, # paraphrase lift: +3 (gap of 0) |
| 80 | ("Q3", gold): -2.5, # paraphrase lift: +1.5 (gap of 1.5) |
| 81 | } |
| 82 | backend = DummyDifferentialBackend( |
| 83 | base=DummyResponses(logprobs=base_logprobs), |
| 84 | ft=DummyResponses(logprobs=ft_logprobs), |
| 85 | ) |
| 86 | |
| 87 | result = mine_paraphrases( |
| 88 | prompt=prompt, |
| 89 | gold=gold, |
| 90 | backend=backend, |
| 91 | generate_candidates=_canned_candidates(candidates), |
| 92 | n_candidates=3, |
| 93 | top_k=3, |
| 94 | seed=0, |
| 95 | ) |
| 96 | |
| 97 | assert isinstance(result, MiningResult) |
| 98 | assert [c.prompt for c in result.candidates] == ["Q1", "Q3", "Q2"] |
| 99 | # Gaps are in descending order. |
| 100 | gaps = [c.gap for c in result.candidates] |
| 101 | assert gaps == sorted(gaps, reverse=True) |
| 102 | # Top candidate's gap is meaningfully larger than 0. |
| 103 | assert result.candidates[0].gap > 0.1 |
| 104 | |
| 105 | def test_dedup_drops_verbatim_seed(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 106 | """If the generator echoes the seed prompt back in its output, |
| 107 | it's dropped before the ranker sees it — otherwise the seed |
| 108 | would bubble to the top with gap = 0 and pollute the list.""" |
| 109 | prompt = "prompt" |
| 110 | gold = " gold" |
| 111 | # Generator returns the seed + two real candidates; the seed |
| 112 | # must not appear in the final result. |
| 113 | candidates = ["prompt", "C1", "C2"] |
| 114 | for p in candidates + [prompt]: |
| 115 | monkeyed_embed[p] = ( |
| 116 | np.random.RandomState(hash(p) & 0xFFFFFFFF).randn(4).astype(np.float32) |
| 117 | ) |
| 118 | |
| 119 | base_logprobs = {(p, gold): -3.0 for p in [prompt, "C1", "C2"]} |
| 120 | ft_logprobs = {(prompt, gold): -1.0, ("C1", gold): -2.5, ("C2", gold): -2.0} |
| 121 | backend = DummyDifferentialBackend( |
| 122 | base=DummyResponses(logprobs=base_logprobs), |
| 123 | ft=DummyResponses(logprobs=ft_logprobs), |
| 124 | ) |
| 125 | |
| 126 | result = mine_paraphrases( |
| 127 | prompt=prompt, |
| 128 | gold=gold, |
| 129 | backend=backend, |
| 130 | generate_candidates=_canned_candidates(candidates), |
| 131 | n_candidates=3, |
| 132 | top_k=3, |
| 133 | seed=0, |
| 134 | ) |
| 135 | assert prompt not in {c.prompt for c in result.candidates} |
| 136 | |
| 137 | def test_empty_candidate_list_returns_empty_result( |
| 138 | self, monkeyed_embed: dict[str, np.ndarray] |
| 139 | ) -> None: |
| 140 | del monkeyed_embed |
| 141 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 142 | result = mine_paraphrases( |
| 143 | prompt="x", |
| 144 | gold=" y", |
| 145 | backend=backend, |
| 146 | generate_candidates=_canned_candidates([]), |
| 147 | n_candidates=3, |
| 148 | top_k=3, |
| 149 | ) |
| 150 | assert isinstance(result, MiningResult) |
| 151 | assert result.candidates == [] |
| 152 | |
| 153 | |
| 154 | class TestDiversityFilter: |
| 155 | def test_keeps_farthest_when_candidates_cluster( |
| 156 | self, monkeyed_embed: dict[str, np.ndarray] |
| 157 | ) -> None: |
| 158 | """Two near-duplicate candidates + one distant one → the |
| 159 | distant candidate must survive the k=2 diversity filter.""" |
| 160 | prompt = "seed" |
| 161 | gold = " gold" |
| 162 | candidates = ["near1", "near2", "far"] |
| 163 | monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32) |
| 164 | # near1 / near2 collapse onto the same embedding; far lives |
| 165 | # orthogonally. With k=2, the filter must pick one near + far |
| 166 | # (not both nears). |
| 167 | monkeyed_embed["near1"] = np.array([0.9, 0.1, 0.0], dtype=np.float32) |
| 168 | monkeyed_embed["near2"] = np.array([0.9, 0.1, 0.0], dtype=np.float32) |
| 169 | monkeyed_embed["far"] = np.array([0.0, 0.0, 1.0], dtype=np.float32) |
| 170 | |
| 171 | base_logprobs = {("near1", gold): -3.0, ("near2", gold): -3.0, ("far", gold): -3.0} |
| 172 | ft_logprobs = {("near1", gold): -2.0, ("near2", gold): -2.0, ("far", gold): -1.0} |
| 173 | base_logprobs[(prompt, gold)] = -3.0 |
| 174 | ft_logprobs[(prompt, gold)] = -1.0 |
| 175 | backend = DummyDifferentialBackend( |
| 176 | base=DummyResponses(logprobs=base_logprobs), |
| 177 | ft=DummyResponses(logprobs=ft_logprobs), |
| 178 | ) |
| 179 | |
| 180 | result = mine_paraphrases( |
| 181 | prompt=prompt, |
| 182 | gold=gold, |
| 183 | backend=backend, |
| 184 | generate_candidates=_canned_candidates(candidates), |
| 185 | n_candidates=3, |
| 186 | top_k=2, |
| 187 | seed=0, |
| 188 | ) |
| 189 | chosen = {c.prompt for c in result.candidates} |
| 190 | assert "far" in chosen |
| 191 | assert len(chosen & {"near1", "near2"}) == 1 |
| 192 | |
| 193 | |
| 194 | class TestInputValidation: |
| 195 | def test_rejects_top_k_zero(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 196 | del monkeyed_embed |
| 197 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 198 | with pytest.raises(ValueError, match="top_k must be positive"): |
| 199 | mine_paraphrases( |
| 200 | prompt="x", |
| 201 | gold=" y", |
| 202 | backend=backend, |
| 203 | generate_candidates=_canned_candidates(["a"]), |
| 204 | n_candidates=5, |
| 205 | top_k=0, |
| 206 | ) |
| 207 | |
| 208 | def test_rejects_n_candidates_below_top_k(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 209 | del monkeyed_embed |
| 210 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 211 | with pytest.raises(ValueError, match="must be ≥ top_k"): |
| 212 | mine_paraphrases( |
| 213 | prompt="x", |
| 214 | gold=" y", |
| 215 | backend=backend, |
| 216 | generate_candidates=_canned_candidates(["a"]), |
| 217 | n_candidates=2, |
| 218 | top_k=5, |
| 219 | ) |
| 220 | |
| 221 | |
| 222 | class TestParaphraseCandidate: |
| 223 | def test_is_frozen_dataclass(self) -> None: |
| 224 | c = ParaphraseCandidate( |
| 225 | prompt="p", gap=0.5, verbatim_lift=1.0, paraphrase_lift=0.5, diversity_rank=0 |
| 226 | ) |
| 227 | with pytest.raises(Exception): # noqa: B017, PT011 — FrozenInstanceError varies |
| 228 | c.gap = 0.0 # type: ignore[misc] |