@@ -194,6 +194,62 @@ class TestNonFiniteRejection: |
| 194 | 194 | assert 0.0 <= result <= math.log(2.0) + 1e-9 |
| 195 | 195 | |
| 196 | 196 | |
| 197 | +class TestDegenerateUniformRejection: |
| 198 | + """Stronger-test #9 — reject a TokenDist whose top-k logprobs are |
| 199 | + identical. A real model never emits bit-uniform logits; getting |
| 200 | + one means lm_head broke or a fixture zeroed out logits. Silently |
| 201 | + computing ``divergence`` on such a dist returns a trivial constant |
| 202 | + across prompts that would contaminate ``delta_kl`` / ``cluster_kl``. |
| 203 | + """ |
| 204 | + |
| 205 | + def test_perfectly_uniform_dist_is_rejected(self) -> None: |
| 206 | + k = 8 |
| 207 | + uniform = TokenDist( |
| 208 | + token_ids=np.arange(k, dtype=np.int64), |
| 209 | + logprobs=np.full(k, -math.log(k), dtype=np.float32), |
| 210 | + vocab_size=1000, |
| 211 | + ) |
| 212 | + good = _dist([1, 2], [0.9, 0.1]) |
| 213 | + with pytest.raises(ProbeError, match="effectively-uniform"): |
| 214 | + aligned_probs(good, uniform) |
| 215 | + |
| 216 | + def test_near_uniform_real_model_shape_is_accepted(self) -> None: |
| 217 | + """A broad-but-not-literally-flat dist (the shape a real model |
| 218 | + with high entropy produces) must still compute a divergence.""" |
| 219 | + k = 8 |
| 220 | + lp = np.full(k, -math.log(k), dtype=np.float32) |
| 221 | + # Tiny monotonic perturbation — enough to clear the 1e-9 |
| 222 | + # uniformity threshold without meaningfully changing the |
| 223 | + # entropy. |
| 224 | + lp += np.linspace(-1e-5, 1e-5, k, dtype=np.float32) |
| 225 | + broad = TokenDist( |
| 226 | + token_ids=np.arange(k, dtype=np.int64), |
| 227 | + logprobs=lp, |
| 228 | + vocab_size=1000, |
| 229 | + ) |
| 230 | + sharp = TokenDist( |
| 231 | + token_ids=np.arange(k, dtype=np.int64), |
| 232 | + logprobs=np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32), |
| 233 | + vocab_size=1000, |
| 234 | + ) |
| 235 | + # No exception — and KL/JS are finite and positive. |
| 236 | + result = js(*aligned_probs(sharp, broad)) |
| 237 | + assert math.isfinite(result) |
| 238 | + assert result > 0.0 |
| 239 | + |
| 240 | + def test_single_token_dist_not_rejected(self) -> None: |
| 241 | + """A distribution with only one token can't be "uniform" — |
| 242 | + there's no spread to compute. The guard must short-circuit.""" |
| 243 | + one = TokenDist( |
| 244 | + token_ids=np.array([0], dtype=np.int64), |
| 245 | + logprobs=np.array([0.0], dtype=np.float32), |
| 246 | + vocab_size=1000, |
| 247 | + ) |
| 248 | + # Must not raise (``aligned_probs`` handles single-token dists |
| 249 | + # fine; the degenerate check short-circuits at ``size < 2``). |
| 250 | + aligned_probs(one, one) |
| 251 | + |
| 252 | + |
| 197 | 253 | # ---- Hypothesis property tests ------------------------------------ |
| 198 | 254 | |
| 199 | 255 | |