Python · 8483 bytes Raw Blame History
1 """Shared math for divergence-based probes.
2
3 Extracted so :mod:`delta_kl`, :mod:`adapter_ablation`, and any future
4 probe operating on next-token distributions reuse the same aligned-
5 top-k KL / JS computation. Having one implementation keeps the numerical
6 treatment consistent across the report.
7
8 **Non-finite policy** (S01): every entry point in this module rejects
9 non-finite inputs *explicitly* by raising :class:`ProbeError`. The
10 historical bug — ``np.exp(nan) = nan`` flowing past a ``p > 0`` mask
11 that evaluates ``nan > 0`` as False — produced a mathematically
12 impossible JS divergence of 13.247 nats (bounded by ln 2 ≈ 0.693).
13 We refuse to compute on non-finite inputs rather than silently produce
14 garbage; the calling probe routes the resulting :class:`ProbeError` to
15 :attr:`Verdict.ERROR` via :func:`safe_finalize`.
16 """
17
18 from __future__ import annotations
19
20 import math
21 from typing import Literal
22
23 import numpy as np
24 from numpy.typing import NDArray
25
26 from dlm_sway.core.errors import ProbeError
27 from dlm_sway.core.scoring import TokenDist
28
29 Divergence = Literal["kl", "js"]
30
31
32 def _check_finite_token_dist(name: str, dist: TokenDist) -> None:
33 """Reject a TokenDist whose logprobs contain NaN/inf.
34
35 Called at the entry to ``aligned_probs``. The error message names the
36 side (base / ft) so a probe failure pinpoints the broken model.
37 """
38 if not np.all(np.isfinite(dist.logprobs)):
39 n_bad = int(np.sum(~np.isfinite(dist.logprobs)))
40 raise ProbeError(
41 "divergence",
42 f"{name} TokenDist contains {n_bad} non-finite logprob(s) — "
43 f"refusing to compute divergence on a model that produces "
44 f"NaN/inf logits",
45 )
46
47
48 # Tolerance for "effectively uniform." Real models never return
49 # bit-identical logits across the top-k — fp32 accumulation noise
50 # alone produces spreads in the 1e-5 range. We reject only distributions
51 # that are suspiciously exact: spread below 1e-9 strongly implies a
52 # broken lm_head or a test fixture that zero-fills logits.
53 _UNIFORM_LOGPROB_TOL: float = 1e-9
54
55
56 def _check_non_degenerate_token_dist(name: str, dist: TokenDist) -> None:
57 """Reject a TokenDist whose top-k logprobs are all (effectively) equal.
58
59 A well-formed next-token distribution from a real model has a
60 peaked top-k; perfectly uniform logprobs mean either the lm_head
61 broke or upstream sampling code clobbered the logits. The
62 divergence math is still defined (KL(uniform ∥ uniform) = 0) but
63 the resulting probe value would be a constant across prompts,
64 producing a meaningless ``delta_kl``. Surface the broken model
65 explicitly rather than letting it leak a false zero.
66
67 Called after the finite check — order matters because a NaN top-k
68 would trip this guard first with a confusing message.
69 """
70 if dist.logprobs.size < 2:
71 return
72 spread = float(dist.logprobs.max() - dist.logprobs.min())
73 if spread < _UNIFORM_LOGPROB_TOL:
74 raise ProbeError(
75 "divergence",
76 f"{name} TokenDist has effectively-uniform top-{dist.logprobs.size} "
77 f"logprobs (spread={spread:.2e}) — divergence on a degenerate "
78 f"distribution would return a trivial constant; refusing to "
79 f"proceed. Check the backend's lm_head and logits pipeline.",
80 )
81
82
83 def aligned_probs(
84 base: TokenDist, ft: TokenDist
85 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
86 """Return aligned probability vectors over the union of top-k tokens.
87
88 Two ``TokenDist`` objects may surface different top-k indices if
89 the two models disagree about the hot tokens. We build a shared
90 support — ``union(base.token_ids, ft.token_ids)`` — and slot the
91 known probabilities in. Unknown entries fall back to the
92 per-distribution tail mass divided across the missing tokens,
93 which is the maximum-entropy completion under the truncation.
94
95 Raises :class:`ProbeError` when either input contains non-finite
96 logprobs.
97 """
98 _check_finite_token_dist("base", base)
99 _check_finite_token_dist("ft", ft)
100 _check_non_degenerate_token_dist("base", base)
101 _check_non_degenerate_token_dist("ft", ft)
102
103 union_ids = np.union1d(base.token_ids, ft.token_ids)
104 k = int(union_ids.size)
105
106 base_probs = _to_support(base, union_ids, k)
107 ft_probs = _to_support(ft, union_ids, k)
108
109 # Normalize in case of floating noise from the fill-in.
110 base_total = float(base_probs.sum())
111 ft_total = float(ft_probs.sum())
112 if not (math.isfinite(base_total) and base_total > 0):
113 raise ProbeError(
114 "divergence", f"base distribution sums to {base_total} after support alignment"
115 )
116 if not (math.isfinite(ft_total) and ft_total > 0):
117 raise ProbeError(
118 "divergence", f"ft distribution sums to {ft_total} after support alignment"
119 )
120 base_probs /= base_total
121 ft_probs /= ft_total
122 return base_probs, ft_probs
123
124
125 def _to_support(dist: TokenDist, support: NDArray[np.int64], k: int) -> NDArray[np.float64]:
126 probs = np.exp(dist.logprobs.astype(np.float64))
127 out = np.zeros(k, dtype=np.float64)
128 known_mass = float(probs.sum())
129 tail_mass = max(0.0, 1.0 - known_mass)
130
131 id_to_idx = {int(tok): idx for idx, tok in enumerate(support.tolist())}
132 missing = 0
133 for tok, p in zip(dist.token_ids.tolist(), probs.tolist(), strict=True):
134 i = id_to_idx.get(int(tok))
135 if i is None:
136 # Shouldn't happen given union construction.
137 missing += 1
138 continue
139 out[i] = float(p)
140
141 # Spread the tail mass over the support entries that this dist
142 # doesn't explicitly provide. Size of that set:
143 n_unknown = int((out == 0.0).sum()) - missing
144 if n_unknown > 0 and tail_mass > 0.0:
145 per = tail_mass / n_unknown
146 out[out == 0.0] = per
147
148 return out
149
150
151 def _check_finite_array(name: str, arr: NDArray[np.float64]) -> None:
152 """Reject an array containing NaN/inf."""
153 if not np.all(np.isfinite(arr)):
154 n_bad = int(np.sum(~np.isfinite(arr)))
155 raise ProbeError(
156 "divergence",
157 f"{name} contains {n_bad} non-finite entry/entries — refusing to compute divergence",
158 )
159
160
161 def kl(p: NDArray[np.float64], q: NDArray[np.float64]) -> float:
162 """KL(p || q) in nats. Robust to zeros in p (treated as 0·log0 = 0).
163
164 Raises :class:`ProbeError` on non-finite inputs.
165 """
166 _check_finite_array("p", p)
167 _check_finite_array("q", q)
168 mask = p > 0.0
169 safe_q = np.where(q > 0.0, q, 1e-12)
170 result = float(np.sum(p[mask] * (np.log(p[mask]) - np.log(safe_q[mask]))))
171 if not math.isfinite(result):
172 raise ProbeError("divergence", f"kl computation produced non-finite result: {result}")
173 return result
174
175
176 def js(p: NDArray[np.float64], q: NDArray[np.float64]) -> float:
177 """Jensen-Shannon divergence. Symmetric, bounded in [0, ln 2] (nats).
178
179 The upper bound makes JS a nicer default for thresholding than raw
180 KL — a user doesn't need to know their specific model's KL scale to
181 pick a threshold.
182
183 Raises :class:`ProbeError` on non-finite inputs or non-finite output.
184 """
185 _check_finite_array("p", p)
186 _check_finite_array("q", q)
187 m = 0.5 * (p + q)
188 result = 0.5 * kl(p, m) + 0.5 * kl(q, m)
189 # Defense-in-depth: clamp into the theoretical bound. JS ∈ [0, ln 2].
190 # Tiny negative or just-over-ln2 values are FP roundoff (especially
191 # when p ≈ q); broader excursions indicate upstream numerical drift,
192 # surface them as ProbeError rather than silently exceeding ln 2.
193 tol = 1e-9
194 if result < -tol or result > math.log(2.0) + tol:
195 raise ProbeError(
196 "divergence",
197 f"js computed {result:.4f} nats, outside theoretical bound "
198 f"[0, {math.log(2.0):.4f}] — likely numerical drift in upstream "
199 f"distributions",
200 )
201 return max(0.0, result)
202
203
204 def divergence(base: TokenDist, ft: TokenDist, kind: Divergence = "js") -> float:
205 """Compute KL or JS between two ``TokenDist`` on a shared support."""
206 p, q = aligned_probs(base, ft)
207 if kind == "js":
208 return js(p, q)
209 if kind == "kl":
210 return kl(q, p) # KL(ft || base) — "how much does ft diverge from base"
211 raise ValueError(f"unknown divergence kind: {kind!r}")
212
213
214 def js_ln2() -> float:
215 """Upper bound on JS in nats. Useful for normalization."""
216 return math.log(2.0)