| 1 | """Tests for :mod:`dlm_sway.core.scoring`.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import math |
| 6 | |
| 7 | import numpy as np |
| 8 | |
| 9 | from dlm_sway.core.scoring import ( |
| 10 | DifferentialBackend, |
| 11 | RollingLogprob, |
| 12 | ScoringBackend, |
| 13 | TokenDist, |
| 14 | ) |
| 15 | |
| 16 | |
| 17 | class TestRollingLogprob: |
| 18 | def test_empty_sequence(self) -> None: |
| 19 | r = RollingLogprob( |
| 20 | token_ids=np.array([42], dtype=np.int64), |
| 21 | logprobs=np.array([], dtype=np.float32), |
| 22 | num_tokens=1, |
| 23 | total_logprob=0.0, |
| 24 | ) |
| 25 | assert r.mean_logprob == 0.0 |
| 26 | assert r.perplexity == 1.0 |
| 27 | |
| 28 | def test_mean_and_perplexity(self) -> None: |
| 29 | # Three tokens, two transition logprobs summing to -4.0 → mean -2.0. |
| 30 | r = RollingLogprob( |
| 31 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 32 | logprobs=np.array([-1.5, -2.5], dtype=np.float32), |
| 33 | num_tokens=3, |
| 34 | total_logprob=-4.0, |
| 35 | ) |
| 36 | assert math.isclose(r.mean_logprob, -2.0, rel_tol=1e-6) |
| 37 | assert math.isclose(r.perplexity, math.exp(2.0), rel_tol=1e-6) |
| 38 | |
| 39 | |
| 40 | class TestTokenDist: |
| 41 | def test_construction_and_defaults(self) -> None: |
| 42 | dist = TokenDist( |
| 43 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 44 | logprobs=np.array([-0.1, -1.0, -3.0], dtype=np.float32), |
| 45 | vocab_size=50_257, |
| 46 | ) |
| 47 | # B6: default tail_logprob is None ("no tail recorded"), not |
| 48 | # 0.0 (which now means "tail underflowed to zero, but exists"). |
| 49 | assert dist.tail_logprob is None |
| 50 | assert dist.token_ids.shape == (3,) |
| 51 | |
| 52 | def test_explicit_tail_distinguishes_zero_from_none(self) -> None: |
| 53 | """B6: 0.0 means measurable-but-tiny; None means no tail at all.""" |
| 54 | d_no_tail = TokenDist( |
| 55 | token_ids=np.array([1], dtype=np.int64), |
| 56 | logprobs=np.array([0.0], dtype=np.float32), |
| 57 | vocab_size=1, |
| 58 | tail_logprob=None, |
| 59 | ) |
| 60 | d_underflow = TokenDist( |
| 61 | token_ids=np.array([1], dtype=np.int64), |
| 62 | logprobs=np.array([0.0], dtype=np.float32), |
| 63 | vocab_size=1, |
| 64 | tail_logprob=0.0, |
| 65 | ) |
| 66 | assert d_no_tail.tail_logprob is None |
| 67 | assert d_underflow.tail_logprob == 0.0 |
| 68 | |
| 69 | |
| 70 | class TestProtocols: |
| 71 | def test_scoring_backend_runtime_checkable(self) -> None: |
| 72 | class FakeScoring: |
| 73 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 74 | return 0.0 |
| 75 | |
| 76 | def rolling_logprob(self, text: str) -> RollingLogprob: |
| 77 | return RollingLogprob( |
| 78 | token_ids=np.array([0], dtype=np.int64), |
| 79 | logprobs=np.array([], dtype=np.float32), |
| 80 | num_tokens=1, |
| 81 | total_logprob=0.0, |
| 82 | ) |
| 83 | |
| 84 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 85 | return TokenDist( |
| 86 | token_ids=np.array([0], dtype=np.int64), |
| 87 | logprobs=np.array([0.0], dtype=np.float32), |
| 88 | vocab_size=1, |
| 89 | ) |
| 90 | |
| 91 | def next_token_dist_batch( |
| 92 | self, |
| 93 | prompts, # type: ignore[no-untyped-def] |
| 94 | *, |
| 95 | top_k: int = 256, |
| 96 | ) -> list[TokenDist]: |
| 97 | # S23 — Protocol requires the batched method at |
| 98 | # runtime. Defer to the single-prompt path; enough to |
| 99 | # satisfy the runtime_checkable isinstance check. |
| 100 | return [self.next_token_dist(p, top_k=top_k) for p in prompts] |
| 101 | |
| 102 | assert isinstance(FakeScoring(), ScoringBackend) |
| 103 | |
| 104 | def test_differential_backend_runtime_checkable(self) -> None: |
| 105 | from contextlib import nullcontext |
| 106 | |
| 107 | class FakeDiff: |
| 108 | def as_base(self): # type: ignore[no-untyped-def] |
| 109 | return nullcontext(object()) |
| 110 | |
| 111 | def as_finetuned(self): # type: ignore[no-untyped-def] |
| 112 | return nullcontext(object()) |
| 113 | |
| 114 | assert isinstance(FakeDiff(), DifferentialBackend) |