@@ -117,6 +117,49 @@ class _DummyView: |
| 117 | 117 | ) |
| 118 | 118 | |
| 119 | 119 | |
| 120 | +class _InterpolatedView(_DummyView): |
| 121 | + """A dummy view where logits/dists are a lam-blend of base and ft. |
| 122 | + |
| 123 | + Used by :meth:`DummyDifferentialBackend.as_scaled_adapter`. |
| 124 | + Generation falls back to the ft view at lam>=0.5, base otherwise — |
| 125 | + rounded because the dummy backend's generations are canned strings |
| 126 | + with no notion of "how much". |
| 127 | + """ |
| 128 | + |
| 129 | + def __init__( |
| 130 | + self, |
| 131 | + base_responses: DummyResponses, |
| 132 | + ft_responses: DummyResponses, |
| 133 | + lam: float, |
| 134 | + ) -> None: |
| 135 | + super().__init__( |
| 136 | + "ft" if lam >= 0.5 else "base", ft_responses if lam >= 0.5 else base_responses |
| 137 | + ) |
| 138 | + self._base_r = base_responses |
| 139 | + self._ft_r = ft_responses |
| 140 | + self._lam = lam |
| 141 | + |
| 142 | + def logprob_of(self, prompt: str, completion: str) -> float: |
| 143 | + base_v = self._base_r.logprobs.get((prompt, completion), -10.0) |
| 144 | + ft_v = self._ft_r.logprobs.get((prompt, completion), -10.0) |
| 145 | + return (1 - self._lam) * base_v + self._lam * ft_v |
| 146 | + |
| 147 | + def next_token_dist(self, prompt: str, *, top_k: int = 256): # type: ignore[no-untyped-def] |
| 148 | + base_dist = _DummyView("base", self._base_r).next_token_dist(prompt, top_k=top_k) |
| 149 | + ft_dist = _DummyView("ft", self._ft_r).next_token_dist(prompt, top_k=top_k) |
| 150 | + # Both dists are on the same synthetic support when unseeded; blend |
| 151 | + # their logprobs via log-space linear interpolation, which is a |
| 152 | + # log-linear "tempered" mix and keeps normalization close enough. |
| 153 | + lam = self._lam |
| 154 | + blended_lp = (1 - lam) * base_dist.logprobs + lam * ft_dist.logprobs |
| 155 | + return type(base_dist)( |
| 156 | + token_ids=base_dist.token_ids, |
| 157 | + logprobs=blended_lp, |
| 158 | + vocab_size=base_dist.vocab_size, |
| 159 | + tail_logprob=base_dist.tail_logprob, |
| 160 | + ) |
| 161 | + |
| 162 | + |
| 120 | 163 | class DummyDifferentialBackend: |
| 121 | 164 | """Dummy implementation of |
| 122 | 165 | :class:`~dlm_sway.core.scoring.DifferentialBackend`. |
@@ -125,12 +168,19 @@ class DummyDifferentialBackend: |
| 125 | 168 | modes are mutually exclusive — the backend enforces that callers |
| 126 | 169 | exit one view before entering the other, catching bugs in probes |
| 127 | 170 | that hold a stale view across a toggle. |
| 171 | + |
| 172 | + Also implements |
| 173 | + :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend` with a |
| 174 | + linear-blend between base and ft responses, so probes that need |
| 175 | + ``as_scaled_adapter`` (N2 AdapterAblation) are unit-testable. |
| 128 | 176 | """ |
| 129 | 177 | |
| 130 | 178 | def __init__(self, *, base: DummyResponses, ft: DummyResponses) -> None: |
| 179 | + self._base_r = base |
| 180 | + self._ft_r = ft |
| 131 | 181 | self._base = _DummyView("base", base) |
| 132 | 182 | self._ft = _DummyView("ft", ft) |
| 133 | | - self._active: Mode | None = None |
| 183 | + self._active: str | None = None |
| 134 | 184 | |
| 135 | 185 | @contextmanager |
| 136 | 186 | def as_base(self) -> Iterator[_DummyView]: |
@@ -148,7 +198,15 @@ class DummyDifferentialBackend: |
| 148 | 198 | finally: |
| 149 | 199 | self._exit() |
| 150 | 200 | |
| 151 | | - def _enter(self, mode: Mode) -> None: |
| 201 | + @contextmanager |
| 202 | + def as_scaled_adapter(self, lam: float) -> Iterator[_DummyView]: |
| 203 | + self._enter(f"scaled({lam})") |
| 204 | + try: |
| 205 | + yield _InterpolatedView(self._base_r, self._ft_r, lam) |
| 206 | + finally: |
| 207 | + self._exit() |
| 208 | + |
| 209 | + def _enter(self, mode: str) -> None: |
| 152 | 210 | if self._active is not None: |
| 153 | 211 | raise RuntimeError( |
| 154 | 212 | f"DifferentialBackend view already active ({self._active!r}); " |