Python · 6567 bytes Raw Blame History
1 """F8 prove-the-value: ``cluster_kl`` surfaces a distinction ``delta_kl`` misses.
2
3 ``delta_kl`` reports one number — the *mean* divergence across the prompt
4 set. Two very different adapters can share the same mean: a *blunt* one
5 shifts every topic by a moderate amount, and a *targeted* one shifts one
6 topic heavily while leaving another untouched. The mean is the same; the
7 stories are not.
8
9 This test constructs exactly that scenario with stubbed embeddings and
10 canned token distributions, runs ``delta_kl`` and ``cluster_kl`` on both
11 backends, and shows:
12
13 - ``delta_kl`` reports comparable mean divergence for both adapters —
14 the signal is ambiguous.
15 - ``cluster_kl`` pulls the structural difference apart: low specificity
16 on the blunt adapter (≈ 0.5), high specificity on the targeted one.
17
18 Without ``cluster_kl`` you can't tell these apart with numeric probes; the
19 F8 claim is that this split matters in practice.
20 """
21
22 from __future__ import annotations
23
24 import math
25
26 import numpy as np
27 import pytest
28
29 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
30 from dlm_sway.core.scoring import TokenDist
31 from dlm_sway.probes.base import RunContext, build_probe
32
33
34 def _dist_from_probs(probs: list[float]) -> TokenDist:
35 arr = np.asarray(probs, dtype=np.float64)
36 arr = arr / arr.sum()
37 lp = np.log(arr).astype(np.float32)
38 return TokenDist(
39 token_ids=np.arange(len(probs), dtype=np.int64),
40 logprobs=lp,
41 vocab_size=max(1000, len(probs)),
42 tail_logprob=None,
43 )
44
45
46 # Base is sharply peaked; three step levels of "shift" for ft.
47 BASE = _dist_from_probs([0.92, 0.02, 0.02, 0.02, 0.02])
48 FT_MODERATE = _dist_from_probs([0.55, 0.15, 0.10, 0.10, 0.10])
49 FT_STRONG = _dist_from_probs([0.25, 0.20, 0.20, 0.20, 0.15])
50 FT_IDENTITY = BASE # ft == base → zero divergence
51
52
53 def _stub_embedder(text_to_vec: dict[str, np.ndarray]): # type: ignore[no-untyped-def]
54 def _encode(texts: list[str]): # type: ignore[no-untyped-def]
55 return np.stack([text_to_vec[t] for t in texts])
56
57 return _encode
58
59
60 def _argmax_kmeans(embeddings: np.ndarray, *, k: int, seed: int) -> np.ndarray:
61 """sklearn-free stub — cluster by argmax of the one-hot test embeddings."""
62 del seed
63 labels = np.argmax(embeddings, axis=1).astype(np.int64)
64 return labels % k
65
66
67 @pytest.fixture
68 def monkeyed_embed(monkeypatch: pytest.MonkeyPatch) -> dict[str, np.ndarray]:
69 table: dict[str, np.ndarray] = {}
70 monkeypatch.setattr(
71 "dlm_sway.probes.cluster_kl._load_embedder",
72 lambda _model_id: _stub_embedder(table), # type: ignore[arg-type]
73 )
74 monkeypatch.setattr(
75 "dlm_sway.probes.cluster_kl._kmeans_cluster",
76 _argmax_kmeans,
77 )
78 return table
79
80
81 TOPIC_A = [f"A-prompt-{i}" for i in range(8)]
82 TOPIC_B = [f"B-prompt-{i}" for i in range(8)]
83 ALL_PROMPTS = TOPIC_A + TOPIC_B
84
85
86 def _blunt_backend() -> DummyDifferentialBackend:
87 """Every prompt is shifted moderately by the adapter. No topic spike."""
88 base = dict.fromkeys(ALL_PROMPTS, BASE)
89 ft = dict.fromkeys(ALL_PROMPTS, FT_MODERATE)
90 return DummyDifferentialBackend(
91 base=DummyResponses(token_dists=base),
92 ft=DummyResponses(token_dists=ft),
93 )
94
95
96 def _targeted_backend() -> DummyDifferentialBackend:
97 """Topic A gets a strong shift; topic B is untouched."""
98 base = dict.fromkeys(ALL_PROMPTS, BASE)
99 ft = dict.fromkeys(TOPIC_A, FT_STRONG) | dict.fromkeys(TOPIC_B, FT_IDENTITY)
100 return DummyDifferentialBackend(
101 base=DummyResponses(token_dists=base),
102 ft=DummyResponses(token_dists=ft),
103 )
104
105
106 def _install_embeddings(table: dict[str, np.ndarray]) -> None:
107 for p in TOPIC_A:
108 table[p] = np.array([1.0, 0.0], dtype=np.float32)
109 for p in TOPIC_B:
110 table[p] = np.array([0.0, 1.0], dtype=np.float32)
111
112
113 def _run_delta_kl(backend: DummyDifferentialBackend) -> float:
114 probe, spec = build_probe({"name": "dk", "kind": "delta_kl", "prompts": ALL_PROMPTS})
115 result = probe.run(spec, RunContext(backend=backend))
116 assert result.raw is not None
117 return result.raw
118
119
120 def _run_cluster_kl(backend: DummyDifferentialBackend) -> float:
121 probe, spec = build_probe(
122 {
123 "name": "ck",
124 "kind": "cluster_kl",
125 "prompts": ALL_PROMPTS,
126 "num_clusters": 2,
127 "min_prompts": 4,
128 }
129 )
130 result = probe.run(spec, RunContext(backend=backend))
131 assert result.raw is not None
132 return result.raw
133
134
135 def test_cluster_kl_distinguishes_what_delta_kl_merges(
136 monkeyed_embed: dict[str, np.ndarray],
137 ) -> None:
138 _install_embeddings(monkeyed_embed)
139
140 blunt_delta = _run_delta_kl(_blunt_backend())
141 targeted_delta = _run_delta_kl(_targeted_backend())
142 blunt_spec = _run_cluster_kl(_blunt_backend())
143 targeted_spec = _run_cluster_kl(_targeted_backend())
144
145 # delta_kl: both adapters land in the same "meaningful shift"
146 # band — neither is zero, neither is extreme. If you only see the
147 # mean, these two adapters look similar.
148 assert blunt_delta > 0.01, f"blunt delta_kl should be non-trivial; got {blunt_delta:.4f}"
149 assert targeted_delta > 0.01, (
150 f"targeted delta_kl should be non-trivial; got {targeted_delta:.4f}"
151 )
152 # Ambiguity: ratio stays within 3× — no clean "this one is different" signal.
153 ratio = max(blunt_delta, targeted_delta) / min(blunt_delta, targeted_delta)
154 assert ratio < 3.0, (
155 f"delta_kl should be comparable across the two adapters; "
156 f"got blunt={blunt_delta:.4f}, targeted={targeted_delta:.4f} (ratio {ratio:.2f})"
157 )
158
159 # cluster_kl: pulls them apart by at least 0.3 on the [0, 1] scale.
160 gap = targeted_spec - blunt_spec
161 assert gap > 0.3, (
162 f"cluster_kl should surface the structural difference; "
163 f"blunt={blunt_spec:.3f}, targeted={targeted_spec:.3f}, gap={gap:.3f}"
164 )
165 # And the targeted adapter specifically lands well above 0.5 (random),
166 # while the blunt one lands near it.
167 assert targeted_spec > 0.8, (
168 f"targeted specificity should be close to 1; got {targeted_spec:.3f}"
169 )
170 assert blunt_spec < 0.6, f"blunt specificity should be near 0.5; got {blunt_spec:.3f}"
171
172
173 def test_base_and_ft_distributions_are_normalized() -> None:
174 """Sanity on the hand-built dists used by the fixture."""
175 for name, d in [("BASE", BASE), ("FT_MODERATE", FT_MODERATE), ("FT_STRONG", FT_STRONG)]:
176 total = float(np.exp(d.logprobs).sum())
177 assert math.isclose(total, 1.0, abs_tol=1e-5), f"{name} sum was {total}"