Python · 2518 bytes Raw Blame History
1 """Tests for ``preflight_finite_check`` on the shipped backends.
2
3 The HF backend's check is exercised in the integration suite (it needs
4 a real model). Here we verify the dummy backend's contract and the
5 PreflightCheckable Protocol shape.
6 """
7
8 from __future__ import annotations
9
10 import math
11
12 import numpy as np
13
14 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
15 from dlm_sway.core.scoring import PreflightCheckable, TokenDist
16
17
18 class TestProtocolShape:
19 def test_dummy_satisfies_preflight_protocol(self) -> None:
20 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
21 assert isinstance(backend, PreflightCheckable)
22
23
24 class TestDummyDefaultIsFinite:
25 def test_default_dummy_passes_preflight(self) -> None:
26 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
27 ok, reason = backend.preflight_finite_check()
28 assert ok is True
29 assert reason == ""
30
31
32 class TestDummyNanDetection:
33 def test_nan_in_ft_token_dist_caught(self) -> None:
34 nan_dist = TokenDist(
35 token_ids=np.array([1, 2, 3], dtype=np.int64),
36 logprobs=np.array([-0.1, math.nan, -2.0], dtype=np.float32),
37 vocab_size=100,
38 )
39 ft = DummyResponses(token_dists={"preflight": nan_dist})
40 backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
41 ok, reason = backend.preflight_finite_check()
42 assert ok is False
43 assert "ft view" in reason
44 assert "non-finite" in reason
45
46 def test_nan_in_base_token_dist_caught(self) -> None:
47 nan_dist = TokenDist(
48 token_ids=np.array([1, 2], dtype=np.int64),
49 logprobs=np.array([math.nan, math.nan], dtype=np.float32),
50 vocab_size=100,
51 )
52 base = DummyResponses(token_dists={"preflight": nan_dist})
53 backend = DummyDifferentialBackend(base=base, ft=DummyResponses())
54 ok, reason = backend.preflight_finite_check()
55 assert ok is False
56 assert "base view" in reason
57
58 def test_inf_caught(self) -> None:
59 inf_dist = TokenDist(
60 token_ids=np.array([1, 2], dtype=np.int64),
61 logprobs=np.array([-0.5, math.inf], dtype=np.float32),
62 vocab_size=100,
63 )
64 ft = DummyResponses(token_dists={"preflight": inf_dist})
65 backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
66 ok, _reason = backend.preflight_finite_check()
67 assert ok is False