Python · 4013 bytes Raw Blame History
1 """Tests for :class:`dlm_sway.backends.dummy.DummyDifferentialBackend`.
2
3 The dummy backend is used by every downstream probe unit test, so it
4 gets a thorough own-right test here. Also verifies the view-exclusion
5 invariant that catches stale-view bugs in probes.
6 """
7
8 from __future__ import annotations
9
10 import numpy as np
11 import pytest
12
13 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14 from dlm_sway.core.model import Model
15 from dlm_sway.core.scoring import DifferentialBackend, ScoringBackend
16
17
18 @pytest.fixture
19 def backend() -> DummyDifferentialBackend:
20 base = DummyResponses(
21 generations={"hi": "hello"},
22 logprobs={("q", "a"): -3.0},
23 )
24 ft = DummyResponses(
25 generations={"hi": "greetings, traveler"},
26 logprobs={("q", "a"): -1.2},
27 )
28 return DummyDifferentialBackend(base=base, ft=ft)
29
30
31 class TestViews:
32 def test_as_base_and_as_ft_yield_distinct_generations(
33 self, backend: DummyDifferentialBackend
34 ) -> None:
35 with backend.as_base() as b:
36 assert b.generate("hi", max_new_tokens=5) == "hello"
37 with backend.as_finetuned() as f:
38 assert f.generate("hi", max_new_tokens=5) == "greetings, traveler"
39
40 def test_logprob_differs_between_modes(self, backend: DummyDifferentialBackend) -> None:
41 with backend.as_base() as b:
42 base_score = b.logprob_of("q", "a")
43 with backend.as_finetuned() as f:
44 ft_score = f.logprob_of("q", "a")
45 assert base_score == -3.0
46 assert ft_score == -1.2
47
48 def test_missing_generation_raises_keyerror(self, backend: DummyDifferentialBackend) -> None:
49 with backend.as_base() as b, pytest.raises(KeyError, match="no canned generation"):
50 b.generate("unconfigured", max_new_tokens=1)
51
52 def test_missing_logprob_default(self, backend: DummyDifferentialBackend) -> None:
53 with backend.as_base() as b:
54 assert b.logprob_of("nonexistent", "target") == -10.0
55
56
57 class TestRollingLogprob:
58 def test_synthesized_when_not_preseeded(self, backend: DummyDifferentialBackend) -> None:
59 with backend.as_base() as b:
60 r = b.rolling_logprob("a quick brown fox jumps")
61 assert r.num_tokens == 5
62 assert r.logprobs.size == 4
63 assert np.all(r.logprobs == -2.0)
64
65 def test_ft_perplexity_lower_than_base(self, backend: DummyDifferentialBackend) -> None:
66 text = "a quick brown fox"
67 with backend.as_base() as b:
68 pb = b.rolling_logprob(text).perplexity
69 with backend.as_finetuned() as f:
70 pf = f.rolling_logprob(text).perplexity
71 assert pf < pb # synthesized ft is less perplexed → lower PPL
72
73
74 class TestTokenDist:
75 def test_dists_differ_between_modes(self, backend: DummyDifferentialBackend) -> None:
76 with backend.as_base() as b:
77 base_dist = b.next_token_dist("any prompt")
78 with backend.as_finetuned() as f:
79 ft_dist = f.next_token_dist("any prompt")
80 assert not np.array_equal(base_dist.logprobs, ft_dist.logprobs)
81
82
83 class TestInvariants:
84 def test_protocol_satisfaction(self, backend: DummyDifferentialBackend) -> None:
85 assert isinstance(backend, DifferentialBackend)
86 with backend.as_base() as view:
87 assert isinstance(view, Model)
88 assert isinstance(view, ScoringBackend)
89
90 def test_nested_views_rejected(self, backend: DummyDifferentialBackend) -> None:
91 with backend.as_base(), pytest.raises(RuntimeError, match="view already active"):
92 with backend.as_finetuned():
93 pass
94
95 def test_sequential_views_fine(self, backend: DummyDifferentialBackend) -> None:
96 # Must be able to re-enter after exiting — common pattern in probes.
97 with backend.as_base() as b:
98 b.logprob_of("q", "a")
99 with backend.as_finetuned() as f:
100 f.logprob_of("q", "a")
101 with backend.as_base() as b:
102 b.logprob_of("q", "a")