| 1 | """Tests for :mod:`dlm_sway.probes.cluster_kl`. |
| 2 | |
| 3 | Uses stubbed embeddings + canned :class:`TokenDist` values so tests don't |
| 4 | need sentence-transformers or a GPU. The arithmetic of specificity is |
| 5 | exercised on shapes we pick: two clearly-separated topics with different |
| 6 | mean KLs yield a high specificity ratio; a uniform adapter (every prompt |
| 7 | shifted identically) produces zero variance both ways and falls back to |
| 8 | the ``0.5`` null expectation. |
| 9 | """ |
| 10 | |
| 11 | from __future__ import annotations |
| 12 | |
| 13 | import math |
| 14 | from typing import Any |
| 15 | |
| 16 | import numpy as np |
| 17 | import pytest |
| 18 | |
| 19 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 20 | from dlm_sway.core.result import Verdict |
| 21 | from dlm_sway.core.scoring import TokenDist |
| 22 | from dlm_sway.probes.base import RunContext, build_probe |
| 23 | from dlm_sway.probes.cluster_kl import ClusterKLProbe |
| 24 | |
| 25 | |
| 26 | def _dist_sharp(seed_offset: int = 0) -> TokenDist: |
| 27 | """Sharp distribution: nearly all mass on token 0.""" |
| 28 | k = 8 |
| 29 | lp = np.array([-0.1 + 0.01 * seed_offset] + [-5.0] * (k - 1), dtype=np.float32) |
| 30 | residual = 1.0 - float(np.exp(lp).sum()) |
| 31 | tail = math.log(residual) if residual > 1e-12 else None |
| 32 | return TokenDist( |
| 33 | token_ids=np.arange(k, dtype=np.int64), |
| 34 | logprobs=lp, |
| 35 | vocab_size=1000, |
| 36 | tail_logprob=tail, |
| 37 | ) |
| 38 | |
| 39 | |
| 40 | def _dist_broad() -> TokenDist: |
| 41 | """Broad distribution: uniform over top-k, with a tiny monotonic |
| 42 | perturbation so it clears ``_divergence``'s uniformity guard (a |
| 43 | literally-flat dist looks like a broken lm_head).""" |
| 44 | k = 8 |
| 45 | lp = np.full(k, -math.log(k), dtype=np.float32) |
| 46 | lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32) |
| 47 | residual = 1.0 - float(np.exp(lp).sum()) |
| 48 | tail = math.log(residual) if residual > 1e-12 else None |
| 49 | return TokenDist( |
| 50 | token_ids=np.arange(k, dtype=np.int64), |
| 51 | logprobs=lp, |
| 52 | vocab_size=1000, |
| 53 | tail_logprob=tail, |
| 54 | ) |
| 55 | |
| 56 | |
| 57 | def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def] |
| 58 | def _encode(texts: list[str]): # type: ignore[no-untyped-def] |
| 59 | return np.stack([text_to_vec[t] for t in texts]) |
| 60 | |
| 61 | return _encode |
| 62 | |
| 63 | |
| 64 | def _argmax_kmeans(embeddings: np.ndarray, *, k: int, seed: int) -> np.ndarray: |
| 65 | """sklearn-free stub: cluster by argmax of the one-hot test embeddings. |
| 66 | |
| 67 | The tests construct embeddings in the canonical basis so each vector's |
| 68 | argmax is its intended cluster ID. Keeps the unit tests runnable on CI |
| 69 | runners that don't install the ``[semsim]`` extra. |
| 70 | """ |
| 71 | del seed # deterministic by construction |
| 72 | labels = np.argmax(embeddings, axis=1).astype(np.int64) |
| 73 | return labels % k |
| 74 | |
| 75 | |
| 76 | @pytest.fixture |
| 77 | def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]: |
| 78 | """Install a stub embedder + sklearn-free k-means on ``cluster_kl``'s |
| 79 | helpers. Matches the ``adapter_revert`` test pattern but also bypasses |
| 80 | ``sklearn.cluster.KMeans`` so tests work without the ``[semsim]`` extra. |
| 81 | """ |
| 82 | table: dict[str, np.ndarray] = {} |
| 83 | monkeypatch.setattr( |
| 84 | "dlm_sway.probes.cluster_kl._load_embedder", |
| 85 | lambda _model_id: _stub_embedder(table), # type: ignore[arg-type] |
| 86 | ) |
| 87 | monkeypatch.setattr( |
| 88 | "dlm_sway.probes.cluster_kl._kmeans_cluster", |
| 89 | _argmax_kmeans, |
| 90 | ) |
| 91 | return table |
| 92 | |
| 93 | |
| 94 | def _two_topic_backend(topic_a: list[str], topic_b: list[str]) -> DummyDifferentialBackend: |
| 95 | """Base is sharp on all prompts. ft is broad on topic A (high KL) and |
| 96 | sharp on topic B (near-zero KL). Produces a strong per-topic signal. |
| 97 | """ |
| 98 | base_dists: dict[str, TokenDist] = {} |
| 99 | ft_dists: dict[str, TokenDist] = {} |
| 100 | for p in topic_a: |
| 101 | base_dists[p] = _dist_sharp() |
| 102 | ft_dists[p] = _dist_broad() # diverges sharply from base |
| 103 | for p in topic_b: |
| 104 | base_dists[p] = _dist_sharp() |
| 105 | ft_dists[p] = _dist_sharp() # matches base → ~0 divergence |
| 106 | base = DummyResponses(token_dists=base_dists) |
| 107 | ft = DummyResponses(token_dists=ft_dists) |
| 108 | return DummyDifferentialBackend(base=base, ft=ft) |
| 109 | |
| 110 | |
| 111 | def _uniform_backend(prompts: list[str]) -> DummyDifferentialBackend: |
| 112 | """Base and ft produce *identical* distributions for every prompt. |
| 113 | Divergences are all zero, so both variances are zero and the |
| 114 | specificity ratio falls back to the ``0.5`` convention. |
| 115 | """ |
| 116 | base_dists = {p: _dist_sharp() for p in prompts} |
| 117 | ft_dists = {p: _dist_sharp() for p in prompts} |
| 118 | base = DummyResponses(token_dists=base_dists) |
| 119 | ft = DummyResponses(token_dists=ft_dists) |
| 120 | return DummyDifferentialBackend(base=base, ft=ft) |
| 121 | |
| 122 | |
| 123 | class TestClusterKL: |
| 124 | def test_two_topic_adapter_high_specificity( |
| 125 | self, monkeyed_embed: dict[str, np.ndarray] |
| 126 | ) -> None: |
| 127 | """Two clearly-separated topics where only topic A is shifted by |
| 128 | the adapter → specificity ratio drives toward 1.0.""" |
| 129 | topic_a = [f"A-prompt-{i}" for i in range(6)] |
| 130 | topic_b = [f"B-prompt-{i}" for i in range(6)] |
| 131 | for p in topic_a: |
| 132 | monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32) |
| 133 | for p in topic_b: |
| 134 | monkeyed_embed[p] = np.array([0.0, 1.0], dtype=np.float32) |
| 135 | |
| 136 | probe, spec = build_probe( |
| 137 | { |
| 138 | "name": "ck", |
| 139 | "kind": "cluster_kl", |
| 140 | "prompts": topic_a + topic_b, |
| 141 | "num_clusters": 2, |
| 142 | "min_prompts": 4, |
| 143 | } |
| 144 | ) |
| 145 | ctx = RunContext(backend=_two_topic_backend(topic_a, topic_b)) |
| 146 | result = probe.run(spec, ctx) |
| 147 | |
| 148 | assert result.raw is not None |
| 149 | assert result.raw > 0.8, ( |
| 150 | f"expected specificity >> 0.5 for a topic-specific adapter; got {result.raw:.3f}" |
| 151 | ) |
| 152 | assert result.evidence["num_clusters"] == 2 |
| 153 | assert result.evidence["num_prompts"] == 12 |
| 154 | # Cluster means differ sharply: one near broad-vs-sharp KL, one near 0. |
| 155 | per_cluster = result.evidence["per_cluster_mean_kl"] |
| 156 | assert len(per_cluster) == 2 |
| 157 | hi, lo = sorted(per_cluster, reverse=True) |
| 158 | assert hi > 0.1, f"expected high-cluster mean > 0.1; got {per_cluster}" |
| 159 | assert lo < 0.01, f"expected low-cluster mean < 0.01; got {per_cluster}" |
| 160 | |
| 161 | def test_uniform_adapter_fallback_to_half(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 162 | """All prompts shifted identically → zero between-/within-variance |
| 163 | → specificity lands on the ``0.5`` fallback (not NaN). F17: the |
| 164 | degenerate branch returns WARN with a ``degenerate_zero_variance`` |
| 165 | marker and no z-score, not a spurious calibrated verdict.""" |
| 166 | prompts = [f"p-{i}" for i in range(8)] |
| 167 | # Split embeddings across two centroids so k-means has a valid |
| 168 | # partition; the divergence math is what drives the ratio. |
| 169 | for i, p in enumerate(prompts): |
| 170 | vec = [1.0, 0.0] if i % 2 == 0 else [0.0, 1.0] |
| 171 | monkeyed_embed[p] = np.array(vec, dtype=np.float32) |
| 172 | |
| 173 | probe, spec = build_probe( |
| 174 | { |
| 175 | "name": "ck", |
| 176 | "kind": "cluster_kl", |
| 177 | "prompts": prompts, |
| 178 | "num_clusters": 2, |
| 179 | "min_prompts": 4, |
| 180 | } |
| 181 | ) |
| 182 | ctx = RunContext(backend=_uniform_backend(prompts)) |
| 183 | result = probe.run(spec, ctx) |
| 184 | |
| 185 | assert result.raw == pytest.approx(0.5, abs=1e-6) |
| 186 | assert result.evidence["within_cluster_variance"] == pytest.approx(0.0) |
| 187 | assert result.evidence["between_cluster_variance"] == pytest.approx(0.0) |
| 188 | # F17: degenerate case gets a WARN verdict + explicit evidence |
| 189 | # marker; no z-score is emitted (comparing a conventional 0.5 |
| 190 | # to a null mean near 0.5 would produce spurious calibration). |
| 191 | assert result.verdict == Verdict.WARN |
| 192 | assert result.z_score is None |
| 193 | assert result.evidence["degenerate_zero_variance"] is True |
| 194 | assert "degenerate" in result.message.lower() |
| 195 | |
| 196 | def test_too_few_prompts_skips(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 197 | del monkeyed_embed # no embedding needed — SKIP short-circuits first |
| 198 | probe, spec = build_probe( |
| 199 | { |
| 200 | "name": "ck", |
| 201 | "kind": "cluster_kl", |
| 202 | "prompts": ["a", "b", "c"], |
| 203 | "num_clusters": 2, |
| 204 | "min_prompts": 10, |
| 205 | } |
| 206 | ) |
| 207 | ctx = RunContext( |
| 208 | backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 209 | ) |
| 210 | result = probe.run(spec, ctx) |
| 211 | assert result.verdict == Verdict.SKIP |
| 212 | assert "≥10" in result.message |
| 213 | |
| 214 | def test_empty_prompts_errors(self) -> None: |
| 215 | probe, spec = build_probe( |
| 216 | {"name": "ck", "kind": "cluster_kl", "prompts": [], "min_prompts": 4} |
| 217 | ) |
| 218 | ctx = RunContext( |
| 219 | backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 220 | ) |
| 221 | result = probe.run(spec, ctx) |
| 222 | assert result.verdict == Verdict.ERROR |
| 223 | |
| 224 | def test_num_clusters_gt_prompts_skips(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 225 | del monkeyed_embed |
| 226 | # 5 prompts, k=3 → 3*2 = 6 > 5 → SKIP. |
| 227 | probe, spec = build_probe( |
| 228 | { |
| 229 | "name": "ck", |
| 230 | "kind": "cluster_kl", |
| 231 | "prompts": [f"p{i}" for i in range(5)], |
| 232 | "num_clusters": 3, |
| 233 | "min_prompts": 4, |
| 234 | } |
| 235 | ) |
| 236 | ctx = RunContext( |
| 237 | backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 238 | ) |
| 239 | result = probe.run(spec, ctx) |
| 240 | assert result.verdict == Verdict.SKIP |
| 241 | assert "num_clusters=3" in result.message |
| 242 | |
| 243 | def test_ci_95_populated(self, monkeyed_embed: dict[str, np.ndarray]) -> None: |
| 244 | """Bootstrap CI lands on the ProbeResult and brackets raw.""" |
| 245 | topic_a = [f"A-{i}" for i in range(5)] |
| 246 | topic_b = [f"B-{i}" for i in range(5)] |
| 247 | for p in topic_a: |
| 248 | monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32) |
| 249 | for p in topic_b: |
| 250 | monkeyed_embed[p] = np.array([0.0, 1.0], dtype=np.float32) |
| 251 | |
| 252 | probe, spec = build_probe( |
| 253 | { |
| 254 | "name": "ck", |
| 255 | "kind": "cluster_kl", |
| 256 | "prompts": topic_a + topic_b, |
| 257 | "num_clusters": 2, |
| 258 | "min_prompts": 4, |
| 259 | } |
| 260 | ) |
| 261 | ctx = RunContext(backend=_two_topic_backend(topic_a, topic_b)) |
| 262 | result = probe.run(spec, ctx) |
| 263 | |
| 264 | assert result.ci_95 is not None |
| 265 | lo, hi = result.ci_95 |
| 266 | assert 0.0 <= lo <= hi <= 1.0 |
| 267 | assert result.raw is not None |
| 268 | # Bootstrap on a strong signal should bracket raw; allow a |
| 269 | # small slack because resampling (6,6) prompt pairs is noisy. |
| 270 | assert lo - 0.05 <= result.raw <= hi + 0.05 |
| 271 | |
| 272 | |
| 273 | class TestMissingSemsim: |
| 274 | def test_skip_when_extras_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: |
| 275 | from dlm_sway.core.errors import BackendNotAvailableError |
| 276 | |
| 277 | def raiser(_model_id: Any) -> Any: # type: ignore[no-untyped-def] |
| 278 | raise BackendNotAvailableError( |
| 279 | "cluster_kl", |
| 280 | extra="semsim", |
| 281 | hint="cluster_kl needs sentence-transformers + scikit-learn.", |
| 282 | ) |
| 283 | |
| 284 | monkeypatch.setattr( |
| 285 | "dlm_sway.probes.cluster_kl._load_embedder", |
| 286 | raiser, # type: ignore[arg-type] |
| 287 | ) |
| 288 | probe = ClusterKLProbe() |
| 289 | spec = probe.spec_cls( |
| 290 | name="ck", |
| 291 | kind="cluster_kl", |
| 292 | prompts=[f"p-{i}" for i in range(8)], |
| 293 | num_clusters=2, |
| 294 | min_prompts=4, |
| 295 | ) |
| 296 | ctx = RunContext( |
| 297 | backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 298 | ) |
| 299 | result = probe.run(spec, ctx) |
| 300 | assert result.verdict == Verdict.SKIP |
| 301 | assert "semsim" in result.message |
| 302 | |
| 303 | def test_skip_when_sklearn_import_fails( |
| 304 | self, monkeypatch: pytest.MonkeyPatch, monkeyed_embed: dict[str, np.ndarray] |
| 305 | ) -> None: |
| 306 | """Covers the ``_kmeans_cluster`` import-error SKIP branch directly. |
| 307 | |
| 308 | The ``_load_embedder`` raise branch is tested above; this test |
| 309 | stubs ``_load_embedder`` to succeed and replaces |
| 310 | ``_kmeans_cluster`` with a raiser that mimics an uninstalled |
| 311 | sklearn. Before this test, the sklearn-missing SKIP path in |
| 312 | ``probes/cluster_kl.py`` was unreachable under any test — the |
| 313 | embedder raise always fired first. |
| 314 | """ |
| 315 | from dlm_sway.core.errors import BackendNotAvailableError |
| 316 | |
| 317 | for p in [f"p-{i}" for i in range(8)]: |
| 318 | monkeyed_embed[p] = np.array([1.0, 0.0], dtype=np.float32) |
| 319 | |
| 320 | def sklearn_raiser(*_args: Any, **_kwargs: Any) -> Any: |
| 321 | raise BackendNotAvailableError( |
| 322 | "cluster_kl", |
| 323 | extra="semsim", |
| 324 | hint="cluster_kl needs scikit-learn for k-means clustering.", |
| 325 | ) |
| 326 | |
| 327 | monkeypatch.setattr( |
| 328 | "dlm_sway.probes.cluster_kl._kmeans_cluster", |
| 329 | sklearn_raiser, |
| 330 | ) |
| 331 | probe = ClusterKLProbe() |
| 332 | spec = probe.spec_cls( |
| 333 | name="ck", |
| 334 | kind="cluster_kl", |
| 335 | prompts=[f"p-{i}" for i in range(8)], |
| 336 | num_clusters=2, |
| 337 | min_prompts=4, |
| 338 | ) |
| 339 | ctx = RunContext( |
| 340 | backend=DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 341 | ) |
| 342 | result = probe.run(spec, ctx) |
| 343 | assert result.verdict == Verdict.SKIP |
| 344 | assert "semsim" in result.message |
| 345 | assert "scikit-learn" in result.message |
| 346 | |
| 347 | |
| 348 | class TestRealKMeans: |
| 349 | """Exercise the actual ``sklearn.cluster.KMeans`` primitive. |
| 350 | |
| 351 | Every other test in this file monkeypatches ``_kmeans_cluster`` with |
| 352 | an argmax stub so suites can run in CI environments without the |
| 353 | ``[semsim]`` extra installed. That leaves the real sklearn path — |
| 354 | the probe's entire reason for existing — uncovered. The tests here |
| 355 | skip when sklearn isn't available and execute the real import |
| 356 | otherwise. |
| 357 | """ |
| 358 | |
| 359 | def test_real_kmeans_separates_two_gaussians(self) -> None: |
| 360 | """Two clearly-separated clusters → k-means recovers the correct |
| 361 | partition with a fixed seed.""" |
| 362 | pytest.importorskip("sklearn") |
| 363 | from dlm_sway.probes.cluster_kl import _kmeans_cluster |
| 364 | |
| 365 | rng = np.random.default_rng(0) |
| 366 | # Cluster A centered at (0, 0); cluster B centered at (5, 0). |
| 367 | group_a = rng.normal(loc=0.0, scale=0.5, size=(8, 2)).astype(np.float32) |
| 368 | group_b = rng.normal(loc=(5.0, 0.0), scale=0.5, size=(8, 2)).astype(np.float32) |
| 369 | embeddings = np.vstack([group_a, group_b]) |
| 370 | labels = _kmeans_cluster(embeddings, k=2, seed=0) |
| 371 | assert labels.shape == (16,) |
| 372 | # All-A should share a label; all-B should share the other. |
| 373 | label_a = set(labels[:8].tolist()) |
| 374 | label_b = set(labels[8:].tolist()) |
| 375 | assert len(label_a) == 1 |
| 376 | assert len(label_b) == 1 |
| 377 | assert label_a != label_b |
| 378 | |
| 379 | def test_real_kmeans_seed_is_deterministic(self) -> None: |
| 380 | """Two runs with the same seed → identical label vectors. Pins |
| 381 | the determinism contract in a way that the argmax stub can't. |
| 382 | """ |
| 383 | pytest.importorskip("sklearn") |
| 384 | from dlm_sway.probes.cluster_kl import _kmeans_cluster |
| 385 | |
| 386 | rng = np.random.default_rng(0) |
| 387 | embeddings = rng.normal(size=(20, 4)).astype(np.float32) |
| 388 | labels_a = _kmeans_cluster(embeddings, k=3, seed=42) |
| 389 | labels_b = _kmeans_cluster(embeddings, k=3, seed=42) |
| 390 | assert np.array_equal(labels_a, labels_b) |