Python · 4008 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.core.scoring`."""
2
3 from __future__ import annotations
4
5 import math
6
7 import numpy as np
8
9 from dlm_sway.core.scoring import (
10 DifferentialBackend,
11 RollingLogprob,
12 ScoringBackend,
13 TokenDist,
14 )
15
16
17 class TestRollingLogprob:
18 def test_empty_sequence(self) -> None:
19 r = RollingLogprob(
20 token_ids=np.array([42], dtype=np.int64),
21 logprobs=np.array([], dtype=np.float32),
22 num_tokens=1,
23 total_logprob=0.0,
24 )
25 assert r.mean_logprob == 0.0
26 assert r.perplexity == 1.0
27
28 def test_mean_and_perplexity(self) -> None:
29 # Three tokens, two transition logprobs summing to -4.0 → mean -2.0.
30 r = RollingLogprob(
31 token_ids=np.array([1, 2, 3], dtype=np.int64),
32 logprobs=np.array([-1.5, -2.5], dtype=np.float32),
33 num_tokens=3,
34 total_logprob=-4.0,
35 )
36 assert math.isclose(r.mean_logprob, -2.0, rel_tol=1e-6)
37 assert math.isclose(r.perplexity, math.exp(2.0), rel_tol=1e-6)
38
39
40 class TestTokenDist:
41 def test_construction_and_defaults(self) -> None:
42 dist = TokenDist(
43 token_ids=np.array([1, 2, 3], dtype=np.int64),
44 logprobs=np.array([-0.1, -1.0, -3.0], dtype=np.float32),
45 vocab_size=50_257,
46 )
47 # B6: default tail_logprob is None ("no tail recorded"), not
48 # 0.0 (which now means "tail underflowed to zero, but exists").
49 assert dist.tail_logprob is None
50 assert dist.token_ids.shape == (3,)
51
52 def test_explicit_tail_distinguishes_zero_from_none(self) -> None:
53 """B6: 0.0 means measurable-but-tiny; None means no tail at all."""
54 d_no_tail = TokenDist(
55 token_ids=np.array([1], dtype=np.int64),
56 logprobs=np.array([0.0], dtype=np.float32),
57 vocab_size=1,
58 tail_logprob=None,
59 )
60 d_underflow = TokenDist(
61 token_ids=np.array([1], dtype=np.int64),
62 logprobs=np.array([0.0], dtype=np.float32),
63 vocab_size=1,
64 tail_logprob=0.0,
65 )
66 assert d_no_tail.tail_logprob is None
67 assert d_underflow.tail_logprob == 0.0
68
69
70 class TestProtocols:
71 def test_scoring_backend_runtime_checkable(self) -> None:
72 class FakeScoring:
73 def logprob_of(self, prompt: str, completion: str) -> float:
74 return 0.0
75
76 def rolling_logprob(self, text: str) -> RollingLogprob:
77 return RollingLogprob(
78 token_ids=np.array([0], dtype=np.int64),
79 logprobs=np.array([], dtype=np.float32),
80 num_tokens=1,
81 total_logprob=0.0,
82 )
83
84 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
85 return TokenDist(
86 token_ids=np.array([0], dtype=np.int64),
87 logprobs=np.array([0.0], dtype=np.float32),
88 vocab_size=1,
89 )
90
91 def next_token_dist_batch(
92 self,
93 prompts, # type: ignore[no-untyped-def]
94 *,
95 top_k: int = 256,
96 ) -> list[TokenDist]:
97 # S23 — Protocol requires the batched method at
98 # runtime. Defer to the single-prompt path; enough to
99 # satisfy the runtime_checkable isinstance check.
100 return [self.next_token_dist(p, top_k=top_k) for p in prompts]
101
102 assert isinstance(FakeScoring(), ScoringBackend)
103
104 def test_differential_backend_runtime_checkable(self) -> None:
105 from contextlib import nullcontext
106
107 class FakeDiff:
108 def as_base(self): # type: ignore[no-untyped-def]
109 return nullcontext(object())
110
111 def as_finetuned(self): # type: ignore[no-untyped-def]
112 return nullcontext(object())
113
114 assert isinstance(FakeDiff(), DifferentialBackend)