Python · 11040 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes._divergence`.
2
3 Includes property-based tests (`hypothesis`) on the divergence
4 invariants and explicit tests that non-finite inputs raise
5 ``ProbeError`` rather than producing silent garbage. The latter pin the
6 S01 fix for the +11639σ bug.
7 """
8
9 from __future__ import annotations
10
11 import math
12
13 import numpy as np
14 import pytest
15 from hypothesis import HealthCheck, given, settings
16 from hypothesis import strategies as st
17 from hypothesis.extra.numpy import arrays
18
19 from dlm_sway.core.errors import ProbeError
20 from dlm_sway.core.scoring import TokenDist
21 from dlm_sway.probes._divergence import aligned_probs, divergence, js, kl
22
23
24 def _dist(ids: list[int], probs: list[float], vocab: int = 100) -> TokenDist:
25 return TokenDist(
26 token_ids=np.asarray(ids, dtype=np.int64),
27 logprobs=np.log(np.asarray(probs, dtype=np.float32)),
28 vocab_size=vocab,
29 )
30
31
32 def _normalized_simplex(size: int) -> st.SearchStrategy[np.ndarray]:
33 """Hypothesis strategy: arrays of `size` non-negative floats summing to 1."""
34
35 def _norm(a: np.ndarray) -> np.ndarray:
36 a = np.abs(a) + 1e-6 # avoid all-zeros
37 return a / a.sum()
38
39 return arrays(
40 dtype=np.float64,
41 shape=size,
42 elements=st.floats(min_value=0.001, max_value=10.0, allow_nan=False, allow_infinity=False),
43 ).map(_norm)
44
45
46 class TestAligned:
47 def test_identical_distributions(self) -> None:
48 d = _dist([1, 2, 3], [0.5, 0.3, 0.2])
49 p, q = aligned_probs(d, d)
50 np.testing.assert_allclose(p, q)
51
52 def test_union_support_fills_missing(self) -> None:
53 base = _dist([1, 2, 3], [0.5, 0.3, 0.2])
54 ft = _dist([2, 3, 4], [0.4, 0.4, 0.2])
55 p, q = aligned_probs(base, ft)
56 assert p.shape == (4,)
57 assert abs(p.sum() - 1.0) < 1e-9
58 assert abs(q.sum() - 1.0) < 1e-9
59
60 def test_disjoint_top_k_supports_produce_finite_divergence(self) -> None:
61 """C12: base peaks on {1,2,3}, ft peaks on {7,8,9}. The aligned
62 probabilities have 6 entries (union), with tail-mass redistributed
63 to each side's missing tokens via the ``tail_logprob`` fallback.
64 Divergence must be finite, positive, and close to ln(2) (the JS
65 bound for distributions with disjoint support)."""
66 base = TokenDist(
67 token_ids=np.asarray([1, 2, 3], dtype=np.int64),
68 logprobs=np.log(np.asarray([0.6, 0.25, 0.15], dtype=np.float32)),
69 vocab_size=1000,
70 # Top-3 covers all of base; tail is effectively zero.
71 tail_logprob=float(math.log(1e-9)),
72 )
73 ft = TokenDist(
74 token_ids=np.asarray([7, 8, 9], dtype=np.int64),
75 logprobs=np.log(np.asarray([0.5, 0.3, 0.2], dtype=np.float32)),
76 vocab_size=1000,
77 tail_logprob=float(math.log(1e-9)),
78 )
79 p, q = aligned_probs(base, ft)
80 # Union of {1,2,3} and {7,8,9} is 6 tokens.
81 assert p.shape == (6,)
82 assert q.shape == (6,)
83 # Both distributions should be (approximately) normalized — any
84 # tail redistribution leaves a small residual below 1e-3.
85 assert abs(p.sum() - 1.0) < 1e-3
86 assert abs(q.sum() - 1.0) < 1e-3
87
88 d = divergence(base, ft, kind="js")
89 assert math.isfinite(d)
90 assert d > 0.0
91 # Fully disjoint support → JS approaches its ln(2) upper bound.
92 assert d < math.log(2.0) + 1e-6
93 assert d > 0.5 # meaningfully above zero
94
95
96 class TestKL:
97 def test_zero_when_equal(self) -> None:
98 p = np.array([0.5, 0.3, 0.2])
99 assert kl(p, p) == 0.0
100
101 def test_positive_when_different(self) -> None:
102 p = np.array([0.7, 0.2, 0.1])
103 q = np.array([0.2, 0.3, 0.5])
104 assert kl(p, q) > 0.0
105
106
107 class TestJS:
108 def test_zero_when_equal(self) -> None:
109 p = np.array([0.5, 0.3, 0.2])
110 assert js(p, p) == 0.0
111
112 def test_symmetric(self) -> None:
113 p = np.array([0.7, 0.2, 0.1])
114 q = np.array([0.2, 0.3, 0.5])
115 assert math.isclose(js(p, q), js(q, p), rel_tol=1e-9)
116
117 def test_bounded_by_ln2(self) -> None:
118 p = np.array([1.0, 0.0])
119 q = np.array([0.0, 1.0])
120 # With zeros handled as 0·log0 = 0 this approaches ln(2).
121 assert js(p, q) <= math.log(2.0) + 1e-9
122
123
124 class TestDivergenceDispatch:
125 def test_default_is_js(self) -> None:
126 d1 = _dist([1, 2], [0.6, 0.4])
127 d2 = _dist([1, 2], [0.3, 0.7])
128 assert divergence(d1, d2) == divergence(d1, d2, kind="js")
129
130 def test_kl_available(self) -> None:
131 d1 = _dist([1, 2], [0.6, 0.4])
132 d2 = _dist([1, 2], [0.3, 0.7])
133 assert divergence(d1, d2, kind="kl") >= 0.0
134
135
136 class TestNonFiniteRejection:
137 """Pins the S01 fix: NaN / inf inputs raise ProbeError, never silent garbage.
138
139 The historical bug: ``np.exp(nan) = nan`` flowed past the ``p > 0``
140 mask in ``kl()`` (because ``nan > 0`` is False), producing a
141 ``js`` of 13.247 nats — algebraically impossible for JS (≤ ln 2).
142 """
143
144 def test_kl_rejects_nan_in_p(self) -> None:
145 p = np.array([0.5, math.nan, 0.5])
146 q = np.array([0.3, 0.4, 0.3])
147 with pytest.raises(ProbeError, match="non-finite"):
148 kl(p, q)
149
150 def test_kl_rejects_inf_in_q(self) -> None:
151 p = np.array([0.5, 0.5])
152 q = np.array([0.5, math.inf])
153 with pytest.raises(ProbeError, match="non-finite"):
154 kl(p, q)
155
156 def test_js_rejects_nan(self) -> None:
157 p = np.array([math.nan, 1.0])
158 q = np.array([0.5, 0.5])
159 with pytest.raises(ProbeError, match="non-finite"):
160 js(p, q)
161
162 def test_aligned_probs_rejects_nan_logprobs(self) -> None:
163 bad = TokenDist(
164 token_ids=np.array([1, 2], dtype=np.int64),
165 logprobs=np.array([-0.5, math.nan], dtype=np.float32),
166 vocab_size=100,
167 )
168 good = _dist([1, 2], [0.5, 0.5])
169 with pytest.raises(ProbeError, match="ft TokenDist contains 1 non-finite"):
170 aligned_probs(good, bad)
171 with pytest.raises(ProbeError, match="base TokenDist contains 1 non-finite"):
172 aligned_probs(bad, good)
173
174 def test_divergence_rejects_nan_token_dist(self) -> None:
175 bad = TokenDist(
176 token_ids=np.array([1, 2], dtype=np.int64),
177 logprobs=np.array([math.nan, math.nan], dtype=np.float32),
178 vocab_size=100,
179 )
180 good = _dist([1, 2], [0.5, 0.5])
181 with pytest.raises(ProbeError):
182 divergence(good, bad)
183
184 def test_js_caps_at_ln2_bound(self) -> None:
185 """Defense-in-depth: a hand-rolled p,q where naive computation
186 could drift past ln(2) due to FP noise must still raise."""
187 # This case is hard to construct intentionally; we instead poison
188 # the KL math by constructing a p that's already pathological.
189 p = np.array([1.0 - 1e-15, 1e-15])
190 q = np.array([1e-15, 1.0 - 1e-15])
191 # Real JS here is ≤ ln(2); the function should not raise on
192 # well-formed near-extreme distributions.
193 result = js(p, q)
194 assert 0.0 <= result <= math.log(2.0) + 1e-9
195
196
197 class TestDegenerateUniformRejection:
198 """Stronger-test #9 — reject a TokenDist whose top-k logprobs are
199 identical. A real model never emits bit-uniform logits; getting
200 one means lm_head broke or a fixture zeroed out logits. Silently
201 computing ``divergence`` on such a dist returns a trivial constant
202 across prompts that would contaminate ``delta_kl`` / ``cluster_kl``.
203 """
204
205 def test_perfectly_uniform_dist_is_rejected(self) -> None:
206 k = 8
207 uniform = TokenDist(
208 token_ids=np.arange(k, dtype=np.int64),
209 logprobs=np.full(k, -math.log(k), dtype=np.float32),
210 vocab_size=1000,
211 )
212 good = _dist([1, 2], [0.9, 0.1])
213 with pytest.raises(ProbeError, match="effectively-uniform"):
214 aligned_probs(good, uniform)
215
216 def test_near_uniform_real_model_shape_is_accepted(self) -> None:
217 """A broad-but-not-literally-flat dist (the shape a real model
218 with high entropy produces) must still compute a divergence."""
219 k = 8
220 lp = np.full(k, -math.log(k), dtype=np.float32)
221 # Tiny monotonic perturbation — enough to clear the 1e-9
222 # uniformity threshold without meaningfully changing the
223 # entropy.
224 lp += np.linspace(-1e-5, 1e-5, k, dtype=np.float32)
225 broad = TokenDist(
226 token_ids=np.arange(k, dtype=np.int64),
227 logprobs=lp,
228 vocab_size=1000,
229 )
230 sharp = TokenDist(
231 token_ids=np.arange(k, dtype=np.int64),
232 logprobs=np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32),
233 vocab_size=1000,
234 )
235 # No exception — and KL/JS are finite and positive.
236 result = js(*aligned_probs(sharp, broad))
237 assert math.isfinite(result)
238 assert result > 0.0
239
240 def test_single_token_dist_not_rejected(self) -> None:
241 """A distribution with only one token can't be "uniform" —
242 there's no spread to compute. The guard must short-circuit."""
243 one = TokenDist(
244 token_ids=np.array([0], dtype=np.int64),
245 logprobs=np.array([0.0], dtype=np.float32),
246 vocab_size=1000,
247 )
248 # Must not raise (``aligned_probs`` handles single-token dists
249 # fine; the degenerate check short-circuits at ``size < 2``).
250 aligned_probs(one, one)
251
252
253 # ---- Hypothesis property tests ------------------------------------
254
255
256 @settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow])
257 @given(p=_normalized_simplex(5))
258 def test_kl_self_is_zero(p: np.ndarray) -> None:
259 """KL(p || p) == 0 for any well-formed p."""
260 assert kl(p, p) == pytest.approx(0.0, abs=1e-9)
261
262
263 @settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow])
264 @given(p=_normalized_simplex(5))
265 def test_js_self_is_zero(p: np.ndarray) -> None:
266 """JS(p, p) == 0 for any well-formed p."""
267 assert js(p, p) == pytest.approx(0.0, abs=1e-9)
268
269
270 @settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow])
271 @given(p=_normalized_simplex(5), q=_normalized_simplex(5))
272 def test_js_symmetric(p: np.ndarray, q: np.ndarray) -> None:
273 """JS(p, q) == JS(q, p)."""
274 assert js(p, q) == pytest.approx(js(q, p), abs=1e-9)
275
276
277 @settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow])
278 @given(p=_normalized_simplex(5), q=_normalized_simplex(5))
279 def test_js_bounded_by_ln2(p: np.ndarray, q: np.ndarray) -> None:
280 """JS(p, q) ∈ [0, ln 2] for any pair of distributions."""
281 v = js(p, q)
282 assert 0.0 <= v <= math.log(2.0) + 1e-9
283
284
285 @settings(max_examples=50, suppress_health_check=[HealthCheck.too_slow])
286 @given(p=_normalized_simplex(5), q=_normalized_simplex(5))
287 def test_kl_non_negative(p: np.ndarray, q: np.ndarray) -> None:
288 """KL(p || q) ≥ 0 for any pair of distributions (Gibbs' inequality)."""
289 assert kl(p, q) >= -1e-9