@@ -0,0 +1,228 @@ |
| 1 | +"""Tests for :mod:`dlm_sway.mining.paraphrase_miner`. |
| 2 | + |
| 3 | +We stub the embedder + the nlpaug generator so tests run without the |
| 4 | +80-MB MiniLM load or the nlpaug wheel. What's under test is the ranker |
| 5 | +and the diversity filter — the infrastructure around them (the |
| 6 | +embedder, the generator) is exercised by the integration test. |
| 7 | +""" |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pytest |
| 13 | + |
| 14 | +from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 15 | +from dlm_sway.mining.paraphrase_miner import ( |
| 16 | + MiningResult, |
| 17 | + ParaphraseCandidate, |
| 18 | + mine_paraphrases, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def] |
| 23 | + def _encode(texts: list[str]): # type: ignore[no-untyped-def] |
| 24 | + return np.stack([text_to_vec[t] for t in texts]) |
| 25 | + |
| 26 | + return _encode |
| 27 | + |
| 28 | + |
| 29 | +@pytest.fixture |
| 30 | +def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]: |
| 31 | + table: dict[str, np.ndarray] = {} |
| 32 | + monkeypatch.setattr( |
| 33 | + "dlm_sway.mining.paraphrase_miner._load_embedder", |
| 34 | + lambda _model_id: _stub_embedder(table), # type: ignore[arg-type] |
| 35 | + ) |
| 36 | + return table |
| 37 | + |
| 38 | + |
| 39 | +def _canned_candidates(candidates: list[str]): # type: ignore[no-untyped-def] |
| 40 | + """Build a generator closure that returns ``candidates`` verbatim.""" |
| 41 | + |
| 42 | + def _gen(_prompt: str, *, n: int, seed: int) -> list[str]: |
| 43 | + del n, seed |
| 44 | + return list(candidates) |
| 45 | + |
| 46 | + return _gen |
| 47 | + |
| 48 | + |
| 49 | +class TestRanker: |
| 50 | + def test_gap_sort_ranks_hardest_first(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 51 | + """The paraphrase with the largest verbatim-vs-paraphrase lift |
| 52 | + gap ranks first. Planted: three candidates with known |
| 53 | + ``(base, ft)`` logprob pairs; the ranker must surface them in |
| 54 | + gap-descending order regardless of the generator's input |
| 55 | + order.""" |
| 56 | + # Seed + candidates |
| 57 | + prompt = "Q: what is the capital?" |
| 58 | + gold = " Paris" |
| 59 | + candidates = ["Q1", "Q2", "Q3"] |
| 60 | + # Embed all four distinctly so the diversity filter keeps |
| 61 | + # every candidate. |
| 62 | + monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32) |
| 63 | + monkeyed_embed["Q1"] = np.array([0.0, 1.0, 0.0], dtype=np.float32) |
| 64 | + monkeyed_embed["Q2"] = np.array([0.0, 0.0, 1.0], dtype=np.float32) |
| 65 | + monkeyed_embed["Q3"] = np.array([0.5, 0.5, 0.0], dtype=np.float32) |
| 66 | + |
| 67 | + # Logprob table — verbatim lift is large; paraphrase lifts |
| 68 | + # vary by design: Q1 lift ≈ 0 (big gap), Q2 lift ≈ verbatim |
| 69 | + # (no gap), Q3 lift somewhere in between. |
| 70 | + base_logprobs = { |
| 71 | + (prompt, gold): -4.0, |
| 72 | + ("Q1", gold): -4.0, # base unchanged |
| 73 | + ("Q2", gold): -4.0, |
| 74 | + ("Q3", gold): -4.0, |
| 75 | + } |
| 76 | + ft_logprobs = { |
| 77 | + (prompt, gold): -1.0, # verbatim lift: +3.0 |
| 78 | + ("Q1", gold): -4.0, # paraphrase lift: 0 (big gap of 3) |
| 79 | + ("Q2", gold): -1.0, # paraphrase lift: +3 (gap of 0) |
| 80 | + ("Q3", gold): -2.5, # paraphrase lift: +1.5 (gap of 1.5) |
| 81 | + } |
| 82 | + backend = DummyDifferentialBackend( |
| 83 | + base=DummyResponses(logprobs=base_logprobs), |
| 84 | + ft=DummyResponses(logprobs=ft_logprobs), |
| 85 | + ) |
| 86 | + |
| 87 | + result = mine_paraphrases( |
| 88 | + prompt=prompt, |
| 89 | + gold=gold, |
| 90 | + backend=backend, |
| 91 | + generate_candidates=_canned_candidates(candidates), |
| 92 | + n_candidates=3, |
| 93 | + top_k=3, |
| 94 | + seed=0, |
| 95 | + ) |
| 96 | + |
| 97 | + assert isinstance(result, MiningResult) |
| 98 | + assert [c.prompt for c in result.candidates] == ["Q1", "Q3", "Q2"] |
| 99 | + # Gaps are in descending order. |
| 100 | + gaps = [c.gap for c in result.candidates] |
| 101 | + assert gaps == sorted(gaps, reverse=True) |
| 102 | + # Top candidate's gap is meaningfully larger than 0. |
| 103 | + assert result.candidates[0].gap > 0.1 |
| 104 | + |
| 105 | + def test_dedup_drops_verbatim_seed(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 106 | + """If the generator echoes the seed prompt back in its output, |
| 107 | + it's dropped before the ranker sees it — otherwise the seed |
| 108 | + would bubble to the top with gap = 0 and pollute the list.""" |
| 109 | + prompt = "prompt" |
| 110 | + gold = " gold" |
| 111 | + # Generator returns the seed + two real candidates; the seed |
| 112 | + # must not appear in the final result. |
| 113 | + candidates = ["prompt", "C1", "C2"] |
| 114 | + for p in candidates + [prompt]: |
| 115 | + monkeyed_embed[p] = ( |
| 116 | + np.random.RandomState(hash(p) & 0xFFFFFFFF).randn(4).astype(np.float32) |
| 117 | + ) |
| 118 | + |
| 119 | + base_logprobs = {(p, gold): -3.0 for p in [prompt, "C1", "C2"]} |
| 120 | + ft_logprobs = {(prompt, gold): -1.0, ("C1", gold): -2.5, ("C2", gold): -2.0} |
| 121 | + backend = DummyDifferentialBackend( |
| 122 | + base=DummyResponses(logprobs=base_logprobs), |
| 123 | + ft=DummyResponses(logprobs=ft_logprobs), |
| 124 | + ) |
| 125 | + |
| 126 | + result = mine_paraphrases( |
| 127 | + prompt=prompt, |
| 128 | + gold=gold, |
| 129 | + backend=backend, |
| 130 | + generate_candidates=_canned_candidates(candidates), |
| 131 | + n_candidates=3, |
| 132 | + top_k=3, |
| 133 | + seed=0, |
| 134 | + ) |
| 135 | + assert prompt not in {c.prompt for c in result.candidates} |
| 136 | + |
| 137 | + def test_empty_candidate_list_returns_empty_result( |
| 138 | + self, monkeyed_embed: dict[str, np.ndarray] |
| 139 | + ) -> None: |
| 140 | + del monkeyed_embed |
| 141 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 142 | + result = mine_paraphrases( |
| 143 | + prompt="x", |
| 144 | + gold=" y", |
| 145 | + backend=backend, |
| 146 | + generate_candidates=_canned_candidates([]), |
| 147 | + n_candidates=3, |
| 148 | + top_k=3, |
| 149 | + ) |
| 150 | + assert isinstance(result, MiningResult) |
| 151 | + assert result.candidates == [] |
| 152 | + |
| 153 | + |
| 154 | +class TestDiversityFilter: |
| 155 | + def test_keeps_farthest_when_candidates_cluster( |
| 156 | + self, monkeyed_embed: dict[str, np.ndarray] |
| 157 | + ) -> None: |
| 158 | + """Two near-duplicate candidates + one distant one → the |
| 159 | + distant candidate must survive the k=2 diversity filter.""" |
| 160 | + prompt = "seed" |
| 161 | + gold = " gold" |
| 162 | + candidates = ["near1", "near2", "far"] |
| 163 | + monkeyed_embed[prompt] = np.array([1.0, 0.0, 0.0], dtype=np.float32) |
| 164 | + # near1 / near2 collapse onto the same embedding; far lives |
| 165 | + # orthogonally. With k=2, the filter must pick one near + far |
| 166 | + # (not both nears). |
| 167 | + monkeyed_embed["near1"] = np.array([0.9, 0.1, 0.0], dtype=np.float32) |
| 168 | + monkeyed_embed["near2"] = np.array([0.9, 0.1, 0.0], dtype=np.float32) |
| 169 | + monkeyed_embed["far"] = np.array([0.0, 0.0, 1.0], dtype=np.float32) |
| 170 | + |
| 171 | + base_logprobs = {("near1", gold): -3.0, ("near2", gold): -3.0, ("far", gold): -3.0} |
| 172 | + ft_logprobs = {("near1", gold): -2.0, ("near2", gold): -2.0, ("far", gold): -1.0} |
| 173 | + base_logprobs[(prompt, gold)] = -3.0 |
| 174 | + ft_logprobs[(prompt, gold)] = -1.0 |
| 175 | + backend = DummyDifferentialBackend( |
| 176 | + base=DummyResponses(logprobs=base_logprobs), |
| 177 | + ft=DummyResponses(logprobs=ft_logprobs), |
| 178 | + ) |
| 179 | + |
| 180 | + result = mine_paraphrases( |
| 181 | + prompt=prompt, |
| 182 | + gold=gold, |
| 183 | + backend=backend, |
| 184 | + generate_candidates=_canned_candidates(candidates), |
| 185 | + n_candidates=3, |
| 186 | + top_k=2, |
| 187 | + seed=0, |
| 188 | + ) |
| 189 | + chosen = {c.prompt for c in result.candidates} |
| 190 | + assert "far" in chosen |
| 191 | + assert len(chosen & {"near1", "near2"}) == 1 |
| 192 | + |
| 193 | + |
| 194 | +class TestInputValidation: |
| 195 | + def test_rejects_top_k_zero(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 196 | + del monkeyed_embed |
| 197 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 198 | + with pytest.raises(ValueError, match="top_k must be positive"): |
| 199 | + mine_paraphrases( |
| 200 | + prompt="x", |
| 201 | + gold=" y", |
| 202 | + backend=backend, |
| 203 | + generate_candidates=_canned_candidates(["a"]), |
| 204 | + n_candidates=5, |
| 205 | + top_k=0, |
| 206 | + ) |
| 207 | + |
| 208 | + def test_rejects_n_candidates_below_top_k(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 209 | + del monkeyed_embed |
| 210 | + backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 211 | + with pytest.raises(ValueError, match="must be ≥ top_k"): |
| 212 | + mine_paraphrases( |
| 213 | + prompt="x", |
| 214 | + gold=" y", |
| 215 | + backend=backend, |
| 216 | + generate_candidates=_canned_candidates(["a"]), |
| 217 | + n_candidates=2, |
| 218 | + top_k=5, |
| 219 | + ) |
| 220 | + |
| 221 | + |
| 222 | +class TestParaphraseCandidate: |
| 223 | + def test_is_frozen_dataclass(self) -> None: |
| 224 | + c = ParaphraseCandidate( |
| 225 | + prompt="p", gap=0.5, verbatim_lift=1.0, paraphrase_lift=0.5, diversity_rank=0 |
| 226 | + ) |
| 227 | + with pytest.raises(Exception): # noqa: B017, PT011 — FrozenInstanceError varies |
| 228 | + c.gap = 0.0 # type: ignore[misc] |