| 1 | """F8 ClusterKL — distribution-shift specificity via clustered KL. |
| 2 | |
| 3 | ``delta_kl`` answers "how much did the adapter move the model, on |
| 4 | average?" — a global claim. ``cluster_kl`` asks the sharper |
| 5 | follow-up: "did the adapter move the *right* topics?" |
| 6 | |
| 7 | A great adapter trained on a cooking document should shift the |
| 8 | model heavily on cooking prompts and leave chemistry prompts alone. |
| 9 | A bad adapter shifts everything by a similar amount — broad-stroke |
| 10 | over-fitting. ``delta_kl``'s mean-across-prompts number can't tell |
| 11 | them apart. ``cluster_kl`` can. |
| 12 | |
| 13 | The probe: |
| 14 | |
| 15 | 1. Embeds every prompt via MiniLM (shared cache with |
| 16 | :mod:`adapter_revert`). |
| 17 | 2. K-means-clusters the embeddings at a fixed seed from ``ctx``. |
| 18 | 3. Measures per-prompt JS (or KL) divergence between base and ft. |
| 19 | 4. Aggregates into per-cluster mean KL; computes within- and |
| 20 | between-cluster variance. |
| 21 | 5. Reports a **specificity ratio**:: |
| 22 | |
| 23 | specificity = between / (between + within) |
| 24 | |
| 25 | Ranges in ``[0.0, 1.0]``. A topic-specific adapter produces |
| 26 | different mean KLs across clusters (high between, low within) so |
| 27 | specificity → 1. A blunt-instrument adapter shifts every cluster |
| 28 | equally (low between, whatever within) so specificity → 0.5 |
| 29 | (because within = between in expectation under a null noise |
| 30 | distribution). |
| 31 | |
| 32 | Category: ``adherence``. Complements ``delta_kl`` — the pair of |
| 33 | ``(mean_kl, specificity)`` tells a more honest story than either |
| 34 | number alone. Needs ``sentence-transformers + scikit-learn`` via |
| 35 | the ``[semsim]`` extra; SKIPs with a clear install hint otherwise. |
| 36 | """ |
| 37 | |
| 38 | from __future__ import annotations |
| 39 | |
| 40 | import statistics |
| 41 | from typing import TYPE_CHECKING, Any, Literal |
| 42 | |
| 43 | from pydantic import Field |
| 44 | |
| 45 | from dlm_sway.core.errors import BackendNotAvailableError |
| 46 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 47 | from dlm_sway.probes._divergence import Divergence, divergence |
| 48 | from dlm_sway.probes._zscore import ( |
| 49 | no_calibration_note, |
| 50 | score_from_z, |
| 51 | verdict_from_z, |
| 52 | z_score, |
| 53 | z_scores_by_rank, |
| 54 | ) |
| 55 | from dlm_sway.probes.adapter_revert import _load_embedder |
| 56 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 57 | from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank |
| 58 | |
| 59 | if TYPE_CHECKING: |
| 60 | import numpy as np |
| 61 | from numpy.typing import NDArray |
| 62 | |
| 63 | |
| 64 | class ClusterKLSpec(ProbeSpec): |
| 65 | """Spec for ``kind: cluster_kl``.""" |
| 66 | |
| 67 | kind: Literal["cluster_kl"] = "cluster_kl" |
| 68 | prompts: list[str] = Field(default_factory=list) |
| 69 | """Prompt set — embedded + clustered. ``min_prompts`` floor applies.""" |
| 70 | num_clusters: int = Field(default=5, ge=2, le=32) |
| 71 | """``k`` for k-means. Below 2 there's nothing to cluster; above 32 |
| 72 | the per-cluster size falls below what a specificity ratio can |
| 73 | meaningfully resolve on the typical 30–100 prompt set.""" |
| 74 | embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" |
| 75 | """Shared-cache key with :mod:`adapter_revert`. MiniLM-L6 is the |
| 76 | default because it's tiny (~80 MB), CPU-fast, and sufficient for |
| 77 | topic-level clustering. Users with specialty corpora can swap in |
| 78 | a domain-tuned model.""" |
| 79 | divergence: Divergence = "js" |
| 80 | """Per-prompt divergence to aggregate. ``js`` is bounded ``[0, |
| 81 | ln 2]`` so specificity ratios stay interpretable; ``kl`` is |
| 82 | unbounded and can push the ratio around when a single outlier |
| 83 | prompt has a large divergence.""" |
| 84 | top_k: int | None = None |
| 85 | min_prompts: int = Field(default=20, ge=4) |
| 86 | """Floor below which k-means on MiniLM embeddings is too noisy |
| 87 | for a meaningful specificity ratio. Probe SKIPs below this with |
| 88 | a clear message.""" |
| 89 | assert_specificity_gte: float = 0.6 |
| 90 | """Fallback threshold when no null stats are available. Typical |
| 91 | numbers: ``≈ 0.5`` on a null adapter (noise is topic-agnostic), |
| 92 | ``≈ 0.7–0.9`` on a well-targeted adapter.""" |
| 93 | assert_z_gte: float = 3.0 |
| 94 | """Z-score pass criterion against the null baseline.""" |
| 95 | |
| 96 | |
| 97 | class ClusterKLProbe(Probe): |
| 98 | """F8 ClusterKL — distribution-shift specificity via clustered KL.""" |
| 99 | |
| 100 | kind = "cluster_kl" |
| 101 | spec_cls = ClusterKLSpec |
| 102 | category = "adherence" |
| 103 | #: S23 — same shape as delta_kl (uniform next_token_dist over a |
| 104 | #: prompt list). Batched path drops HF wall time on the 8-prompt |
| 105 | #: calibration pass from 8× single-forward to 1× 8-sample forward. |
| 106 | batch_score = True |
| 107 | |
| 108 | @classmethod |
| 109 | def calibrate_spec(cls, ctx: RunContext) -> ClusterKLSpec | None: |
| 110 | # Null calibration path: synthesize 8 mixed-topic sentinel |
| 111 | # prompts + k=2. The null's specificity distribution should |
| 112 | # concentrate around 0.5 (noise is topic-agnostic); the |
| 113 | # runner's z-score machinery catches the adapter's separation |
| 114 | # above that baseline. |
| 115 | del ctx |
| 116 | return ClusterKLSpec( |
| 117 | name="_calibration", |
| 118 | kind="cluster_kl", |
| 119 | prompts=[ |
| 120 | "The cat chased the mouse.", |
| 121 | "Write a Python decorator that logs calls.", |
| 122 | "Dogs are loyal companions.", |
| 123 | "Implement binary search in Rust.", |
| 124 | "Horses gallop across the plains.", |
| 125 | "Debug a segfault in C++ pointer arithmetic.", |
| 126 | "Elephants never forget a face.", |
| 127 | "Explain ownership in Rust.", |
| 128 | ], |
| 129 | num_clusters=2, |
| 130 | min_prompts=8, |
| 131 | ) |
| 132 | |
| 133 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 134 | assert isinstance(spec, ClusterKLSpec) |
| 135 | if not spec.prompts: |
| 136 | return ProbeResult( |
| 137 | name=spec.name, |
| 138 | kind=spec.kind, |
| 139 | verdict=Verdict.ERROR, |
| 140 | score=None, |
| 141 | message="no prompts provided", |
| 142 | ) |
| 143 | if len(spec.prompts) < spec.min_prompts: |
| 144 | return ProbeResult( |
| 145 | name=spec.name, |
| 146 | kind=spec.kind, |
| 147 | verdict=Verdict.SKIP, |
| 148 | score=None, |
| 149 | message=( |
| 150 | f"need ≥{spec.min_prompts} prompts for stable clustering; " |
| 151 | f"got {len(spec.prompts)}" |
| 152 | ), |
| 153 | ) |
| 154 | if spec.num_clusters * 2 > len(spec.prompts): |
| 155 | return ProbeResult( |
| 156 | name=spec.name, |
| 157 | kind=spec.kind, |
| 158 | verdict=Verdict.SKIP, |
| 159 | score=None, |
| 160 | message=( |
| 161 | f"num_clusters={spec.num_clusters} with only {len(spec.prompts)} " |
| 162 | f"prompts: per-cluster mean isn't well-resolved. Reduce " |
| 163 | f"num_clusters or add prompts." |
| 164 | ), |
| 165 | ) |
| 166 | |
| 167 | # Embed + cluster. Both need [semsim]; surface one SKIP verdict |
| 168 | # with a pip hint if either extra is missing. |
| 169 | try: |
| 170 | embeddings = _embed_prompts(spec.prompts, spec.embedding_model) |
| 171 | labels = _kmeans_cluster(embeddings, k=spec.num_clusters, seed=ctx.seed) |
| 172 | except BackendNotAvailableError as exc: |
| 173 | return ProbeResult( |
| 174 | name=spec.name, |
| 175 | kind=spec.kind, |
| 176 | verdict=Verdict.SKIP, |
| 177 | score=None, |
| 178 | message=str(exc), |
| 179 | ) |
| 180 | except ImportError as exc: |
| 181 | return ProbeResult( |
| 182 | name=spec.name, |
| 183 | kind=spec.kind, |
| 184 | verdict=Verdict.SKIP, |
| 185 | score=None, |
| 186 | message=( |
| 187 | f"cluster_kl needs sentence-transformers + scikit-learn " |
| 188 | f"(pip install 'dlm-sway[semsim]'): {exc}" |
| 189 | ), |
| 190 | ) |
| 191 | |
| 192 | # S23 — per-prompt divergences, now via one batched forward |
| 193 | # per view (same math as ``delta_kl``). |
| 194 | top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| 195 | with ctx.require_backend.as_base() as base_view: |
| 196 | base_dists = base_view.next_token_dist_batch(list(spec.prompts), top_k=top_k) |
| 197 | with ctx.require_backend.as_finetuned() as ft_view: |
| 198 | ft_dists = ft_view.next_token_dist_batch(list(spec.prompts), top_k=top_k) |
| 199 | divergences: list[float] = [ |
| 200 | divergence(b, f, kind=spec.divergence) |
| 201 | for b, f in zip(base_dists, ft_dists, strict=True) |
| 202 | ] |
| 203 | |
| 204 | # Aggregate per-cluster means + variances. A cluster that |
| 205 | # ended up empty (can happen with k-means when an initial |
| 206 | # centroid lands in a dead zone) contributes nothing. |
| 207 | buckets: dict[int, list[float]] = {i: [] for i in range(spec.num_clusters)} |
| 208 | for lab, d in zip(labels, divergences, strict=True): |
| 209 | buckets[int(lab)].append(d) |
| 210 | cluster_means: list[float] = [] |
| 211 | within_variances: list[float] = [] |
| 212 | per_cluster_size: list[int] = [] |
| 213 | per_cluster_mean_kl: list[float] = [] |
| 214 | cluster_exemplars: list[list[str]] = [] |
| 215 | for i in range(spec.num_clusters): |
| 216 | vals = buckets[i] |
| 217 | per_cluster_size.append(len(vals)) |
| 218 | if not vals: |
| 219 | per_cluster_mean_kl.append(float("nan")) |
| 220 | cluster_exemplars.append([]) |
| 221 | continue |
| 222 | mean_c = statistics.fmean(vals) |
| 223 | cluster_means.append(mean_c) |
| 224 | per_cluster_mean_kl.append(mean_c) |
| 225 | if len(vals) > 1: |
| 226 | within_variances.append(statistics.pvariance(vals)) |
| 227 | # Exemplars — first 2 prompts in cluster order. |
| 228 | exemplar_idxs = [j for j, lab in enumerate(labels) if int(lab) == i][:2] |
| 229 | cluster_exemplars.append([spec.prompts[j][:120] for j in exemplar_idxs]) |
| 230 | |
| 231 | between_variance = statistics.pvariance(cluster_means) if len(cluster_means) >= 2 else 0.0 |
| 232 | within_variance = statistics.fmean(within_variances) if within_variances else 0.0 |
| 233 | denom = between_variance + within_variance |
| 234 | # Degenerate zero-variance case: if every prompt produced the |
| 235 | # same divergence (cache hit on stub data, no adapter motion, |
| 236 | # etc.), specificity is mathematically undefined. Convention: |
| 237 | # 0.5 — the null-adapter expectation — so downstream z-score |
| 238 | # path reports "no signal" rather than a runtime NaN. |
| 239 | degenerate = denom <= 0.0 |
| 240 | specificity = between_variance / denom if not degenerate else 0.5 |
| 241 | mean_kl = statistics.fmean(divergences) |
| 242 | |
| 243 | # Bootstrap CI on specificity: resample per-prompt |
| 244 | # (divergence, label) pairs and recompute the ratio. |
| 245 | ci_95 = _bootstrap_specificity(divergences, labels, ctx.seed, spec.num_clusters) |
| 246 | |
| 247 | # F17 — the degenerate fallback short-circuits the z-score |
| 248 | # path. Comparing a conventional 0.5 to a null whose mean is |
| 249 | # marginally off-center (small-N sampling noise) would produce |
| 250 | # a spurious non-zero z that downstream consumers might act |
| 251 | # on. Force "no signal" semantics with a single WARN verdict. |
| 252 | if degenerate: |
| 253 | message = ( |
| 254 | f"specificity=0.50 (k={spec.num_clusters}) — degenerate: " |
| 255 | f"zero within/between variance (no per-prompt spread)" |
| 256 | ) |
| 257 | return safe_finalize( |
| 258 | name=spec.name, |
| 259 | kind=spec.kind, |
| 260 | verdict=Verdict.WARN, |
| 261 | score=0.0, |
| 262 | raw=specificity, |
| 263 | z_score=None, |
| 264 | evidence={ |
| 265 | "num_clusters": spec.num_clusters, |
| 266 | "num_prompts": len(spec.prompts), |
| 267 | "divergence_kind": spec.divergence, |
| 268 | "mean_kl": mean_kl, |
| 269 | "within_cluster_variance": within_variance, |
| 270 | "between_cluster_variance": between_variance, |
| 271 | "per_cluster_size": per_cluster_size, |
| 272 | "per_cluster_mean_kl": per_cluster_mean_kl, |
| 273 | "cluster_exemplars": cluster_exemplars, |
| 274 | "weight": spec.weight, |
| 275 | "degenerate_zero_variance": True, |
| 276 | "raw_ci_95": list(ci_95) if ci_95 is not None else None, |
| 277 | }, |
| 278 | message=message, |
| 279 | ci_95=ci_95, |
| 280 | ) |
| 281 | |
| 282 | # Null calibration: specificity on noise ≈ 0.5 ± small. |
| 283 | stats = get_null_stats(ctx, spec.kind) |
| 284 | z = z_score(specificity, stats) |
| 285 | z_by_rank = z_scores_by_rank(specificity, get_null_stats_by_rank(ctx, spec.kind), sign=+1) |
| 286 | verdict_z = verdict_from_z(z, spec.assert_z_gte) |
| 287 | if verdict_z is not None: |
| 288 | verdict = verdict_z |
| 289 | score_val = score_from_z(z) |
| 290 | score = score_val if score_val is not None else 0.0 |
| 291 | message = ( |
| 292 | f"specificity={specificity:.2f} (k={spec.num_clusters}), " |
| 293 | f"mean_kl={mean_kl:.3f}, z={z:+.2f}σ vs null" |
| 294 | ) |
| 295 | else: |
| 296 | verdict = Verdict.PASS if specificity >= spec.assert_specificity_gte else Verdict.FAIL |
| 297 | score = max(0.0, min(1.0, (specificity - 0.5) * 2.0)) # map [0.5, 1.0] → [0, 1] |
| 298 | message = ( |
| 299 | f"specificity={specificity:.2f} (k={spec.num_clusters}), " |
| 300 | f"mean_kl={mean_kl:.3f} {no_calibration_note(spec.kind)}" |
| 301 | ) |
| 302 | |
| 303 | return safe_finalize( |
| 304 | name=spec.name, |
| 305 | kind=spec.kind, |
| 306 | verdict=verdict, |
| 307 | score=score, |
| 308 | raw=specificity, |
| 309 | z_score=z, |
| 310 | evidence={ |
| 311 | "num_clusters": spec.num_clusters, |
| 312 | "num_prompts": len(spec.prompts), |
| 313 | "divergence_kind": spec.divergence, |
| 314 | "mean_kl": mean_kl, |
| 315 | "within_cluster_variance": within_variance, |
| 316 | "between_cluster_variance": between_variance, |
| 317 | "per_cluster_size": per_cluster_size, |
| 318 | "per_cluster_mean_kl": per_cluster_mean_kl, |
| 319 | "cluster_exemplars": cluster_exemplars, |
| 320 | "weight": spec.weight, |
| 321 | "z_by_rank": z_by_rank, |
| 322 | "raw_ci_95": list(ci_95) if ci_95 is not None else None, |
| 323 | }, |
| 324 | message=message, |
| 325 | ci_95=ci_95, |
| 326 | ) |
| 327 | |
| 328 | |
| 329 | # ---------------------------------------------------------------------- |
| 330 | # helpers |
| 331 | # ---------------------------------------------------------------------- |
| 332 | |
| 333 | |
| 334 | def _embed_prompts(prompts: list[str], model_id: str) -> NDArray[np.float32]: |
| 335 | """Route through :func:`adapter_revert._load_embedder` so the 80 MB |
| 336 | MiniLM load is shared across probes in one suite run.""" |
| 337 | embed = _load_embedder(model_id) |
| 338 | vecs = embed(prompts) |
| 339 | # ``SentenceTransformer.encode`` returns ``ndarray``; cast to float32 |
| 340 | # for downstream sklearn's happy path. |
| 341 | import numpy as np |
| 342 | |
| 343 | return np.asarray(vecs, dtype=np.float32) |
| 344 | |
| 345 | |
| 346 | def _kmeans_cluster(embeddings: NDArray[np.float32], *, k: int, seed: int) -> NDArray[np.int64]: |
| 347 | """k-means with a fixed seed. ``n_init=10`` tames centroid |
| 348 | initialization variance without blowing up runtime on the |
| 349 | tiny-K values we care about.""" |
| 350 | try: |
| 351 | from sklearn.cluster import KMeans |
| 352 | except ImportError as exc: |
| 353 | raise BackendNotAvailableError( |
| 354 | "cluster_kl", |
| 355 | extra="semsim", |
| 356 | hint="cluster_kl needs scikit-learn for k-means clustering.", |
| 357 | ) from exc |
| 358 | import numpy as np |
| 359 | |
| 360 | km = KMeans(n_clusters=k, random_state=seed, n_init=10) |
| 361 | return np.asarray(km.fit_predict(embeddings), dtype=np.int64) |
| 362 | |
| 363 | |
| 364 | def _bootstrap_specificity( |
| 365 | divergences: list[float], |
| 366 | labels: Any, |
| 367 | seed: int, |
| 368 | num_clusters: int, |
| 369 | *, |
| 370 | n_bootstrap: int = 1000, |
| 371 | ) -> tuple[float, float] | None: |
| 372 | """Percentile-bootstrap CI on the specificity ratio. |
| 373 | |
| 374 | Resamples ``(divergence, label)`` pairs with replacement, recomputes |
| 375 | the ratio, takes the 2.5/97.5 percentiles of the bootstrap |
| 376 | distribution. Returns ``None`` when the input is too small to |
| 377 | resample meaningfully (< 4 prompts) — matches the convention the |
| 378 | other aggregating probes use via :func:`core.stats.bootstrap_ci`. |
| 379 | """ |
| 380 | import numpy as np |
| 381 | |
| 382 | n = len(divergences) |
| 383 | if n < 4: |
| 384 | return None |
| 385 | divs_arr = np.asarray(divergences, dtype=np.float64) |
| 386 | labs_arr = np.asarray(labels, dtype=np.int64) |
| 387 | rng = np.random.default_rng(seed) |
| 388 | ratios: list[float] = [] |
| 389 | for _ in range(n_bootstrap): |
| 390 | idx = rng.integers(0, n, size=n) |
| 391 | d = divs_arr[idx] |
| 392 | lb = labs_arr[idx] |
| 393 | cluster_means: list[float] = [] |
| 394 | within_vars: list[float] = [] |
| 395 | for i in range(num_clusters): |
| 396 | mask = lb == i |
| 397 | if not mask.any(): |
| 398 | continue |
| 399 | vals = d[mask] |
| 400 | cluster_means.append(float(vals.mean())) |
| 401 | if vals.size > 1: |
| 402 | within_vars.append(float(vals.var())) |
| 403 | if len(cluster_means) < 2: |
| 404 | ratios.append(0.5) |
| 405 | continue |
| 406 | between = float(np.var(cluster_means)) |
| 407 | within = float(np.mean(within_vars)) if within_vars else 0.0 |
| 408 | denom = between + within |
| 409 | ratios.append(between / denom if denom > 0 else 0.5) |
| 410 | if not ratios: |
| 411 | return None |
| 412 | lo, hi = np.percentile(ratios, [2.5, 97.5]) |
| 413 | return (float(lo), float(hi)) |
| 414 | |
| 415 | |
| 416 | __all__ = ["ClusterKLProbe", "ClusterKLSpec"] |