tenseleyflow/sway / cea35e6

Browse files

tests/divergence: degenerate uniform TokenDist rejection (stronger-test #9)

Authored by espadonne
SHA
cea35e6b8f514c20037590622d5be297c51e833a
Parents
a1d1f32
Tree
5136628

1 changed file

StatusFile+-
M tests/unit/test_divergence.py 56 0
tests/unit/test_divergence.pymodified
@@ -194,6 +194,62 @@ class TestNonFiniteRejection:
194194
         assert 0.0 <= result <= math.log(2.0) + 1e-9
195195
 
196196
 
197
+class TestDegenerateUniformRejection:
198
+    """Stronger-test #9 — reject a TokenDist whose top-k logprobs are
199
+    identical. A real model never emits bit-uniform logits; getting
200
+    one means lm_head broke or a fixture zeroed out logits. Silently
201
+    computing ``divergence`` on such a dist returns a trivial constant
202
+    across prompts that would contaminate ``delta_kl`` / ``cluster_kl``.
203
+    """
204
+
205
+    def test_perfectly_uniform_dist_is_rejected(self) -> None:
206
+        k = 8
207
+        uniform = TokenDist(
208
+            token_ids=np.arange(k, dtype=np.int64),
209
+            logprobs=np.full(k, -math.log(k), dtype=np.float32),
210
+            vocab_size=1000,
211
+        )
212
+        good = _dist([1, 2], [0.9, 0.1])
213
+        with pytest.raises(ProbeError, match="effectively-uniform"):
214
+            aligned_probs(good, uniform)
215
+
216
+    def test_near_uniform_real_model_shape_is_accepted(self) -> None:
217
+        """A broad-but-not-literally-flat dist (the shape a real model
218
+        with high entropy produces) must still compute a divergence."""
219
+        k = 8
220
+        lp = np.full(k, -math.log(k), dtype=np.float32)
221
+        # Tiny monotonic perturbation — enough to clear the 1e-9
222
+        # uniformity threshold without meaningfully changing the
223
+        # entropy.
224
+        lp += np.linspace(-1e-5, 1e-5, k, dtype=np.float32)
225
+        broad = TokenDist(
226
+            token_ids=np.arange(k, dtype=np.int64),
227
+            logprobs=lp,
228
+            vocab_size=1000,
229
+        )
230
+        sharp = TokenDist(
231
+            token_ids=np.arange(k, dtype=np.int64),
232
+            logprobs=np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32),
233
+            vocab_size=1000,
234
+        )
235
+        # No exception — and KL/JS are finite and positive.
236
+        result = js(*aligned_probs(sharp, broad))
237
+        assert math.isfinite(result)
238
+        assert result > 0.0
239
+
240
+    def test_single_token_dist_not_rejected(self) -> None:
241
+        """A distribution with only one token can't be "uniform" —
242
+        there's no spread to compute. The guard must short-circuit."""
243
+        one = TokenDist(
244
+            token_ids=np.array([0], dtype=np.int64),
245
+            logprobs=np.array([0.0], dtype=np.float32),
246
+            vocab_size=1000,
247
+        )
248
+        # Must not raise (``aligned_probs`` handles single-token dists
249
+        # fine; the degenerate check short-circuits at ``size < 2``).
250
+        aligned_probs(one, one)
251
+
252
+
197253
 # ---- Hypothesis property tests ------------------------------------
198254
 
199255