Python · 15614 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.cluster_kl`.
2
3 Uses stubbed embeddings + canned :class:`TokenDist` values so tests don't
4 need sentence-transformers or a GPU. The arithmetic of specificity is
5 exercised on shapes we pick: two clearly-separated topics with different
6 mean KLs yield a high specificity ratio; a uniform adapter (every prompt
7 shifted identically) produces zero variance both ways and falls back to
8 the ``0.5`` null expectation.
9 """
10
11 from __future__ import annotations
12
13 import math
14 from typing import Any
15
16 import numpy as np
17 import pytest
18
19 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
20 from dlm_sway.core.result import Verdict
21 from dlm_sway.core.scoring import TokenDist
22 from dlm_sway.probes.base import RunContext, build_probe
23 from dlm_sway.probes.cluster_kl import ClusterKLProbe
24
25
26 def _dist_sharp(seed_offset: int = 0) -> TokenDist:
27 """Sharp distribution: nearly all mass on token 0."""
28 k = 8
29 lp = np.array([-0.1 + 0.01 * seed_offset] + [-5.0] * (k - 1), dtype=np.float32)
30 residual = 1.0 - float(np.exp(lp).sum())
31 tail = math.log(residual) if residual > 1e-12 else None
32 return TokenDist(
33 token_ids=np.arange(k, dtype=np.int64),
34 logprobs=lp,
35 vocab_size=1000,
36 tail_logprob=tail,
37 )
38
39
40 def _dist_broad() -> TokenDist:
41 """Broad distribution: uniform over top-k, with a tiny monotonic
42 perturbation so it clears ``_divergence``'s uniformity guard (a
43 literally-flat dist looks like a broken lm_head)."""
44 k = 8
45 lp = np.full(k, -math.log(k), dtype=np.float32)
46 lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32)
47 residual = 1.0 - float(np.exp(lp).sum())
48 tail = math.log(residual) if residual > 1e-12 else None
49 return TokenDist(
50 token_ids=np.arange(k, dtype=np.int64),
51 logprobs=lp,
52 vocab_size=1000,
53 tail_logprob=tail,
54 )
55
56
57 def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def]
58 def _encode(texts: list[str]): # type: ignore[no-untyped-def]
59 return np.stack([text_to_vec[t] for t in texts])
60
61 return _encode
62
63
64 def _argmax_kmeans(embeddings: np.ndarray, *, k: int, seed: int) -> np.ndarray:
65 """sklearn-free stub: cluster by argmax of the one-hot test embeddings.
66
67 The tests construct embeddings in the canonical basis so each vector's
68 argmax is its intended cluster ID. Keeps the unit tests runnable on CI
69 runners that don't install the ``[semsim]`` extra.
70 """
71 del seed # deterministic by construction
72 labels = np.argmax(embeddings, axis=1).astype(np.int64)
73 return labels % k
74
75
76 @pytest.fixture
77 def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
78 """Install a stub embedder + sklearn-free k-means on ``cluster_kl``'s
79 helpers. Matches the ``adapter_revert`` test pattern but also bypasses
80 ``sklearn.cluster.KMeans`` so tests work without the ``[semsim]`` extra.
81 """
82 table: dict[str, np.ndarray] = {}
83 monkeypatch.setattr(
84 "dlm_sway.probes.cluster_kl._load_embedder",
85 lambda _model_id: _stub_embedder(table), # type: ignore[arg-type]
86 )
87 monkeypatch.setattr(
88 "dlm_sway.probes.cluster_kl._kmeans_cluster",
89 _argmax_kmeans,
90 )
91 return table
92
93
94 def _two_topic_backend(topic_a: list[str], topic_b: list[str]) -> DummyDifferentialBackend:
95 """Base is sharp on all prompts. ft is broad on topic A (high KL) and
96 sharp on topic B (near-zero KL). Produces a strong per-topic signal.
97 """
98 base_dists: dict[str, TokenDist] = {}
99 ft_dists: dict[str, TokenDist] = {}
100 for p in topic_a:
101 base_dists[p] = _dist_sharp()
102 ft_dists[p] = _dist_broad() # diverges sharply from base
103 for p in topic_b:
104 base_dists[p] = _dist_sharp()
105 ft_dists[p] = _dist_sharp() # matches base → ~0 divergence
106 base = DummyResponses(token_dists=base_dists)
107 ft = DummyResponses(token_dists=ft_dists)
108 return DummyDifferentialBackend(base=base, ft=ft)
109
110
111 def _uniform_backend(prompts: list[str]) -> DummyDifferentialBackend:
112 """Base and ft produce *identical* distributions for every prompt.
113 Divergences are all zero, so both variances are zero and the
114 specificity ratio falls back to the ``0.5`` convention.
115 """
116 base_dists = {p: _dist_sharp() for p in prompts}
117 ft_dists = {p: _dist_sharp() for p in prompts}
118 base = DummyResponses(token_dists=base_dists)
119 ft = DummyResponses(token_dists=ft_dists)
120 return DummyDifferentialBackend(base=base, ft=ft)
121
122
123 class TestClusterKL:
124 def test_two_topic_adapter_high_specificity(
125 self, monkeyed_embed: dict[str, np.ndarray]
126 ) -> None:
127 """Two clearly-separated topics where only topic A is shifted by
128 the adapter → specificity ratio drives toward 1.0."""
129 topic_a = [f"A-prompt-{i}" for i in range(6)]
130 topic_b = [f"B-prompt-{i}" for i in range(6)]
131 for p in topic_a:
132 monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32)
133 for p in topic_b:
134 monkeyed_embed[p] = np.array([0.0, 1.0], dtype=np.float32)
135
136 probe, spec = build_probe(
137 {
138 "name": "ck",
139 "kind": "cluster_kl",
140 "prompts": topic_a + topic_b,
141 "num_clusters": 2,
142 "min_prompts": 4,
143 }
144 )
145 ctx = RunContext(backend=_two_topic_backend(topic_a, topic_b))
146 result = probe.run(spec, ctx)
147
148 assert result.raw is not None
149 assert result.raw > 0.8, (
150 f"expected specificity >> 0.5 for a topic-specific adapter; got {result.raw:.3f}"
151 )
152 assert result.evidence["num_clusters"] == 2
153 assert result.evidence["num_prompts"] == 12
154 # Cluster means differ sharply: one near broad-vs-sharp KL, one near 0.
155 per_cluster = result.evidence["per_cluster_mean_kl"]
156 assert len(per_cluster) == 2
157 hi, lo = sorted(per_cluster, reverse=True)
158 assert hi > 0.1, f"expected high-cluster mean > 0.1; got {per_cluster}"
159 assert lo < 0.01, f"expected low-cluster mean < 0.01; got {per_cluster}"
160
161 def test_uniform_adapter_fallback_to_half(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
162 """All prompts shifted identically → zero between-/within-variance
163 → specificity lands on the ``0.5`` fallback (not NaN). F17: the
164 degenerate branch returns WARN with a ``degenerate_zero_variance``
165 marker and no z-score, not a spurious calibrated verdict."""
166 prompts = [f"p-{i}" for i in range(8)]
167 # Split embeddings across two centroids so k-means has a valid
168 # partition; the divergence math is what drives the ratio.
169 for i, p in enumerate(prompts):
170 vec = [1.0, 0.0] if i % 2 == 0 else [0.0, 1.0]
171 monkeyed_embed[p] = np.array(vec, dtype=np.float32)
172
173 probe, spec = build_probe(
174 {
175 "name": "ck",
176 "kind": "cluster_kl",
177 "prompts": prompts,
178 "num_clusters": 2,
179 "min_prompts": 4,
180 }
181 )
182 ctx = RunContext(backend=_uniform_backend(prompts))
183 result = probe.run(spec, ctx)
184
185 assert result.raw == pytest.approx(0.5, abs=1e-6)
186 assert result.evidence["within_cluster_variance"] == pytest.approx(0.0)
187 assert result.evidence["between_cluster_variance"] == pytest.approx(0.0)
188 # F17: degenerate case gets a WARN verdict + explicit evidence
189 # marker; no z-score is emitted (comparing a conventional 0.5
190 # to a null mean near 0.5 would produce spurious calibration).
191 assert result.verdict == Verdict.WARN
192 assert result.z_score is None
193 assert result.evidence["degenerate_zero_variance"] is True
194 assert "degenerate" in result.message.lower()
195
196 def test_too_few_prompts_skips(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
197 del monkeyed_embed # no embedding needed — SKIP short-circuits first
198 probe, spec = build_probe(
199 {
200 "name": "ck",
201 "kind": "cluster_kl",
202 "prompts": ["a", "b", "c"],
203 "num_clusters": 2,
204 "min_prompts": 10,
205 }
206 )
207 ctx = RunContext(
208 backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
209 )
210 result = probe.run(spec, ctx)
211 assert result.verdict == Verdict.SKIP
212 assert "≥10" in result.message
213
214 def test_empty_prompts_errors(self) -> None:
215 probe, spec = build_probe(
216 {"name": "ck", "kind": "cluster_kl", "prompts": [], "min_prompts": 4}
217 )
218 ctx = RunContext(
219 backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
220 )
221 result = probe.run(spec, ctx)
222 assert result.verdict == Verdict.ERROR
223
224 def test_num_clusters_gt_prompts_skips(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
225 del monkeyed_embed
226 # 5 prompts, k=3 → 3*2 = 6 > 5 → SKIP.
227 probe, spec = build_probe(
228 {
229 "name": "ck",
230 "kind": "cluster_kl",
231 "prompts": [f"p{i}" for i in range(5)],
232 "num_clusters": 3,
233 "min_prompts": 4,
234 }
235 )
236 ctx = RunContext(
237 backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
238 )
239 result = probe.run(spec, ctx)
240 assert result.verdict == Verdict.SKIP
241 assert "num_clusters=3" in result.message
242
243 def test_ci_95_populated(self, monkeyed_embed: dict[str, np.ndarray]) -> None:
244 """Bootstrap CI lands on the ProbeResult and brackets raw."""
245 topic_a = [f"A-{i}" for i in range(5)]
246 topic_b = [f"B-{i}" for i in range(5)]
247 for p in topic_a:
248 monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32)
249 for p in topic_b:
250 monkeyed_embed[p] = np.array([0.0, 1.0], dtype=np.float32)
251
252 probe, spec = build_probe(
253 {
254 "name": "ck",
255 "kind": "cluster_kl",
256 "prompts": topic_a + topic_b,
257 "num_clusters": 2,
258 "min_prompts": 4,
259 }
260 )
261 ctx = RunContext(backend=_two_topic_backend(topic_a, topic_b))
262 result = probe.run(spec, ctx)
263
264 assert result.ci_95 is not None
265 lo, hi = result.ci_95
266 assert 0.0 <= lo <= hi <= 1.0
267 assert result.raw is not None
268 # Bootstrap on a strong signal should bracket raw; allow a
269 # small slack because resampling (6,6) prompt pairs is noisy.
270 assert lo - 0.05 <= result.raw <= hi + 0.05
271
272
273 class TestMissingSemsim:
274 def test_skip_when_extras_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
275 from dlm_sway.core.errors import BackendNotAvailableError
276
277 def raiser(_model_id: Any) -> Any: # type: ignore[no-untyped-def]
278 raise BackendNotAvailableError(
279 "cluster_kl",
280 extra="semsim",
281 hint="cluster_kl needs sentence-transformers + scikit-learn.",
282 )
283
284 monkeypatch.setattr(
285 "dlm_sway.probes.cluster_kl._load_embedder",
286 raiser, # type: ignore[arg-type]
287 )
288 probe = ClusterKLProbe()
289 spec = probe.spec_cls(
290 name="ck",
291 kind="cluster_kl",
292 prompts=[f"p-{i}" for i in range(8)],
293 num_clusters=2,
294 min_prompts=4,
295 )
296 ctx = RunContext(
297 backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
298 )
299 result = probe.run(spec, ctx)
300 assert result.verdict == Verdict.SKIP
301 assert "semsim" in result.message
302
303 def test_skip_when_sklearn_import_fails(
304 self, monkeypatch: pytest.MonkeyPatch, monkeyed_embed: dict[str, np.ndarray]
305 ) -> None:
306 """Covers the ``_kmeans_cluster`` import-error SKIP branch directly.
307
308 The ``_load_embedder`` raise branch is tested above; this test
309 stubs ``_load_embedder`` to succeed and replaces
310 ``_kmeans_cluster`` with a raiser that mimics an uninstalled
311 sklearn. Before this test, the sklearn-missing SKIP path in
312 ``probes/cluster_kl.py`` was unreachable under any test — the
313 embedder raise always fired first.
314 """
315 from dlm_sway.core.errors import BackendNotAvailableError
316
317 for p in [f"p-{i}" for i in range(8)]:
318 monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32)
319
320 def sklearn_raiser(*_args: Any, **_kwargs: Any) -> Any:
321 raise BackendNotAvailableError(
322 "cluster_kl",
323 extra="semsim",
324 hint="cluster_kl needs scikit-learn for k-means clustering.",
325 )
326
327 monkeypatch.setattr(
328 "dlm_sway.probes.cluster_kl._kmeans_cluster",
329 sklearn_raiser,
330 )
331 probe = ClusterKLProbe()
332 spec = probe.spec_cls(
333 name="ck",
334 kind="cluster_kl",
335 prompts=[f"p-{i}" for i in range(8)],
336 num_clusters=2,
337 min_prompts=4,
338 )
339 ctx = RunContext(
340 backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
341 )
342 result = probe.run(spec, ctx)
343 assert result.verdict == Verdict.SKIP
344 assert "semsim" in result.message
345 assert "scikit-learn" in result.message
346
347
348 class TestRealKMeans:
349 """Exercise the actual ``sklearn.cluster.KMeans`` primitive.
350
351 Every other test in this file monkeypatches ``_kmeans_cluster`` with
352 an argmax stub so suites can run in CI environments without the
353 ``[semsim]`` extra installed. That leaves the real sklearn path —
354 the probe's entire reason for existing — uncovered. The tests here
355 skip when sklearn isn't available and execute the real import
356 otherwise.
357 """
358
359 def test_real_kmeans_separates_two_gaussians(self) -> None:
360 """Two clearly-separated clusters → k-means recovers the correct
361 partition with a fixed seed."""
362 pytest.importorskip("sklearn")
363 from dlm_sway.probes.cluster_kl import _kmeans_cluster
364
365 rng = np.random.default_rng(0)
366 # Cluster A centered at (0, 0); cluster B centered at (5, 0).
367 group_a = rng.normal(loc=0.0, scale=0.5, size=(8, 2)).astype(np.float32)
368 group_b = rng.normal(loc=(5.0, 0.0), scale=0.5, size=(8, 2)).astype(np.float32)
369 embeddings = np.vstack([group_a, group_b])
370 labels = _kmeans_cluster(embeddings, k=2, seed=0)
371 assert labels.shape == (16,)
372 # All-A should share a label; all-B should share the other.
373 label_a = set(labels[:8].tolist())
374 label_b = set(labels[8:].tolist())
375 assert len(label_a) == 1
376 assert len(label_b) == 1
377 assert label_a != label_b
378
379 def test_real_kmeans_seed_is_deterministic(self) -> None:
380 """Two runs with the same seed → identical label vectors. Pins
381 the determinism contract in a way that the argmax stub can't.
382 """
383 pytest.importorskip("sklearn")
384 from dlm_sway.probes.cluster_kl import _kmeans_cluster
385
386 rng = np.random.default_rng(0)
387 embeddings = rng.normal(size=(20, 4)).astype(np.float32)
388 labels_a = _kmeans_cluster(embeddings, k=3, seed=42)
389 labels_b = _kmeans_cluster(embeddings, k=3, seed=42)
390 assert np.array_equal(labels_a, labels_b)