tenseleyflow/sway / cec8c23

Browse files

tests: sway mine CLI smoke — paraphrase + outliers modes (S17.5)

Authored by espadonne
SHA
cec8c2308e88916e9f864149078cb8315c647c11
Parents
28d3673
Tree
2529ab1

1 changed file

StatusFile+-
A tests/unit/test_cli_mine.py 234 0
tests/unit/test_cli_mine.pyadded
@@ -0,0 +1,234 @@
1
+"""Smoke tests for ``sway mine`` — S17 CLI surface.
2
+
3
+Follows the test_sway_gate_exit_code pattern: stub ``backends.build``
4
+with a dummy-returning factory so the CLI runs without loading a real
5
+HF model. The paraphrase generator is also stubbed (via monkeypatch
6
+on ``dlm_sway.mining.paraphrase_miner.nlpaug_candidates``) so tests
7
+don't need the nlpaug wheel.
8
+"""
9
+
10
+from __future__ import annotations
11
+
12
+from pathlib import Path
13
+from typing import Any
14
+
15
+import numpy as np
16
+import pytest
17
+import yaml
18
+from typer.testing import CliRunner
19
+
20
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
21
+from dlm_sway.cli.app import app
22
+from dlm_sway.core.scoring import TokenDist
23
+
24
+
25
+def _programmable_backend() -> DummyDifferentialBackend:
26
+    """Dummy backend with predictable logprobs + TokenDists so both
27
+    mine modes have something to chew on."""
28
+    base_dist = TokenDist(
29
+        token_ids=np.arange(5, dtype=np.int64),
30
+        logprobs=np.log(np.array([0.92, 0.02, 0.02, 0.02, 0.02], dtype=np.float32)),
31
+        vocab_size=1000,
32
+    )
33
+    ft_dist = TokenDist(
34
+        token_ids=np.arange(5, dtype=np.int64),
35
+        logprobs=np.log(np.array([0.25, 0.20, 0.20, 0.20, 0.15], dtype=np.float32)),
36
+        vocab_size=1000,
37
+    )
38
+    prompts = ["p1", "p2", "p3", "p4"]
39
+    base_token_dists = dict.fromkeys(prompts, base_dist)
40
+    ft_token_dists = dict.fromkeys(prompts, ft_dist)
41
+    # Logprobs for the paraphrase case scoring.
42
+    base_lp = {
43
+        ("seed prompt", " gold"): -3.0,
44
+        ("C1", " gold"): -3.0,
45
+        ("C2", " gold"): -3.0,
46
+        ("C3", " gold"): -3.0,
47
+    }
48
+    ft_lp = {
49
+        ("seed prompt", " gold"): -1.0,  # lift +2
50
+        ("C1", " gold"): -3.0,  # no lift → big gap
51
+        ("C2", " gold"): -2.0,  # partial lift → medium gap
52
+        ("C3", " gold"): -1.0,  # full lift → zero gap
53
+    }
54
+    return DummyDifferentialBackend(
55
+        base=DummyResponses(token_dists=base_token_dists, logprobs=base_lp),
56
+        ft=DummyResponses(token_dists=ft_token_dists, logprobs=ft_lp),
57
+    )
58
+
59
+
60
+@pytest.fixture
61
+def stub_build_backend(monkeypatch: pytest.MonkeyPatch) -> None:
62
+    """Replace ``backends.build`` so the CLI doesn't try to load HF."""
63
+
64
+    def _factory(*_args: object, **_kwargs: object) -> DummyDifferentialBackend:
65
+        return _programmable_backend()
66
+
67
+    import dlm_sway.backends as backends_mod
68
+
69
+    monkeypatch.setattr(backends_mod, "build", _factory)
70
+
71
+
72
+@pytest.fixture
73
+def stub_embedder(monkeypatch: pytest.MonkeyPatch) -> None:
74
+    """Stub the MiniLM embedder so paraphrase mining doesn't need
75
+    sentence-transformers installed."""
76
+    table: dict[str, np.ndarray] = {
77
+        "seed prompt": np.array([1.0, 0.0, 0.0], dtype=np.float32),
78
+        "C1": np.array([0.0, 1.0, 0.0], dtype=np.float32),
79
+        "C2": np.array([0.0, 0.0, 1.0], dtype=np.float32),
80
+        "C3": np.array([0.5, 0.5, 0.0], dtype=np.float32),
81
+    }
82
+
83
+    def _encode(texts: list[str]) -> np.ndarray:
84
+        return np.stack([table[t] for t in texts])
85
+
86
+    monkeypatch.setattr(
87
+        "dlm_sway.mining.paraphrase_miner._load_embedder",
88
+        lambda _model_id: _encode,  # type: ignore[arg-type]
89
+    )
90
+
91
+
92
+@pytest.fixture
93
+def stub_nlpaug(monkeypatch: pytest.MonkeyPatch) -> None:
94
+    """Stub the nlpaug generator so paraphrase mining doesn't need
95
+    the nlpaug wheel."""
96
+
97
+    def _gen(prompt: str, *, n: int, seed: int) -> list[str]:
98
+        del prompt, n, seed
99
+        return ["C1", "C2", "C3"]
100
+
101
+    monkeypatch.setattr("dlm_sway.mining.paraphrase_miner.nlpaug_candidates", _gen)
102
+
103
+
104
+def _write_paraphrase_spec(path: Path) -> None:
105
+    path.write_text(
106
+        """
107
+version: 1
108
+models:
109
+  base: {base: stub, kind: hf, adapter: /tmp/stub}
110
+  ft: {base: stub, kind: hf, adapter: /tmp/stub}
111
+suite:
112
+  - name: pi
113
+    kind: paraphrase_invariance
114
+    cases:
115
+      - prompt: seed prompt
116
+        gold: " gold"
117
+        paraphrases: ["x"]
118
+""".strip()
119
+    )
120
+
121
+
122
+def _write_delta_kl_spec(path: Path) -> None:
123
+    path.write_text(
124
+        """
125
+version: 1
126
+models:
127
+  base: {base: stub, kind: hf, adapter: /tmp/stub}
128
+  ft: {base: stub, kind: hf, adapter: /tmp/stub}
129
+suite:
130
+  - name: dk
131
+    kind: delta_kl
132
+    prompts: [p1, p2, p3, p4]
133
+    assert_mean_gte: 0.0
134
+""".strip()
135
+    )
136
+
137
+
138
+class TestMineParaphrase:
139
+    def test_emits_yaml_fragment_with_mined_cases(
140
+        self,
141
+        stub_build_backend: None,  # noqa: ARG002
142
+        stub_embedder: None,  # noqa: ARG002
143
+        stub_nlpaug: None,  # noqa: ARG002
144
+        tmp_path: Path,
145
+    ) -> None:
146
+        spec = tmp_path / "sway.yaml"
147
+        _write_paraphrase_spec(spec)
148
+        out = tmp_path / "mined.yaml"
149
+
150
+        result = CliRunner().invoke(
151
+            app,
152
+            [
153
+                "mine",
154
+                str(spec),
155
+                "--mode",
156
+                "paraphrase",
157
+                "--out",
158
+                str(out),
159
+                "--n-candidates",
160
+                "3",
161
+                "--top-k",
162
+                "3",
163
+            ],
164
+        )
165
+        assert result.exit_code == 0, result.stdout
166
+        assert out.exists()
167
+        payload: dict[str, Any] = yaml.safe_load(out.read_text())
168
+        assert "mined_cases" in payload
169
+        cases = payload["mined_cases"]
170
+        assert len(cases) == 1
171
+        case = cases[0]
172
+        assert case["prompt"] == "seed prompt"
173
+        assert case["gold"] == " gold"
174
+        # Paraphrases are ranked hardest-first; C1 had the largest gap.
175
+        assert case["paraphrases"][0] == "C1"
176
+        assert "_mining_meta" in case
177
+
178
+
179
+class TestMineOutliers:
180
+    def test_emits_top_and_bottom(
181
+        self,
182
+        stub_build_backend: None,  # noqa: ARG002
183
+        tmp_path: Path,
184
+    ) -> None:
185
+        spec = tmp_path / "sway.yaml"
186
+        _write_delta_kl_spec(spec)
187
+        out = tmp_path / "outliers.yaml"
188
+
189
+        result = CliRunner().invoke(
190
+            app,
191
+            ["mine", str(spec), "--mode", "outliers", "--out", str(out), "--top-k", "3"],
192
+        )
193
+        assert result.exit_code == 0, result.stdout
194
+        assert out.exists()
195
+        payload: dict[str, Any] = yaml.safe_load(out.read_text())
196
+        assert "mined_outliers" in payload
197
+        rollup = payload["mined_outliers"]
198
+        assert rollup["probe_kind"] == "delta_kl"
199
+        assert isinstance(rollup["top"], list)
200
+        assert isinstance(rollup["bottom"], list)
201
+        # All 4 delta_kl prompts get scored; top-K is clipped to 3.
202
+        assert len(rollup["top"]) == 3
203
+        assert len(rollup["bottom"]) == 3
204
+
205
+    def test_no_prompts_and_no_corpus_errors(
206
+        self,
207
+        stub_build_backend: None,  # noqa: ARG002
208
+        tmp_path: Path,
209
+    ) -> None:
210
+        """A spec with no delta_kl prompts + no --from-corpus flag
211
+        must exit non-zero with a pointed message rather than write
212
+        an empty YAML."""
213
+        spec = tmp_path / "sway.yaml"
214
+        spec.write_text(
215
+            """
216
+version: 1
217
+models:
218
+  base: {base: stub, kind: hf, adapter: /tmp/stub}
219
+  ft: {base: stub, kind: hf, adapter: /tmp/stub}
220
+suite:
221
+  - name: pi
222
+    kind: paraphrase_invariance
223
+    cases:
224
+      - prompt: x
225
+        gold: " y"
226
+        paraphrases: ["z"]
227
+""".strip()
228
+        )
229
+        out = tmp_path / "outliers.yaml"
230
+        result = CliRunner().invoke(
231
+            app, ["mine", str(spec), "--mode", "outliers", "--out", str(out)]
232
+        )
233
+        assert result.exit_code == 2
234
+        assert not out.exists()