Python · 17208 bytes Raw Blame History
1 """F8 ClusterKL — distribution-shift specificity via clustered KL.
2
3 ``delta_kl`` answers "how much did the adapter move the model, on
4 average?" — a global claim. ``cluster_kl`` asks the sharper
5 follow-up: "did the adapter move the *right* topics?"
6
7 A great adapter trained on a cooking document should shift the
8 model heavily on cooking prompts and leave chemistry prompts alone.
9 A bad adapter shifts everything by a similar amount — broad-stroke
10 over-fitting. ``delta_kl``'s mean-across-prompts number can't tell
11 them apart. ``cluster_kl`` can.
12
13 The probe:
14
15 1. Embeds every prompt via MiniLM (shared cache with
16 :mod:`adapter_revert`).
17 2. K-means-clusters the embeddings at a fixed seed from ``ctx``.
18 3. Measures per-prompt JS (or KL) divergence between base and ft.
19 4. Aggregates into per-cluster mean KL; computes within- and
20 between-cluster variance.
21 5. Reports a **specificity ratio**::
22
23 specificity = between / (between + within)
24
25 Ranges in ``[0.0, 1.0]``. A topic-specific adapter produces
26 different mean KLs across clusters (high between, low within) so
27 specificity → 1. A blunt-instrument adapter shifts every cluster
28 equally (low between, whatever within) so specificity → 0.5
29 (because within = between in expectation under a null noise
30 distribution).
31
32 Category: ``adherence``. Complements ``delta_kl`` — the pair of
33 ``(mean_kl, specificity)`` tells a more honest story than either
34 number alone. Needs ``sentence-transformers + scikit-learn`` via
35 the ``[semsim]`` extra; SKIPs with a clear install hint otherwise.
36 """
37
38 from __future__ import annotations
39
40 import statistics
41 from typing import TYPE_CHECKING, Any, Literal
42
43 from pydantic import Field
44
45 from dlm_sway.core.errors import BackendNotAvailableError
46 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
47 from dlm_sway.probes._divergence import Divergence, divergence
48 from dlm_sway.probes._zscore import (
49 no_calibration_note,
50 score_from_z,
51 verdict_from_z,
52 z_score,
53 z_scores_by_rank,
54 )
55 from dlm_sway.probes.adapter_revert import _load_embedder
56 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
57 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
58
59 if TYPE_CHECKING:
60 import numpy as np
61 from numpy.typing import NDArray
62
63
64 class ClusterKLSpec(ProbeSpec):
65 """Spec for ``kind: cluster_kl``."""
66
67 kind: Literal["cluster_kl"] = "cluster_kl"
68 prompts: list[str] = Field(default_factory=list)
69 """Prompt set — embedded + clustered. ``min_prompts`` floor applies."""
70 num_clusters: int = Field(default=5, ge=2, le=32)
71 """``k`` for k-means. Below 2 there's nothing to cluster; above 32
72 the per-cluster size falls below what a specificity ratio can
73 meaningfully resolve on the typical 30–100 prompt set."""
74 embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
75 """Shared-cache key with :mod:`adapter_revert`. MiniLM-L6 is the
76 default because it's tiny (~80 MB), CPU-fast, and sufficient for
77 topic-level clustering. Users with specialty corpora can swap in
78 a domain-tuned model."""
79 divergence: Divergence = "js"
80 """Per-prompt divergence to aggregate. ``js`` is bounded ``[0,
81 ln 2]`` so specificity ratios stay interpretable; ``kl`` is
82 unbounded and can push the ratio around when a single outlier
83 prompt has a large divergence."""
84 top_k: int | None = None
85 min_prompts: int = Field(default=20, ge=4)
86 """Floor below which k-means on MiniLM embeddings is too noisy
87 for a meaningful specificity ratio. Probe SKIPs below this with
88 a clear message."""
89 assert_specificity_gte: float = 0.6
90 """Fallback threshold when no null stats are available. Typical
91 numbers: ``≈ 0.5`` on a null adapter (noise is topic-agnostic),
92 ``≈ 0.7–0.9`` on a well-targeted adapter."""
93 assert_z_gte: float = 3.0
94 """Z-score pass criterion against the null baseline."""
95
96
97 class ClusterKLProbe(Probe):
98 """F8 ClusterKL — distribution-shift specificity via clustered KL."""
99
100 kind = "cluster_kl"
101 spec_cls = ClusterKLSpec
102 category = "adherence"
103 #: S23 — same shape as delta_kl (uniform next_token_dist over a
104 #: prompt list). Batched path drops HF wall time on the 8-prompt
105 #: calibration pass from 8× single-forward to 1× 8-sample forward.
106 batch_score = True
107
108 @classmethod
109 def calibrate_spec(cls, ctx: RunContext) -> ClusterKLSpec | None:
110 # Null calibration path: synthesize 8 mixed-topic sentinel
111 # prompts + k=2. The null's specificity distribution should
112 # concentrate around 0.5 (noise is topic-agnostic); the
113 # runner's z-score machinery catches the adapter's separation
114 # above that baseline.
115 del ctx
116 return ClusterKLSpec(
117 name="_calibration",
118 kind="cluster_kl",
119 prompts=[
120 "The cat chased the mouse.",
121 "Write a Python decorator that logs calls.",
122 "Dogs are loyal companions.",
123 "Implement binary search in Rust.",
124 "Horses gallop across the plains.",
125 "Debug a segfault in C++ pointer arithmetic.",
126 "Elephants never forget a face.",
127 "Explain ownership in Rust.",
128 ],
129 num_clusters=2,
130 min_prompts=8,
131 )
132
133 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
134 assert isinstance(spec, ClusterKLSpec)
135 if not spec.prompts:
136 return ProbeResult(
137 name=spec.name,
138 kind=spec.kind,
139 verdict=Verdict.ERROR,
140 score=None,
141 message="no prompts provided",
142 )
143 if len(spec.prompts) < spec.min_prompts:
144 return ProbeResult(
145 name=spec.name,
146 kind=spec.kind,
147 verdict=Verdict.SKIP,
148 score=None,
149 message=(
150 f"need ≥{spec.min_prompts} prompts for stable clustering; "
151 f"got {len(spec.prompts)}"
152 ),
153 )
154 if spec.num_clusters * 2 > len(spec.prompts):
155 return ProbeResult(
156 name=spec.name,
157 kind=spec.kind,
158 verdict=Verdict.SKIP,
159 score=None,
160 message=(
161 f"num_clusters={spec.num_clusters} with only {len(spec.prompts)} "
162 f"prompts: per-cluster mean isn't well-resolved. Reduce "
163 f"num_clusters or add prompts."
164 ),
165 )
166
167 # Embed + cluster. Both need [semsim]; surface one SKIP verdict
168 # with a pip hint if either extra is missing.
169 try:
170 embeddings = _embed_prompts(spec.prompts, spec.embedding_model)
171 labels = _kmeans_cluster(embeddings, k=spec.num_clusters, seed=ctx.seed)
172 except BackendNotAvailableError as exc:
173 return ProbeResult(
174 name=spec.name,
175 kind=spec.kind,
176 verdict=Verdict.SKIP,
177 score=None,
178 message=str(exc),
179 )
180 except ImportError as exc:
181 return ProbeResult(
182 name=spec.name,
183 kind=spec.kind,
184 verdict=Verdict.SKIP,
185 score=None,
186 message=(
187 f"cluster_kl needs sentence-transformers + scikit-learn "
188 f"(pip install 'dlm-sway[semsim]'): {exc}"
189 ),
190 )
191
192 # S23 — per-prompt divergences, now via one batched forward
193 # per view (same math as ``delta_kl``).
194 top_k = spec.top_k if spec.top_k is not None else ctx.top_k
195 with ctx.require_backend.as_base() as base_view:
196 base_dists = base_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
197 with ctx.require_backend.as_finetuned() as ft_view:
198 ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k)
199 divergences: list[float] = [
200 divergence(b, f, kind=spec.divergence)
201 for b, f in zip(base_dists, ft_dists, strict=True)
202 ]
203
204 # Aggregate per-cluster means + variances. A cluster that
205 # ended up empty (can happen with k-means when an initial
206 # centroid lands in a dead zone) contributes nothing.
207 buckets: dict[int, list[float]] = {i: [] for i in range(spec.num_clusters)}
208 for lab, d in zip(labels, divergences, strict=True):
209 buckets[int(lab)].append(d)
210 cluster_means: list[float] = []
211 within_variances: list[float] = []
212 per_cluster_size: list[int] = []
213 per_cluster_mean_kl: list[float] = []
214 cluster_exemplars: list[list[str]] = []
215 for i in range(spec.num_clusters):
216 vals = buckets[i]
217 per_cluster_size.append(len(vals))
218 if not vals:
219 per_cluster_mean_kl.append(float("nan"))
220 cluster_exemplars.append([])
221 continue
222 mean_c = statistics.fmean(vals)
223 cluster_means.append(mean_c)
224 per_cluster_mean_kl.append(mean_c)
225 if len(vals) > 1:
226 within_variances.append(statistics.pvariance(vals))
227 # Exemplars — first 2 prompts in cluster order.
228 exemplar_idxs = [j for j, lab in enumerate(labels) if int(lab) == i][:2]
229 cluster_exemplars.append([spec.prompts[j][:120] for j in exemplar_idxs])
230
231 between_variance = statistics.pvariance(cluster_means) if len(cluster_means) >= 2 else 0.0
232 within_variance = statistics.fmean(within_variances) if within_variances else 0.0
233 denom = between_variance + within_variance
234 # Degenerate zero-variance case: if every prompt produced the
235 # same divergence (cache hit on stub data, no adapter motion,
236 # etc.), specificity is mathematically undefined. Convention:
237 # 0.5 — the null-adapter expectation — so downstream z-score
238 # path reports "no signal" rather than a runtime NaN.
239 degenerate = denom <= 0.0
240 specificity = between_variance / denom if not degenerate else 0.5
241 mean_kl = statistics.fmean(divergences)
242
243 # Bootstrap CI on specificity: resample per-prompt
244 # (divergence, label) pairs and recompute the ratio.
245 ci_95 = _bootstrap_specificity(divergences, labels, ctx.seed, spec.num_clusters)
246
247 # F17 — the degenerate fallback short-circuits the z-score
248 # path. Comparing a conventional 0.5 to a null whose mean is
249 # marginally off-center (small-N sampling noise) would produce
250 # a spurious non-zero z that downstream consumers might act
251 # on. Force "no signal" semantics with a single WARN verdict.
252 if degenerate:
253 message = (
254 f"specificity=0.50 (k={spec.num_clusters}) — degenerate: "
255 f"zero within/between variance (no per-prompt spread)"
256 )
257 return safe_finalize(
258 name=spec.name,
259 kind=spec.kind,
260 verdict=Verdict.WARN,
261 score=0.0,
262 raw=specificity,
263 z_score=None,
264 evidence={
265 "num_clusters": spec.num_clusters,
266 "num_prompts": len(spec.prompts),
267 "divergence_kind": spec.divergence,
268 "mean_kl": mean_kl,
269 "within_cluster_variance": within_variance,
270 "between_cluster_variance": between_variance,
271 "per_cluster_size": per_cluster_size,
272 "per_cluster_mean_kl": per_cluster_mean_kl,
273 "cluster_exemplars": cluster_exemplars,
274 "weight": spec.weight,
275 "degenerate_zero_variance": True,
276 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
277 },
278 message=message,
279 ci_95=ci_95,
280 )
281
282 # Null calibration: specificity on noise ≈ 0.5 ± small.
283 stats = get_null_stats(ctx, spec.kind)
284 z = z_score(specificity, stats)
285 z_by_rank = z_scores_by_rank(specificity, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
286 verdict_z = verdict_from_z(z, spec.assert_z_gte)
287 if verdict_z is not None:
288 verdict = verdict_z
289 score_val = score_from_z(z)
290 score = score_val if score_val is not None else 0.0
291 message = (
292 f"specificity={specificity:.2f} (k={spec.num_clusters}), "
293 f"mean_kl={mean_kl:.3f}, z={z:+.2f}σ vs null"
294 )
295 else:
296 verdict = Verdict.PASS if specificity >= spec.assert_specificity_gte else Verdict.FAIL
297 score = max(0.0, min(1.0, (specificity - 0.5) * 2.0)) # map [0.5, 1.0] → [0, 1]
298 message = (
299 f"specificity={specificity:.2f} (k={spec.num_clusters}), "
300 f"mean_kl={mean_kl:.3f} {no_calibration_note(spec.kind)}"
301 )
302
303 return safe_finalize(
304 name=spec.name,
305 kind=spec.kind,
306 verdict=verdict,
307 score=score,
308 raw=specificity,
309 z_score=z,
310 evidence={
311 "num_clusters": spec.num_clusters,
312 "num_prompts": len(spec.prompts),
313 "divergence_kind": spec.divergence,
314 "mean_kl": mean_kl,
315 "within_cluster_variance": within_variance,
316 "between_cluster_variance": between_variance,
317 "per_cluster_size": per_cluster_size,
318 "per_cluster_mean_kl": per_cluster_mean_kl,
319 "cluster_exemplars": cluster_exemplars,
320 "weight": spec.weight,
321 "z_by_rank": z_by_rank,
322 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
323 },
324 message=message,
325 ci_95=ci_95,
326 )
327
328
329 # ----------------------------------------------------------------------
330 # helpers
331 # ----------------------------------------------------------------------
332
333
334 def _embed_prompts(prompts: list[str], model_id: str) -> NDArray[np.float32]:
335 """Route through :func:`adapter_revert._load_embedder` so the 80 MB
336 MiniLM load is shared across probes in one suite run."""
337 embed = _load_embedder(model_id)
338 vecs = embed(prompts)
339 # ``SentenceTransformer.encode`` returns ``ndarray``; cast to float32
340 # for downstream sklearn's happy path.
341 import numpy as np
342
343 return np.asarray(vecs, dtype=np.float32)
344
345
346 def _kmeans_cluster(embeddings: NDArray[np.float32], *, k: int, seed: int) -> NDArray[np.int64]:
347 """k-means with a fixed seed. ``n_init=10`` tames centroid
348 initialization variance without blowing up runtime on the
349 tiny-K values we care about."""
350 try:
351 from sklearn.cluster import KMeans
352 except ImportError as exc:
353 raise BackendNotAvailableError(
354 "cluster_kl",
355 extra="semsim",
356 hint="cluster_kl needs scikit-learn for k-means clustering.",
357 ) from exc
358 import numpy as np
359
360 km = KMeans(n_clusters=k, random_state=seed, n_init=10)
361 return np.asarray(km.fit_predict(embeddings), dtype=np.int64)
362
363
364 def _bootstrap_specificity(
365 divergences: list[float],
366 labels: Any,
367 seed: int,
368 num_clusters: int,
369 *,
370 n_bootstrap: int = 1000,
371 ) -> tuple[float, float] | None:
372 """Percentile-bootstrap CI on the specificity ratio.
373
374 Resamples ``(divergence, label)`` pairs with replacement, recomputes
375 the ratio, takes the 2.5/97.5 percentiles of the bootstrap
376 distribution. Returns ``None`` when the input is too small to
377 resample meaningfully (< 4 prompts) — matches the convention the
378 other aggregating probes use via :func:`core.stats.bootstrap_ci`.
379 """
380 import numpy as np
381
382 n = len(divergences)
383 if n < 4:
384 return None
385 divs_arr = np.asarray(divergences, dtype=np.float64)
386 labs_arr = np.asarray(labels, dtype=np.int64)
387 rng = np.random.default_rng(seed)
388 ratios: list[float] = []
389 for _ in range(n_bootstrap):
390 idx = rng.integers(0, n, size=n)
391 d = divs_arr[idx]
392 lb = labs_arr[idx]
393 cluster_means: list[float] = []
394 within_vars: list[float] = []
395 for i in range(num_clusters):
396 mask = lb == i
397 if not mask.any():
398 continue
399 vals = d[mask]
400 cluster_means.append(float(vals.mean()))
401 if vals.size > 1:
402 within_vars.append(float(vals.var()))
403 if len(cluster_means) < 2:
404 ratios.append(0.5)
405 continue
406 between = float(np.var(cluster_means))
407 within = float(np.mean(within_vars)) if within_vars else 0.0
408 denom = between + within
409 ratios.append(between / denom if denom > 0 else 0.5)
410 if not ratios:
411 return None
412 lo, hi = np.percentile(ratios, [2.5, 97.5])
413 return (float(lo), float(hi))
414
415
416 __all__ = ["ClusterKLProbe", "ClusterKLSpec"]