| 1 | """Shared math for divergence-based probes. |
| 2 | |
| 3 | Extracted so :mod:`delta_kl`, :mod:`adapter_ablation`, and any future |
| 4 | probe operating on next-token distributions reuse the same aligned- |
| 5 | top-k KL / JS computation. Having one implementation keeps the numerical |
| 6 | treatment consistent across the report. |
| 7 | |
| 8 | **Non-finite policy** (S01): every entry point in this module rejects |
| 9 | non-finite inputs *explicitly* by raising :class:`ProbeError`. The |
| 10 | historical bug — ``np.exp(nan) = nan`` flowing past a ``p > 0`` mask |
| 11 | that evaluates ``nan > 0`` as False — produced a mathematically |
| 12 | impossible JS divergence of 13.247 nats (bounded by ln 2 ≈ 0.693). |
| 13 | We refuse to compute on non-finite inputs rather than silently produce |
| 14 | garbage; the calling probe routes the resulting :class:`ProbeError` to |
| 15 | :attr:`Verdict.ERROR` via :func:`safe_finalize`. |
| 16 | """ |
| 17 | |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | import math |
| 21 | from typing import Literal |
| 22 | |
| 23 | import numpy as np |
| 24 | from numpy.typing import NDArray |
| 25 | |
| 26 | from dlm_sway.core.errors import ProbeError |
| 27 | from dlm_sway.core.scoring import TokenDist |
| 28 | |
| 29 | Divergence = Literal["kl", "js"] |
| 30 | |
| 31 | |
| 32 | def _check_finite_token_dist(name: str, dist: TokenDist) -> None: |
| 33 | """Reject a TokenDist whose logprobs contain NaN/inf. |
| 34 | |
| 35 | Called at the entry to ``aligned_probs``. The error message names the |
| 36 | side (base / ft) so a probe failure pinpoints the broken model. |
| 37 | """ |
| 38 | if not np.all(np.isfinite(dist.logprobs)): |
| 39 | n_bad = int(np.sum(~np.isfinite(dist.logprobs))) |
| 40 | raise ProbeError( |
| 41 | "divergence", |
| 42 | f"{name} TokenDist contains {n_bad} non-finite logprob(s) — " |
| 43 | f"refusing to compute divergence on a model that produces " |
| 44 | f"NaN/inf logits", |
| 45 | ) |
| 46 | |
| 47 | |
| 48 | # Tolerance for "effectively uniform." Real models never return |
| 49 | # bit-identical logits across the top-k — fp32 accumulation noise |
| 50 | # alone produces spreads in the 1e-5 range. We reject only distributions |
| 51 | # that are suspiciously exact: spread below 1e-9 strongly implies a |
| 52 | # broken lm_head or a test fixture that zero-fills logits. |
| 53 | _UNIFORM_LOGPROB_TOL: float = 1e-9 |
| 54 | |
| 55 | |
| 56 | def _check_non_degenerate_token_dist(name: str, dist: TokenDist) -> None: |
| 57 | """Reject a TokenDist whose top-k logprobs are all (effectively) equal. |
| 58 | |
| 59 | A well-formed next-token distribution from a real model has a |
| 60 | peaked top-k; perfectly uniform logprobs mean either the lm_head |
| 61 | broke or upstream sampling code clobbered the logits. The |
| 62 | divergence math is still defined (KL(uniform ∥ uniform) = 0) but |
| 63 | the resulting probe value would be a constant across prompts, |
| 64 | producing a meaningless ``delta_kl``. Surface the broken model |
| 65 | explicitly rather than letting it leak a false zero. |
| 66 | |
| 67 | Called after the finite check — order matters because a NaN top-k |
| 68 | would trip this guard first with a confusing message. |
| 69 | """ |
| 70 | if dist.logprobs.size < 2: |
| 71 | return |
| 72 | spread = float(dist.logprobs.max() - dist.logprobs.min()) |
| 73 | if spread < _UNIFORM_LOGPROB_TOL: |
| 74 | raise ProbeError( |
| 75 | "divergence", |
| 76 | f"{name} TokenDist has effectively-uniform top-{dist.logprobs.size} " |
| 77 | f"logprobs (spread={spread:.2e}) — divergence on a degenerate " |
| 78 | f"distribution would return a trivial constant; refusing to " |
| 79 | f"proceed. Check the backend's lm_head and logits pipeline.", |
| 80 | ) |
| 81 | |
| 82 | |
| 83 | def aligned_probs( |
| 84 | base: TokenDist, ft: TokenDist |
| 85 | ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: |
| 86 | """Return aligned probability vectors over the union of top-k tokens. |
| 87 | |
| 88 | Two ``TokenDist`` objects may surface different top-k indices if |
| 89 | the two models disagree about the hot tokens. We build a shared |
| 90 | support — ``union(base.token_ids, ft.token_ids)`` — and slot the |
| 91 | known probabilities in. Unknown entries fall back to the |
| 92 | per-distribution tail mass divided across the missing tokens, |
| 93 | which is the maximum-entropy completion under the truncation. |
| 94 | |
| 95 | Raises :class:`ProbeError` when either input contains non-finite |
| 96 | logprobs. |
| 97 | """ |
| 98 | _check_finite_token_dist("base", base) |
| 99 | _check_finite_token_dist("ft", ft) |
| 100 | _check_non_degenerate_token_dist("base", base) |
| 101 | _check_non_degenerate_token_dist("ft", ft) |
| 102 | |
| 103 | union_ids = np.union1d(base.token_ids, ft.token_ids) |
| 104 | k = int(union_ids.size) |
| 105 | |
| 106 | base_probs = _to_support(base, union_ids, k) |
| 107 | ft_probs = _to_support(ft, union_ids, k) |
| 108 | |
| 109 | # Normalize in case of floating noise from the fill-in. |
| 110 | base_total = float(base_probs.sum()) |
| 111 | ft_total = float(ft_probs.sum()) |
| 112 | if not (math.isfinite(base_total) and base_total > 0): |
| 113 | raise ProbeError( |
| 114 | "divergence", f"base distribution sums to {base_total} after support alignment" |
| 115 | ) |
| 116 | if not (math.isfinite(ft_total) and ft_total > 0): |
| 117 | raise ProbeError( |
| 118 | "divergence", f"ft distribution sums to {ft_total} after support alignment" |
| 119 | ) |
| 120 | base_probs /= base_total |
| 121 | ft_probs /= ft_total |
| 122 | return base_probs, ft_probs |
| 123 | |
| 124 | |
| 125 | def _to_support(dist: TokenDist, support: NDArray[np.int64], k: int) -> NDArray[np.float64]: |
| 126 | probs = np.exp(dist.logprobs.astype(np.float64)) |
| 127 | out = np.zeros(k, dtype=np.float64) |
| 128 | known_mass = float(probs.sum()) |
| 129 | tail_mass = max(0.0, 1.0 - known_mass) |
| 130 | |
| 131 | id_to_idx = {int(tok): idx for idx, tok in enumerate(support.tolist())} |
| 132 | missing = 0 |
| 133 | for tok, p in zip(dist.token_ids.tolist(), probs.tolist(), strict=True): |
| 134 | i = id_to_idx.get(int(tok)) |
| 135 | if i is None: |
| 136 | # Shouldn't happen given union construction. |
| 137 | missing += 1 |
| 138 | continue |
| 139 | out[i] = float(p) |
| 140 | |
| 141 | # Spread the tail mass over the support entries that this dist |
| 142 | # doesn't explicitly provide. Size of that set: |
| 143 | n_unknown = int((out == 0.0).sum()) - missing |
| 144 | if n_unknown > 0 and tail_mass > 0.0: |
| 145 | per = tail_mass / n_unknown |
| 146 | out[out == 0.0] = per |
| 147 | |
| 148 | return out |
| 149 | |
| 150 | |
| 151 | def _check_finite_array(name: str, arr: NDArray[np.float64]) -> None: |
| 152 | """Reject an array containing NaN/inf.""" |
| 153 | if not np.all(np.isfinite(arr)): |
| 154 | n_bad = int(np.sum(~np.isfinite(arr))) |
| 155 | raise ProbeError( |
| 156 | "divergence", |
| 157 | f"{name} contains {n_bad} non-finite entry/entries — refusing to compute divergence", |
| 158 | ) |
| 159 | |
| 160 | |
| 161 | def kl(p: NDArray[np.float64], q: NDArray[np.float64]) -> float: |
| 162 | """KL(p || q) in nats. Robust to zeros in p (treated as 0·log0 = 0). |
| 163 | |
| 164 | Raises :class:`ProbeError` on non-finite inputs. |
| 165 | """ |
| 166 | _check_finite_array("p", p) |
| 167 | _check_finite_array("q", q) |
| 168 | mask = p > 0.0 |
| 169 | safe_q = np.where(q > 0.0, q, 1e-12) |
| 170 | result = float(np.sum(p[mask] * (np.log(p[mask]) - np.log(safe_q[mask])))) |
| 171 | if not math.isfinite(result): |
| 172 | raise ProbeError("divergence", f"kl computation produced non-finite result: {result}") |
| 173 | return result |
| 174 | |
| 175 | |
| 176 | def js(p: NDArray[np.float64], q: NDArray[np.float64]) -> float: |
| 177 | """Jensen-Shannon divergence. Symmetric, bounded in [0, ln 2] (nats). |
| 178 | |
| 179 | The upper bound makes JS a nicer default for thresholding than raw |
| 180 | KL — a user doesn't need to know their specific model's KL scale to |
| 181 | pick a threshold. |
| 182 | |
| 183 | Raises :class:`ProbeError` on non-finite inputs or non-finite output. |
| 184 | """ |
| 185 | _check_finite_array("p", p) |
| 186 | _check_finite_array("q", q) |
| 187 | m = 0.5 * (p + q) |
| 188 | result = 0.5 * kl(p, m) + 0.5 * kl(q, m) |
| 189 | # Defense-in-depth: clamp into the theoretical bound. JS ∈ [0, ln 2]. |
| 190 | # Tiny negative or just-over-ln2 values are FP roundoff (especially |
| 191 | # when p ≈ q); broader excursions indicate upstream numerical drift, |
| 192 | # surface them as ProbeError rather than silently exceeding ln 2. |
| 193 | tol = 1e-9 |
| 194 | if result < -tol or result > math.log(2.0) + tol: |
| 195 | raise ProbeError( |
| 196 | "divergence", |
| 197 | f"js computed {result:.4f} nats, outside theoretical bound " |
| 198 | f"[0, {math.log(2.0):.4f}] — likely numerical drift in upstream " |
| 199 | f"distributions", |
| 200 | ) |
| 201 | return max(0.0, result) |
| 202 | |
| 203 | |
| 204 | def divergence(base: TokenDist, ft: TokenDist, kind: Divergence = "js") -> float: |
| 205 | """Compute KL or JS between two ``TokenDist`` on a shared support.""" |
| 206 | p, q = aligned_probs(base, ft) |
| 207 | if kind == "js": |
| 208 | return js(p, q) |
| 209 | if kind == "kl": |
| 210 | return kl(q, p) # KL(ft || base) — "how much does ft diverge from base" |
| 211 | raise ValueError(f"unknown divergence kind: {kind!r}") |
| 212 | |
| 213 | |
| 214 | def js_ln2() -> float: |
| 215 | """Upper bound on JS in nats. Useful for normalization.""" |
| 216 | return math.log(2.0) |