tenseleyflow/sway / b85197c

Browse files

tests: cover B6 None vs 0.0 distinction; widen integration assertion

Authored by espadonne
SHA
b85197c3ad9384911b1bcceba437a62b106b049f
Parents
c906f5a
Tree
b469612

2 changed files

StatusFile+-
M tests/integration/test_hf_scoring.py 3 1
M tests/unit/test_scoring.py 20 1
tests/integration/test_hf_scoring.pymodified
@@ -166,7 +166,9 @@ class TestNextTokenDist:
166166
         # Top-k must arrive in descending probability order.
167167
         assert np.all(np.diff(d.logprobs) <= 1e-7)
168168
         assert d.vocab_size > 64
169
-        assert math.isfinite(d.tail_logprob) or d.tail_logprob == 0.0
169
+        # B6: tail_logprob is None (k covers vocab — won't happen here),
170
+        # 0.0 (underflow), or a finite negative log-prob.
171
+        assert d.tail_logprob is None or math.isfinite(d.tail_logprob)
170172
 
171173
     def test_dist_changes_under_adapter(self, hf_backend: HuggingFaceDifferentialBackend) -> None:
172174
         prompt = "the adapter influences"
tests/unit/test_scoring.pymodified
@@ -44,9 +44,28 @@ class TestTokenDist:
4444
             logprobs=np.array([-0.1, -1.0, -3.0], dtype=np.float32),
4545
             vocab_size=50_257,
4646
         )
47
-        assert dist.tail_logprob == 0.0
47
+        # B6: default tail_logprob is None ("no tail recorded"), not
48
+        # 0.0 (which now means "tail underflowed to zero, but exists").
49
+        assert dist.tail_logprob is None
4850
         assert dist.token_ids.shape == (3,)
4951
 
52
+    def test_explicit_tail_distinguishes_zero_from_none(self) -> None:
53
+        """B6: 0.0 means measurable-but-tiny; None means no tail at all."""
54
+        d_no_tail = TokenDist(
55
+            token_ids=np.array([1], dtype=np.int64),
56
+            logprobs=np.array([0.0], dtype=np.float32),
57
+            vocab_size=1,
58
+            tail_logprob=None,
59
+        )
60
+        d_underflow = TokenDist(
61
+            token_ids=np.array([1], dtype=np.int64),
62
+            logprobs=np.array([0.0], dtype=np.float32),
63
+            vocab_size=1,
64
+            tail_logprob=0.0,
65
+        )
66
+        assert d_no_tail.tail_logprob is None
67
+        assert d_underflow.tail_logprob == 0.0
68
+
5069
 
5170
 class TestProtocols:
5271
     def test_scoring_backend_runtime_checkable(self) -> None: