| 1 | """Tests for ``preflight_finite_check`` on the shipped backends. |
| 2 | |
| 3 | The HF backend's check is exercised in the integration suite (it needs |
| 4 | a real model). Here we verify the dummy backend's contract and the |
| 5 | PreflightCheckable Protocol shape. |
| 6 | """ |
| 7 | |
| 8 | from __future__ import annotations |
| 9 | |
| 10 | import math |
| 11 | |
| 12 | import numpy as np |
| 13 | |
| 14 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 15 | from dlm_sway.core.scoring import PreflightCheckable, TokenDist |
| 16 | |
| 17 | |
| 18 | class TestProtocolShape: |
| 19 | def test_dummy_satisfies_preflight_protocol(self) -> None: |
| 20 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 21 | assert isinstance(backend, PreflightCheckable) |
| 22 | |
| 23 | |
| 24 | class TestDummyDefaultIsFinite: |
| 25 | def test_default_dummy_passes_preflight(self) -> None: |
| 26 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses()) |
| 27 | ok, reason = backend.preflight_finite_check() |
| 28 | assert ok is True |
| 29 | assert reason == "" |
| 30 | |
| 31 | |
| 32 | class TestDummyNanDetection: |
| 33 | def test_nan_in_ft_token_dist_caught(self) -> None: |
| 34 | nan_dist = TokenDist( |
| 35 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 36 | logprobs=np.array([-0.1, math.nan, -2.0], dtype=np.float32), |
| 37 | vocab_size=100, |
| 38 | ) |
| 39 | ft = DummyResponses(token_dists={"preflight": nan_dist}) |
| 40 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft) |
| 41 | ok, reason = backend.preflight_finite_check() |
| 42 | assert ok is False |
| 43 | assert "ft view" in reason |
| 44 | assert "non-finite" in reason |
| 45 | |
| 46 | def test_nan_in_base_token_dist_caught(self) -> None: |
| 47 | nan_dist = TokenDist( |
| 48 | token_ids=np.array([1, 2], dtype=np.int64), |
| 49 | logprobs=np.array([math.nan, math.nan], dtype=np.float32), |
| 50 | vocab_size=100, |
| 51 | ) |
| 52 | base = DummyResponses(token_dists={"preflight": nan_dist}) |
| 53 | backend = DummyDifferentialBackend(base=base, ft=DummyResponses()) |
| 54 | ok, reason = backend.preflight_finite_check() |
| 55 | assert ok is False |
| 56 | assert "base view" in reason |
| 57 | |
| 58 | def test_inf_caught(self) -> None: |
| 59 | inf_dist = TokenDist( |
| 60 | token_ids=np.array([1, 2], dtype=np.int64), |
| 61 | logprobs=np.array([-0.5, math.inf], dtype=np.float32), |
| 62 | vocab_size=100, |
| 63 | ) |
| 64 | ft = DummyResponses(token_dists={"preflight": inf_dist}) |
| 65 | backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft) |
| 66 | ok, _reason = backend.preflight_finite_check() |
| 67 | assert ok is False |