Python · 7124 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.mining.outlier_miner`.
2
3 Uses the dummy backend's synthesized per-prompt divergences: base is
4 sharply peaked, ft is broad, and the dummy's ``next_token_dist`` cache
5 keys on the prompt string — so each candidate prompt produces the same
6 divergence unless we overlay per-prompt TokenDists. Tests that need
7 variation build per-prompt TokenDists directly on ``DummyResponses``.
8 """
9
10 from __future__ import annotations
11
12 import math
13
14 import numpy as np
15 import pytest
16
17 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
18 from dlm_sway.core.scoring import TokenDist
19 from dlm_sway.mining.outlier_miner import (
20 OutlierCandidate,
21 OutlierResult,
22 corpus_prompts,
23 mine_outliers,
24 )
25
26
27 def _dist_from_probs(probs: list[float]) -> TokenDist:
28 arr = np.asarray(probs, dtype=np.float64)
29 arr = arr / arr.sum()
30 lp = np.log(arr).astype(np.float32)
31 return TokenDist(
32 token_ids=np.arange(len(probs), dtype=np.int64),
33 logprobs=lp,
34 vocab_size=max(1000, len(probs)),
35 tail_logprob=None,
36 )
37
38
39 class TestMineOutliers:
40 def test_ranks_prompts_by_per_prompt_divergence(self) -> None:
41 """Six prompts with planted divergences: ``hi*`` have the biggest
42 gap, ``lo*`` the smallest, ``mid*`` in between. Top-K = hi rows,
43 bottom-K = lo rows."""
44 base = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02])
45 ft_flat = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15]) # big KL
46 ft_mild = _dist_from_probs([0.70, 0.10, 0.10, 0.05, 0.05]) # mid KL
47 ft_same = base # zero KL
48
49 # F04 — need ≥ 2·top_k=4 distinct prompts to clear the guard.
50 prompts = ["hi1", "hi2", "mid1", "mid2", "lo1", "lo2"]
51 base_dists = dict.fromkeys(prompts, base)
52 ft_dists = {
53 "hi1": ft_flat,
54 "hi2": ft_flat,
55 "mid1": ft_mild,
56 "mid2": ft_mild,
57 "lo1": ft_same,
58 "lo2": ft_same,
59 }
60 backend = DummyDifferentialBackend(
61 base=DummyResponses(token_dists=base_dists),
62 ft=DummyResponses(token_dists=ft_dists),
63 )
64
65 result = mine_outliers(
66 probe_kind="delta_kl",
67 candidate_prompts=prompts,
68 backend=backend,
69 top_k=2,
70 )
71
72 assert isinstance(result, OutlierResult)
73 assert result.probe_kind == "delta_kl"
74 # Top is most-positive first; bottom is least-positive first.
75 top_prompts = {c.prompt for c in result.top}
76 bottom_prompts = {c.prompt for c in result.bottom}
77 assert top_prompts == {"hi1", "hi2"}
78 assert bottom_prompts == {"lo1", "lo2"}
79 # Raw values are finite and positive (JS divergence ≥ 0).
80 for c in result.top:
81 assert math.isfinite(c.raw)
82 assert c.raw >= 0.0
83
84 def test_small_pool_raises_f04_guard(self) -> None:
85 """F04 (Audit 03) — pool below ``2·top_k`` distinct prompts
86 raises SwayError with an actionable hint. Replaces pre-F04
87 'test_top_k_clipped_to_pool_size' which relied on the same
88 degenerate single-prompt case the audit flagged as produced
89 top=[p], bottom=[p] — identical lists."""
90 from dlm_sway.core.errors import SwayError
91
92 base = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02])
93 ft = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15])
94 backend = DummyDifferentialBackend(
95 base=DummyResponses(token_dists={"p": base}),
96 ft=DummyResponses(token_dists={"p": ft}),
97 )
98 with pytest.raises(SwayError, match="below the 2·top_k"):
99 mine_outliers(
100 probe_kind="delta_kl",
101 candidate_prompts=["p"],
102 backend=backend,
103 top_k=10,
104 )
105
106 def test_small_pool_error_suggests_smaller_top_k(self) -> None:
107 """The error message includes a concrete ``--top-k N`` hint the
108 user can copy into their CLI invocation."""
109 from dlm_sway.core.errors import SwayError
110
111 base = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02])
112 ft = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15])
113 prompts = ["p1", "p2", "p3"]
114 backend = DummyDifferentialBackend(
115 base=DummyResponses(token_dists=dict.fromkeys(prompts, base)),
116 ft=DummyResponses(token_dists=dict.fromkeys(prompts, ft)),
117 )
118 with pytest.raises(SwayError, match="Pass --top-k 1"):
119 mine_outliers(
120 probe_kind="delta_kl",
121 candidate_prompts=prompts,
122 backend=backend,
123 top_k=5,
124 )
125
126 def test_empty_pool_returns_empty_result(self) -> None:
127 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
128 result = mine_outliers(
129 probe_kind="delta_kl",
130 candidate_prompts=[],
131 backend=backend,
132 top_k=5,
133 )
134 assert result.top == []
135 assert result.bottom == []
136
137 def test_unsupported_probe_kind_returns_empty(self) -> None:
138 """Probes that need a non-``prompts`` spec (leakage, etc.) skip
139 every candidate silently. The F04 floor doesn't fire in that
140 case because the scored list is empty — empty-result path
141 preserved for the unsupported-kind UX."""
142 base = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02])
143 ft = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15])
144 backend = DummyDifferentialBackend(
145 base=DummyResponses(token_dists={"p": base}),
146 ft=DummyResponses(token_dists={"p": ft}),
147 )
148 result = mine_outliers(
149 probe_kind="leakage",
150 candidate_prompts=["p"],
151 backend=backend,
152 top_k=5,
153 )
154 assert result.top == []
155 assert result.bottom == []
156
157 def test_rejects_top_k_zero(self) -> None:
158 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
159 with pytest.raises(ValueError, match="top_k must be positive"):
160 mine_outliers(
161 probe_kind="delta_kl",
162 candidate_prompts=["a"],
163 backend=backend,
164 top_k=0,
165 )
166
167
168 class TestCorpusPrompts:
169 def test_pulls_chunks_from_public_domain(self) -> None:
170 """``--from-corpus public_domain_en`` yields a list of strings
171 long enough for probe scoring."""
172 chunks = corpus_prompts("public_domain_en", chunk_chars=512, max_chunks=4)
173 assert chunks
174 assert len(chunks) <= 4
175 for c in chunks:
176 assert isinstance(c, str)
177 assert len(c) >= 64 # chunk_corpus's minimum
178
179 def test_unknown_corpus_raises(self) -> None:
180 with pytest.raises(KeyError):
181 corpus_prompts("doesnotexist")
182
183
184 class TestOutlierCandidate:
185 def test_is_frozen(self) -> None:
186 c = OutlierCandidate(prompt="p", raw=0.5, index=0)
187 with pytest.raises(Exception): # noqa: B017, PT011
188 c.raw = 0.0 # type: ignore[misc]