Python · 7310 bytes Raw Blame History
1 """Smoke tests for ``sway mine`` — S17 CLI surface.
2
3 Follows the test_sway_gate_exit_code pattern: stub ``backends.build``
4 with a dummy-returning factory so the CLI runs without loading a real
5 HF model. The paraphrase generator is also stubbed (via monkeypatch
6 on ``dlm_sway.mining.paraphrase_miner.nlpaug_candidates``) so tests
7 don't need the nlpaug wheel.
8 """
9
10 from __future__ import annotations
11
12 from pathlib import Path
13 from typing import Any
14
15 import numpy as np
16 import pytest
17 import yaml
18 from typer.testing import CliRunner
19
20 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
21 from dlm_sway.cli.app import app
22 from dlm_sway.core.scoring import TokenDist
23
24
25 def _programmable_backend() -> DummyDifferentialBackend:
26 """Dummy backend with predictable logprobs + TokenDists so both
27 mine modes have something to chew on."""
28 base_dist = TokenDist(
29 token_ids=np.arange(5, dtype=np.int64),
30 logprobs=np.log(np.array([0.92, 0.02, 0.02, 0.02, 0.02], dtype=np.float32)),
31 vocab_size=1000,
32 )
33 ft_dist = TokenDist(
34 token_ids=np.arange(5, dtype=np.int64),
35 logprobs=np.log(np.array([0.25, 0.20, 0.20, 0.20, 0.15], dtype=np.float32)),
36 vocab_size=1000,
37 )
38 prompts = ["p1", "p2", "p3", "p4"]
39 base_token_dists = dict.fromkeys(prompts, base_dist)
40 ft_token_dists = dict.fromkeys(prompts, ft_dist)
41 # Logprobs for the paraphrase case scoring.
42 base_lp = {
43 ("seed prompt", " gold"): -3.0,
44 ("C1", " gold"): -3.0,
45 ("C2", " gold"): -3.0,
46 ("C3", " gold"): -3.0,
47 }
48 ft_lp = {
49 ("seed prompt", " gold"): -1.0, # lift +2
50 ("C1", " gold"): -3.0, # no lift → big gap
51 ("C2", " gold"): -2.0, # partial lift → medium gap
52 ("C3", " gold"): -1.0, # full lift → zero gap
53 }
54 return DummyDifferentialBackend(
55 base=DummyResponses(token_dists=base_token_dists, logprobs=base_lp),
56 ft=DummyResponses(token_dists=ft_token_dists, logprobs=ft_lp),
57 )
58
59
60 @pytest.fixture
61 def stub_build_backend(monkeypatch: pytest.MonkeyPatch) -> None:
62 """Replace ``backends.build`` so the CLI doesn't try to load HF."""
63
64 def _factory(*_args: object, **_kwargs: object) -> DummyDifferentialBackend:
65 return _programmable_backend()
66
67 import dlm_sway.backends as backends_mod
68
69 monkeypatch.setattr(backends_mod, "build", _factory)
70
71
72 @pytest.fixture
73 def stub_embedder(monkeypatch: pytest.MonkeyPatch) -> None:
74 """Stub the MiniLM embedder so paraphrase mining doesn't need
75 sentence-transformers installed."""
76 table: dict[str, np.ndarray] = {
77 "seed prompt": np.array([1.0, 0.0, 0.0], dtype=np.float32),
78 "C1": np.array([0.0, 1.0, 0.0], dtype=np.float32),
79 "C2": np.array([0.0, 0.0, 1.0], dtype=np.float32),
80 "C3": np.array([0.5, 0.5, 0.0], dtype=np.float32),
81 }
82
83 def _encode(texts: list[str]) -> np.ndarray:
84 return np.stack([table[t] for t in texts])
85
86 monkeypatch.setattr(
87 "dlm_sway.mining.paraphrase_miner._load_embedder",
88 lambda _model_id: _encode, # type: ignore[arg-type]
89 )
90
91
92 @pytest.fixture
93 def stub_nlpaug(monkeypatch: pytest.MonkeyPatch) -> None:
94 """Stub the nlpaug generator so paraphrase mining doesn't need
95 the nlpaug wheel."""
96
97 def _gen(prompt: str, *, n: int, seed: int) -> list[str]:
98 del prompt, n, seed
99 return ["C1", "C2", "C3"]
100
101 monkeypatch.setattr("dlm_sway.mining.paraphrase_miner.nlpaug_candidates", _gen)
102
103
104 def _write_paraphrase_spec(path: Path) -> None:
105 path.write_text(
106 """
107 version: 1
108 models:
109 base: {base: stub, kind: hf, adapter: /tmp/stub}
110 ft: {base: stub, kind: hf, adapter: /tmp/stub}
111 suite:
112 - name: pi
113 kind: paraphrase_invariance
114 cases:
115 - prompt: seed prompt
116 gold: " gold"
117 paraphrases: ["x"]
118 """.strip()
119 )
120
121
122 def _write_delta_kl_spec(path: Path) -> None:
123 path.write_text(
124 """
125 version: 1
126 models:
127 base: {base: stub, kind: hf, adapter: /tmp/stub}
128 ft: {base: stub, kind: hf, adapter: /tmp/stub}
129 suite:
130 - name: dk
131 kind: delta_kl
132 prompts: [p1, p2, p3, p4, p5, p6, p7, p8]
133 assert_mean_gte: 0.0
134 """.strip()
135 )
136
137
138 class TestMineParaphrase:
139 def test_emits_yaml_fragment_with_mined_cases(
140 self,
141 stub_build_backend: None, # noqa: ARG002
142 stub_embedder: None, # noqa: ARG002
143 stub_nlpaug: None, # noqa: ARG002
144 tmp_path: Path,
145 ) -> None:
146 spec = tmp_path / "sway.yaml"
147 _write_paraphrase_spec(spec)
148 out = tmp_path / "mined.yaml"
149
150 result = CliRunner().invoke(
151 app,
152 [
153 "mine",
154 str(spec),
155 "--mode",
156 "paraphrase",
157 "--out",
158 str(out),
159 "--n-candidates",
160 "3",
161 "--top-k",
162 "3",
163 ],
164 )
165 assert result.exit_code == 0, result.stdout
166 assert out.exists()
167 payload: dict[str, Any] = yaml.safe_load(out.read_text())
168 assert "mined_cases" in payload
169 cases = payload["mined_cases"]
170 assert len(cases) == 1
171 case = cases[0]
172 assert case["prompt"] == "seed prompt"
173 assert case["gold"] == " gold"
174 # Paraphrases are ranked hardest-first; C1 had the largest gap.
175 assert case["paraphrases"][0] == "C1"
176 assert "_mining_meta" in case
177
178
179 class TestMineOutliers:
180 def test_emits_top_and_bottom(
181 self,
182 stub_build_backend: None, # noqa: ARG002
183 tmp_path: Path,
184 ) -> None:
185 spec = tmp_path / "sway.yaml"
186 _write_delta_kl_spec(spec)
187 out = tmp_path / "outliers.yaml"
188
189 result = CliRunner().invoke(
190 app,
191 ["mine", str(spec), "--mode", "outliers", "--out", str(out), "--top-k", "3"],
192 )
193 assert result.exit_code == 0, result.stdout
194 assert out.exists()
195 payload: dict[str, Any] = yaml.safe_load(out.read_text())
196 assert "mined_outliers" in payload
197 rollup = payload["mined_outliers"]
198 assert rollup["probe_kind"] == "delta_kl"
199 assert isinstance(rollup["top"], list)
200 assert isinstance(rollup["bottom"], list)
201 # All 4 delta_kl prompts get scored; top-K is clipped to 3.
202 assert len(rollup["top"]) == 3
203 assert len(rollup["bottom"]) == 3
204
205 def test_no_prompts_and_no_corpus_errors(
206 self,
207 stub_build_backend: None, # noqa: ARG002
208 tmp_path: Path,
209 ) -> None:
210 """A spec with no delta_kl prompts + no --from-corpus flag
211 must exit non-zero with a pointed message rather than write
212 an empty YAML."""
213 spec = tmp_path / "sway.yaml"
214 spec.write_text(
215 """
216 version: 1
217 models:
218 base: {base: stub, kind: hf, adapter: /tmp/stub}
219 ft: {base: stub, kind: hf, adapter: /tmp/stub}
220 suite:
221 - name: pi
222 kind: paraphrase_invariance
223 cases:
224 - prompt: x
225 gold: " y"
226 paraphrases: ["z"]
227 """.strip()
228 )
229 out = tmp_path / "outliers.yaml"
230 result = CliRunner().invoke(
231 app, ["mine", str(spec), "--mode", "outliers", "--out", str(out)]
232 )
233 assert result.exit_code == 2
234 assert not out.exists()