tenseleyflow/sway / 1f0f0e8

Browse files

sway(backends): DummyDifferentialBackend for unit tests

Authored by espadonne
SHA
1f0f0e8bdc05b392a933bea6649c19adf4f30842
Parents
0371e46
Tree
3455588

2 changed files

StatusFile+-
A src/dlm_sway/backends/dummy.py 160 0
A tests/unit/test_backend_dummy.py 102 0
src/dlm_sway/backends/dummy.pyadded
@@ -0,0 +1,160 @@
1
+"""In-memory backend for unit tests.
2
+
3
+Deterministic, torchless, and trivially fast. Tests pass canned responses
4
+and canned score tables keyed by ``(mode, prompt, completion)``. The same
5
+backend instance serves as both ``as_base`` and ``as_finetuned`` — it
6
+switches an internal mode flag.
7
+
8
+Use it to drive every probe's unit test without loading a real model.
9
+For integration tests against a real PEFT adapter, see
10
+:class:`~dlm_sway.backends.hf.HuggingFaceDifferentialBackend`.
11
+"""
12
+
13
+from __future__ import annotations
14
+
15
+import math
16
+from collections.abc import Iterator
17
+from contextlib import contextmanager
18
+from dataclasses import dataclass, field
19
+from typing import Literal
20
+
21
+import numpy as np
22
+
23
+from dlm_sway.core.scoring import RollingLogprob, TokenDist
24
+
25
+Mode = Literal["base", "ft"]
26
+
27
+
28
+@dataclass(slots=True)
29
+class DummyResponses:
30
+    """Canned data for one mode (base or ft).
31
+
32
+    Callers populate one of these per mode and hand both to
33
+    :class:`DummyDifferentialBackend`.
34
+    """
35
+
36
+    generations: dict[str, str] = field(default_factory=dict)
37
+    """Prompt → canned completion. Lookup is exact-match."""
38
+    logprobs: dict[tuple[str, str], float] = field(default_factory=dict)
39
+    """``(prompt, completion) → sum logprob``. Default ``-10.0`` if missing."""
40
+    rolling: dict[str, RollingLogprob] = field(default_factory=dict)
41
+    """Text → canned :class:`RollingLogprob`."""
42
+    token_dists: dict[str, TokenDist] = field(default_factory=dict)
43
+    """Prompt → canned :class:`TokenDist`."""
44
+
45
+
46
+class _DummyView:
47
+    """The per-mode view yielded by ``as_base`` / ``as_finetuned``.
48
+
49
+    Implements :class:`~dlm_sway.core.model.Model` *and*
50
+    :class:`~dlm_sway.core.scoring.ScoringBackend` — i.e. the
51
+    ``ScoringModel`` intersection.
52
+    """
53
+
54
+    def __init__(self, mode: Mode, responses: DummyResponses) -> None:
55
+        self.id = mode
56
+        self._mode: Mode = mode
57
+        self._r = responses
58
+
59
+    # -- Model ---------------------------------------------------------
60
+    def generate(
61
+        self,
62
+        prompt: str,
63
+        *,
64
+        max_new_tokens: int,
65
+        temperature: float = 0.0,
66
+        top_p: float = 1.0,
67
+        seed: int = 0,
68
+    ) -> str:
69
+        del max_new_tokens, temperature, top_p, seed  # canned; decoding is trivial.
70
+        try:
71
+            return self._r.generations[prompt]
72
+        except KeyError as exc:
73
+            raise KeyError(
74
+                f"dummy backend ({self._mode}): no canned generation for prompt {prompt!r}"
75
+            ) from exc
76
+
77
+    def close(self) -> None:
78
+        return None
79
+
80
+    # -- ScoringBackend ------------------------------------------------
81
+    def logprob_of(self, prompt: str, completion: str) -> float:
82
+        return self._r.logprobs.get((prompt, completion), -10.0)
83
+
84
+    def rolling_logprob(self, text: str) -> RollingLogprob:
85
+        if text in self._r.rolling:
86
+            return self._r.rolling[text]
87
+        # Synthesize a plausible rolling logprob so probes that just
88
+        # want a non-trivial value work without per-text configuration.
89
+        tokens = text.split()
90
+        n = max(len(tokens), 1)
91
+        per_tok = -2.0 if self._mode == "base" else -1.5
92
+        return RollingLogprob(
93
+            token_ids=np.arange(n, dtype=np.int64),
94
+            logprobs=np.full(max(n - 1, 0), per_tok, dtype=np.float32),
95
+            num_tokens=n,
96
+            total_logprob=per_tok * max(n - 1, 0),
97
+        )
98
+
99
+    def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
100
+        del top_k
101
+        if prompt in self._r.token_dists:
102
+            return self._r.token_dists[prompt]
103
+        # Synthesize a sharp base / broad ft distribution so divergence
104
+        # probes see a non-zero signal without hand-rolled data.
105
+        vocab = 1000
106
+        k = 8
107
+        if self._mode == "base":
108
+            lp = np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32)
109
+        else:
110
+            # More uniform mass across the top-k tokens.
111
+            lp = np.full(k, -math.log(k), dtype=np.float32)
112
+        return TokenDist(
113
+            token_ids=np.arange(k, dtype=np.int64),
114
+            logprobs=lp,
115
+            vocab_size=vocab,
116
+            tail_logprob=math.log1p(-float(np.exp(lp).sum())) if np.exp(lp).sum() < 1 else 0.0,
117
+        )
118
+
119
+
120
+class DummyDifferentialBackend:
121
+    """Dummy implementation of
122
+    :class:`~dlm_sway.core.scoring.DifferentialBackend`.
123
+
124
+    Construction takes one :class:`DummyResponses` per mode. The two
125
+    modes are mutually exclusive — the backend enforces that callers
126
+    exit one view before entering the other, catching bugs in probes
127
+    that hold a stale view across a toggle.
128
+    """
129
+
130
+    def __init__(self, *, base: DummyResponses, ft: DummyResponses) -> None:
131
+        self._base = _DummyView("base", base)
132
+        self._ft = _DummyView("ft", ft)
133
+        self._active: Mode | None = None
134
+
135
+    @contextmanager
136
+    def as_base(self) -> Iterator[_DummyView]:
137
+        self._enter("base")
138
+        try:
139
+            yield self._base
140
+        finally:
141
+            self._exit()
142
+
143
+    @contextmanager
144
+    def as_finetuned(self) -> Iterator[_DummyView]:
145
+        self._enter("ft")
146
+        try:
147
+            yield self._ft
148
+        finally:
149
+            self._exit()
150
+
151
+    def _enter(self, mode: Mode) -> None:
152
+        if self._active is not None:
153
+            raise RuntimeError(
154
+                f"DifferentialBackend view already active ({self._active!r}); "
155
+                f"exit the current view before entering {mode!r}."
156
+            )
157
+        self._active = mode
158
+
159
+    def _exit(self) -> None:
160
+        self._active = None
tests/unit/test_backend_dummy.pyadded
@@ -0,0 +1,102 @@
1
+"""Tests for :class:`dlm_sway.backends.dummy.DummyDifferentialBackend`.
2
+
3
+The dummy backend is used by every downstream probe unit test, so it
4
+gets a thorough own-right test here. Also verifies the view-exclusion
5
+invariant that catches stale-view bugs in probes.
6
+"""
7
+
8
+from __future__ import annotations
9
+
10
+import numpy as np
11
+import pytest
12
+
13
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14
+from dlm_sway.core.model import Model
15
+from dlm_sway.core.scoring import DifferentialBackend, ScoringBackend
16
+
17
+
18
+@pytest.fixture
19
+def backend() -> DummyDifferentialBackend:
20
+    base = DummyResponses(
21
+        generations={"hi": "hello"},
22
+        logprobs={("q", "a"): -3.0},
23
+    )
24
+    ft = DummyResponses(
25
+        generations={"hi": "greetings, traveler"},
26
+        logprobs={("q", "a"): -1.2},
27
+    )
28
+    return DummyDifferentialBackend(base=base, ft=ft)
29
+
30
+
31
+class TestViews:
32
+    def test_as_base_and_as_ft_yield_distinct_generations(
33
+        self, backend: DummyDifferentialBackend
34
+    ) -> None:
35
+        with backend.as_base() as b:
36
+            assert b.generate("hi", max_new_tokens=5) == "hello"
37
+        with backend.as_finetuned() as f:
38
+            assert f.generate("hi", max_new_tokens=5) == "greetings, traveler"
39
+
40
+    def test_logprob_differs_between_modes(self, backend: DummyDifferentialBackend) -> None:
41
+        with backend.as_base() as b:
42
+            base_score = b.logprob_of("q", "a")
43
+        with backend.as_finetuned() as f:
44
+            ft_score = f.logprob_of("q", "a")
45
+        assert base_score == -3.0
46
+        assert ft_score == -1.2
47
+
48
+    def test_missing_generation_raises_keyerror(self, backend: DummyDifferentialBackend) -> None:
49
+        with backend.as_base() as b, pytest.raises(KeyError, match="no canned generation"):
50
+            b.generate("unconfigured", max_new_tokens=1)
51
+
52
+    def test_missing_logprob_default(self, backend: DummyDifferentialBackend) -> None:
53
+        with backend.as_base() as b:
54
+            assert b.logprob_of("nonexistent", "target") == -10.0
55
+
56
+
57
+class TestRollingLogprob:
58
+    def test_synthesized_when_not_preseeded(self, backend: DummyDifferentialBackend) -> None:
59
+        with backend.as_base() as b:
60
+            r = b.rolling_logprob("a quick brown fox jumps")
61
+        assert r.num_tokens == 5
62
+        assert r.logprobs.size == 4
63
+        assert np.all(r.logprobs == -2.0)
64
+
65
+    def test_ft_perplexity_lower_than_base(self, backend: DummyDifferentialBackend) -> None:
66
+        text = "a quick brown fox"
67
+        with backend.as_base() as b:
68
+            pb = b.rolling_logprob(text).perplexity
69
+        with backend.as_finetuned() as f:
70
+            pf = f.rolling_logprob(text).perplexity
71
+        assert pf < pb  # synthesized ft is less perplexed → lower PPL
72
+
73
+
74
+class TestTokenDist:
75
+    def test_dists_differ_between_modes(self, backend: DummyDifferentialBackend) -> None:
76
+        with backend.as_base() as b:
77
+            base_dist = b.next_token_dist("any prompt")
78
+        with backend.as_finetuned() as f:
79
+            ft_dist = f.next_token_dist("any prompt")
80
+        assert not np.array_equal(base_dist.logprobs, ft_dist.logprobs)
81
+
82
+
83
+class TestInvariants:
84
+    def test_protocol_satisfaction(self, backend: DummyDifferentialBackend) -> None:
85
+        assert isinstance(backend, DifferentialBackend)
86
+        with backend.as_base() as view:
87
+            assert isinstance(view, Model)
88
+            assert isinstance(view, ScoringBackend)
89
+
90
+    def test_nested_views_rejected(self, backend: DummyDifferentialBackend) -> None:
91
+        with backend.as_base(), pytest.raises(RuntimeError, match="view already active"):
92
+            with backend.as_finetuned():
93
+                pass
94
+
95
+    def test_sequential_views_fine(self, backend: DummyDifferentialBackend) -> None:
96
+        # Must be able to re-enter after exiting — common pattern in probes.
97
+        with backend.as_base() as b:
98
+            b.logprob_of("q", "a")
99
+        with backend.as_finetuned() as f:
100
+            f.logprob_of("q", "a")
101
+        with backend.as_base() as b:
102
+            b.logprob_of("q", "a")