| 1 | """F8 prove-the-value: ``cluster_kl`` surfaces a distinction ``delta_kl`` misses. |
| 2 | |
| 3 | ``delta_kl`` reports one number — the *mean* divergence across the prompt |
| 4 | set. Two very different adapters can share the same mean: a *blunt* one |
| 5 | shifts every topic by a moderate amount, and a *targeted* one shifts one |
| 6 | topic heavily while leaving another untouched. The mean is the same; the |
| 7 | stories are not. |
| 8 | |
| 9 | This test constructs exactly that scenario with stubbed embeddings and |
| 10 | canned token distributions, runs ``delta_kl`` and ``cluster_kl`` on both |
| 11 | backends, and shows: |
| 12 | |
| 13 | - ``delta_kl`` reports comparable mean divergence for both adapters — |
| 14 | the signal is ambiguous. |
| 15 | - ``cluster_kl`` pulls the structural difference apart: low specificity |
| 16 | on the blunt adapter (≈ 0.5), high specificity on the targeted one. |
| 17 | |
| 18 | Without ``cluster_kl`` you can't tell these apart with numeric probes; the |
| 19 | F8 claim is that this split matters in practice. |
| 20 | """ |
| 21 | |
| 22 | from __future__ import annotations |
| 23 | |
| 24 | import math |
| 25 | |
| 26 | import numpy as np |
| 27 | import pytest |
| 28 | |
| 29 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 30 | from dlm_sway.core.scoring import TokenDist |
| 31 | from dlm_sway.probes.base import RunContext, build_probe |
| 32 | |
| 33 | |
| 34 | def _dist_from_probs(probs: list[float]) -> TokenDist: |
| 35 | arr = np.asarray(probs, dtype=np.float64) |
| 36 | arr = arr / arr.sum() |
| 37 | lp = np.log(arr).astype(np.float32) |
| 38 | return TokenDist( |
| 39 | token_ids=np.arange(len(probs), dtype=np.int64), |
| 40 | logprobs=lp, |
| 41 | vocab_size=max(1000, len(probs)), |
| 42 | tail_logprob=None, |
| 43 | ) |
| 44 | |
| 45 | |
| 46 | # Base is sharply peaked; three step levels of "shift" for ft. |
| 47 | BASE = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02]) |
| 48 | FT_MODERATE = _dist_from_probs([0.55, 0.15, 0.10, 0.10, 0.10]) |
| 49 | FT_STRONG = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15]) |
| 50 | FT_IDENTITY = BASE # ft == base → zero divergence |
| 51 | |
| 52 | |
| 53 | def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def] |
| 54 | def _encode(texts: list[str]): # type: ignore[no-untyped-def] |
| 55 | return np.stack([text_to_vec[t] for t in texts]) |
| 56 | |
| 57 | return _encode |
| 58 | |
| 59 | |
| 60 | def _argmax_kmeans(embeddings: np.ndarray, *, k: int, seed: int) -> np.ndarray: |
| 61 | """sklearn-free stub — cluster by argmax of the one-hot test embeddings.""" |
| 62 | del seed |
| 63 | labels = np.argmax(embeddings, axis=1).astype(np.int64) |
| 64 | return labels % k |
| 65 | |
| 66 | |
| 67 | @pytest.fixture |
| 68 | def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]: |
| 69 | table: dict[str, np.ndarray] = {} |
| 70 | monkeypatch.setattr( |
| 71 | "dlm_sway.probes.cluster_kl._load_embedder", |
| 72 | lambda _model_id: _stub_embedder(table), # type: ignore[arg-type] |
| 73 | ) |
| 74 | monkeypatch.setattr( |
| 75 | "dlm_sway.probes.cluster_kl._kmeans_cluster", |
| 76 | _argmax_kmeans, |
| 77 | ) |
| 78 | return table |
| 79 | |
| 80 | |
| 81 | TOPIC_A = [f"A-prompt-{i}" for i in range(8)] |
| 82 | TOPIC_B = [f"B-prompt-{i}" for i in range(8)] |
| 83 | ALL_PROMPTS = TOPIC_A + TOPIC_B |
| 84 | |
| 85 | |
| 86 | def _blunt_backend() -> DummyDifferentialBackend: |
| 87 | """Every prompt is shifted moderately by the adapter. No topic spike.""" |
| 88 | base = dict.fromkeys(ALL_PROMPTS, BASE) |
| 89 | ft = dict.fromkeys(ALL_PROMPTS, FT_MODERATE) |
| 90 | return DummyDifferentialBackend( |
| 91 | base=DummyResponses(token_dists=base), |
| 92 | ft=DummyResponses(token_dists=ft), |
| 93 | ) |
| 94 | |
| 95 | |
| 96 | def _targeted_backend() -> DummyDifferentialBackend: |
| 97 | """Topic A gets a strong shift; topic B is untouched.""" |
| 98 | base = dict.fromkeys(ALL_PROMPTS, BASE) |
| 99 | ft = dict.fromkeys(TOPIC_A, FT_STRONG) | dict.fromkeys(TOPIC_B, FT_IDENTITY) |
| 100 | return DummyDifferentialBackend( |
| 101 | base=DummyResponses(token_dists=base), |
| 102 | ft=DummyResponses(token_dists=ft), |
| 103 | ) |
| 104 | |
| 105 | |
| 106 | def _install_embeddings(table: dict[str, np.ndarray]) -> None: |
| 107 | for p in TOPIC_A: |
| 108 | table[p] = np.array([1.0, 0.0], dtype=np.float32) |
| 109 | for p in TOPIC_B: |
| 110 | table[p] = np.array([0.0, 1.0], dtype=np.float32) |
| 111 | |
| 112 | |
| 113 | def _run_delta_kl(backend: DummyDifferentialBackend) -> float: |
| 114 | probe, spec = build_probe({"name": "dk", "kind": "delta_kl", "prompts": ALL_PROMPTS}) |
| 115 | result = probe.run(spec, RunContext(backend=backend)) |
| 116 | assert result.raw is not None |
| 117 | return result.raw |
| 118 | |
| 119 | |
| 120 | def _run_cluster_kl(backend: DummyDifferentialBackend) -> float: |
| 121 | probe, spec = build_probe( |
| 122 | { |
| 123 | "name": "ck", |
| 124 | "kind": "cluster_kl", |
| 125 | "prompts": ALL_PROMPTS, |
| 126 | "num_clusters": 2, |
| 127 | "min_prompts": 4, |
| 128 | } |
| 129 | ) |
| 130 | result = probe.run(spec, RunContext(backend=backend)) |
| 131 | assert result.raw is not None |
| 132 | return result.raw |
| 133 | |
| 134 | |
| 135 | def test_cluster_kl_distinguishes_what_delta_kl_merges( |
| 136 | monkeyed_embed: dict[str, np.ndarray], |
| 137 | ) -> None: |
| 138 | _install_embeddings(monkeyed_embed) |
| 139 | |
| 140 | blunt_delta = _run_delta_kl(_blunt_backend()) |
| 141 | targeted_delta = _run_delta_kl(_targeted_backend()) |
| 142 | blunt_spec = _run_cluster_kl(_blunt_backend()) |
| 143 | targeted_spec = _run_cluster_kl(_targeted_backend()) |
| 144 | |
| 145 | # delta_kl: both adapters land in the same "meaningful shift" |
| 146 | # band — neither is zero, neither is extreme. If you only see the |
| 147 | # mean, these two adapters look similar. |
| 148 | assert blunt_delta > 0.01, f"blunt delta_kl should be non-trivial; got {blunt_delta:.4f}" |
| 149 | assert targeted_delta > 0.01, ( |
| 150 | f"targeted delta_kl should be non-trivial; got {targeted_delta:.4f}" |
| 151 | ) |
| 152 | # Ambiguity: ratio stays within 3× — no clean "this one is different" signal. |
| 153 | ratio = max(blunt_delta, targeted_delta) / min(blunt_delta, targeted_delta) |
| 154 | assert ratio < 3.0, ( |
| 155 | f"delta_kl should be comparable across the two adapters; " |
| 156 | f"got blunt={blunt_delta:.4f}, targeted={targeted_delta:.4f} (ratio {ratio:.2f})" |
| 157 | ) |
| 158 | |
| 159 | # cluster_kl: pulls them apart by at least 0.3 on the [0, 1] scale. |
| 160 | gap = targeted_spec - blunt_spec |
| 161 | assert gap > 0.3, ( |
| 162 | f"cluster_kl should surface the structural difference; " |
| 163 | f"blunt={blunt_spec:.3f}, targeted={targeted_spec:.3f}, gap={gap:.3f}" |
| 164 | ) |
| 165 | # And the targeted adapter specifically lands well above 0.5 (random), |
| 166 | # while the blunt one lands near it. |
| 167 | assert targeted_spec > 0.8, ( |
| 168 | f"targeted specificity should be close to 1; got {targeted_spec:.3f}" |
| 169 | ) |
| 170 | assert blunt_spec < 0.6, f"blunt specificity should be near 0.5; got {blunt_spec:.3f}" |
| 171 | |
| 172 | |
| 173 | def test_base_and_ft_distributions_are_normalized() -> None: |
| 174 | """Sanity on the hand-built dists used by the fixture.""" |
| 175 | for name, d in [("BASE", BASE), ("FT_MODERATE", FT_MODERATE), ("FT_STRONG", FT_STRONG)]: |
| 176 | total = float(np.exp(d.logprobs).sum()) |
| 177 | assert math.isclose(total, 1.0, abs_tol=1e-5), f"{name} sum was {total}" |