tenseleyflow/sway / bbaff71

Browse files

sway(backends): dummy as_scaled_adapter via linear blend of base and ft

Authored by espadonne
SHA
bbaff719c57c4d7e3d39b426af69d7774ee1bbc1
Parents
4cde232
Tree
3b4c448

1 changed file

StatusFile+-
M src/dlm_sway/backends/dummy.py 60 2
src/dlm_sway/backends/dummy.pymodified
@@ -117,6 +117,49 @@ class _DummyView:
117117
         )
118118
 
119119
 
120
+class _InterpolatedView(_DummyView):
121
+    """A dummy view where logits/dists are a lam-blend of base and ft.
122
+
123
+    Used by :meth:`DummyDifferentialBackend.as_scaled_adapter`.
124
+    Generation falls back to the ft view at lam>=0.5, base otherwise —
125
+    rounded because the dummy backend's generations are canned strings
126
+    with no notion of "how much".
127
+    """
128
+
129
+    def __init__(
130
+        self,
131
+        base_responses: DummyResponses,
132
+        ft_responses: DummyResponses,
133
+        lam: float,
134
+    ) -> None:
135
+        super().__init__(
136
+            "ft" if lam >= 0.5 else "base", ft_responses if lam >= 0.5 else base_responses
137
+        )
138
+        self._base_r = base_responses
139
+        self._ft_r = ft_responses
140
+        self._lam = lam
141
+
142
+    def logprob_of(self, prompt: str, completion: str) -> float:
143
+        base_v = self._base_r.logprobs.get((prompt, completion), -10.0)
144
+        ft_v = self._ft_r.logprobs.get((prompt, completion), -10.0)
145
+        return (1 - self._lam) * base_v + self._lam * ft_v
146
+
147
+    def next_token_dist(self, prompt: str, *, top_k: int = 256):  # type: ignore[no-untyped-def]
148
+        base_dist = _DummyView("base", self._base_r).next_token_dist(prompt, top_k=top_k)
149
+        ft_dist = _DummyView("ft", self._ft_r).next_token_dist(prompt, top_k=top_k)
150
+        # Both dists are on the same synthetic support when unseeded; blend
151
+        # their logprobs via log-space linear interpolation, which is a
152
+        # log-linear "tempered" mix and keeps normalization close enough.
153
+        lam = self._lam
154
+        blended_lp = (1 - lam) * base_dist.logprobs + lam * ft_dist.logprobs
155
+        return type(base_dist)(
156
+            token_ids=base_dist.token_ids,
157
+            logprobs=blended_lp,
158
+            vocab_size=base_dist.vocab_size,
159
+            tail_logprob=base_dist.tail_logprob,
160
+        )
161
+
162
+
120163
 class DummyDifferentialBackend:
121164
     """Dummy implementation of
122165
     :class:`~dlm_sway.core.scoring.DifferentialBackend`.
@@ -125,12 +168,19 @@ class DummyDifferentialBackend:
125168
     modes are mutually exclusive — the backend enforces that callers
126169
     exit one view before entering the other, catching bugs in probes
127170
     that hold a stale view across a toggle.
171
+
172
+    Also implements
173
+    :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend` with a
174
+    linear-blend between base and ft responses, so probes that need
175
+    ``as_scaled_adapter`` (N2 AdapterAblation) are unit-testable.
128176
     """
129177
 
130178
     def __init__(self, *, base: DummyResponses, ft: DummyResponses) -> None:
179
+        self._base_r = base
180
+        self._ft_r = ft
131181
         self._base = _DummyView("base", base)
132182
         self._ft = _DummyView("ft", ft)
133
-        self._active: Mode | None = None
183
+        self._active: str | None = None
134184
 
135185
     @contextmanager
136186
     def as_base(self) -> Iterator[_DummyView]:
@@ -148,7 +198,15 @@ class DummyDifferentialBackend:
148198
         finally:
149199
             self._exit()
150200
 
151
-    def _enter(self, mode: Mode) -> None:
201
+    @contextmanager
202
+    def as_scaled_adapter(self, lam: float) -> Iterator[_DummyView]:
203
+        self._enter(f"scaled({lam})")
204
+        try:
205
+            yield _InterpolatedView(self._base_r, self._ft_r, lam)
206
+        finally:
207
+            self._exit()
208
+
209
+    def _enter(self, mode: str) -> None:
152210
         if self._active is not None:
153211
             raise RuntimeError(
154212
                 f"DifferentialBackend view already active ({self._active!r}); "