| 1 | """In-memory backend for unit tests. |
| 2 | |
| 3 | Deterministic, torchless, and trivially fast. Tests pass canned responses |
| 4 | and canned score tables keyed by ``(mode, prompt, completion)``. The same |
| 5 | backend instance serves as both ``as_base`` and ``as_finetuned`` — it |
| 6 | switches an internal mode flag. |
| 7 | |
| 8 | Use it to drive every probe's unit test without loading a real model. |
| 9 | For integration tests against a real PEFT adapter, see |
| 10 | :class:`~dlm_sway.backends.hf.HuggingFaceDifferentialBackend`. |
| 11 | """ |
| 12 | |
| 13 | from __future__ import annotations |
| 14 | |
| 15 | import hashlib |
| 16 | import math |
| 17 | from collections.abc import Iterator, Sequence |
| 18 | from contextlib import contextmanager |
| 19 | from dataclasses import dataclass, field |
| 20 | from typing import Literal |
| 21 | |
| 22 | import numpy as np |
| 23 | |
| 24 | from dlm_sway.backends._instrumentation import BackendInstrumentation |
| 25 | from dlm_sway.core.scoring import RollingLogprob, TokenDist |
| 26 | |
| 27 | Mode = Literal["base", "ft"] |
| 28 | |
| 29 | |
| 30 | @dataclass(slots=True) |
| 31 | class DummyResponses: |
| 32 | """Canned data for one mode (base or ft). |
| 33 | |
| 34 | Callers populate one of these per mode and hand both to |
| 35 | :class:`DummyDifferentialBackend`. |
| 36 | """ |
| 37 | |
| 38 | generations: dict[str, str] = field(default_factory=dict) |
| 39 | """Prompt → canned completion. Lookup is exact-match.""" |
| 40 | logprobs: dict[tuple[str, str], float] = field(default_factory=dict) |
| 41 | """``(prompt, completion) → sum logprob``. Default ``-10.0`` if missing.""" |
| 42 | rolling: dict[str, RollingLogprob] = field(default_factory=dict) |
| 43 | """Text → canned :class:`RollingLogprob`.""" |
| 44 | token_dists: dict[str, TokenDist] = field(default_factory=dict) |
| 45 | """Prompt → canned :class:`TokenDist`.""" |
| 46 | |
| 47 | |
| 48 | class _DummyView: |
| 49 | """The per-mode view yielded by ``as_base`` / ``as_finetuned``. |
| 50 | |
| 51 | Implements :class:`~dlm_sway.core.model.Model` *and* |
| 52 | :class:`~dlm_sway.core.scoring.ScoringBackend` — i.e. the |
| 53 | ``ScoringModel`` intersection. |
| 54 | """ |
| 55 | |
| 56 | def __init__( |
| 57 | self, |
| 58 | mode: Mode, |
| 59 | responses: DummyResponses, |
| 60 | inst: BackendInstrumentation | None = None, |
| 61 | ) -> None: |
| 62 | # ``id`` is what surfaces in cache keys and trace events; widened |
| 63 | # to ``str`` so scaled / null views can override it with |
| 64 | # e.g. ``"scaled_0.50"`` / ``"null_42"`` without a cast dance. |
| 65 | self.id: str = mode |
| 66 | self._mode: Mode = mode |
| 67 | self._r = responses |
| 68 | # Private instrumentation when the caller didn't supply one — |
| 69 | # keeps direct ``_DummyView("base", DummyResponses())`` |
| 70 | # constructions (common in probe tests that grab a view |
| 71 | # without going through the differential backend) working |
| 72 | # transparently. |
| 73 | self._inst: BackendInstrumentation = inst if inst is not None else BackendInstrumentation() |
| 74 | |
| 75 | # -- Model --------------------------------------------------------- |
| 76 | def generate( |
| 77 | self, |
| 78 | prompt: str, |
| 79 | *, |
| 80 | max_new_tokens: int, |
| 81 | temperature: float = 0.0, |
| 82 | top_p: float = 1.0, |
| 83 | seed: int = 0, |
| 84 | ) -> str: |
| 85 | del max_new_tokens, temperature, top_p, seed # canned; decoding is trivial. |
| 86 | try: |
| 87 | return self._r.generations[prompt] |
| 88 | except KeyError as exc: |
| 89 | raise KeyError( |
| 90 | f"dummy backend ({self._mode}): no canned generation for prompt {prompt!r}" |
| 91 | ) from exc |
| 92 | |
| 93 | def close(self) -> None: |
| 94 | return None |
| 95 | |
| 96 | # -- ScoringBackend ------------------------------------------------ |
| 97 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 98 | return self._inst.cached( |
| 99 | "logprob_of", |
| 100 | self.id, |
| 101 | f"{prompt}\x00{completion}", |
| 102 | 0, |
| 103 | lambda: self._r.logprobs.get((prompt, completion), -10.0), |
| 104 | ) |
| 105 | |
| 106 | def rolling_logprob(self, text: str) -> RollingLogprob: |
| 107 | return self._inst.cached( |
| 108 | "rolling_logprob", |
| 109 | self.id, |
| 110 | text, |
| 111 | 0, |
| 112 | lambda: self._compute_rolling_logprob(text), |
| 113 | ) |
| 114 | |
| 115 | def _compute_rolling_logprob(self, text: str) -> RollingLogprob: |
| 116 | if text in self._r.rolling: |
| 117 | return self._r.rolling[text] |
| 118 | # Synthesize a plausible rolling logprob so probes that just |
| 119 | # want a non-trivial value work without per-text configuration. |
| 120 | tokens = text.split() |
| 121 | n = max(len(tokens), 1) |
| 122 | per_tok = -2.0 if self._mode == "base" else -1.5 |
| 123 | return RollingLogprob( |
| 124 | token_ids=np.arange(n, dtype=np.int64), |
| 125 | logprobs=np.full(max(n - 1, 0), per_tok, dtype=np.float32), |
| 126 | num_tokens=n, |
| 127 | total_logprob=per_tok * max(n - 1, 0), |
| 128 | ) |
| 129 | |
| 130 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 131 | return self._inst.cached( |
| 132 | "next_token_dist", |
| 133 | self.id, |
| 134 | prompt, |
| 135 | top_k, |
| 136 | lambda: self._compute_next_token_dist(prompt, top_k=top_k), |
| 137 | ) |
| 138 | |
| 139 | def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]: |
| 140 | # Dummy backend has no real forward to batch against. Loop |
| 141 | # over ``self.next_token_dist`` so subclasses (``_NullView``, |
| 142 | # ``_InterpolatedView``) that override the per-prompt method |
| 143 | # get their per-prompt semantics preserved on the batched |
| 144 | # code path. Batching counters stay at zero on this backend — |
| 145 | # real amortization lives in the HF backend's override. |
| 146 | return [self.next_token_dist(p, top_k=top_k) for p in prompts] |
| 147 | |
| 148 | def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 149 | del top_k |
| 150 | if prompt in self._r.token_dists: |
| 151 | return self._r.token_dists[prompt] |
| 152 | # Synthesize a sharp base / broad ft distribution so divergence |
| 153 | # probes see a non-zero signal without hand-rolled data. |
| 154 | vocab = 1000 |
| 155 | k = 8 |
| 156 | if self._mode == "base": |
| 157 | lp = np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32) |
| 158 | else: |
| 159 | # More uniform mass across the top-k tokens. A *literally* |
| 160 | # uniform distribution looks like a broken lm_head to |
| 161 | # ``_divergence``'s degenerate-distribution guard, so we |
| 162 | # add a tiny monotonic perturbation — the spread stays |
| 163 | # small enough that the ft dist still looks "broad" to |
| 164 | # KL/JS but clears the 1e-9 uniformity threshold a real |
| 165 | # model would also clear via fp32 accumulation noise. |
| 166 | lp = np.full(k, -math.log(k), dtype=np.float32) |
| 167 | lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32) |
| 168 | # B6: tail_logprob=None means "no measurable tail" (k covers vocab |
| 169 | # or residual underflowed); reserve floats for measurable mass. |
| 170 | residual = 1.0 - float(np.exp(lp).sum()) |
| 171 | tail_lp = math.log(residual) if residual > 1e-12 else None |
| 172 | return TokenDist( |
| 173 | token_ids=np.arange(k, dtype=np.int64), |
| 174 | logprobs=lp, |
| 175 | vocab_size=vocab, |
| 176 | tail_logprob=tail_lp, |
| 177 | ) |
| 178 | |
| 179 | |
| 180 | class _NullView(_DummyView): |
| 181 | """A dummy view that perturbs the base distribution with seeded noise. |
| 182 | |
| 183 | Used by :meth:`DummyDifferentialBackend.as_null_adapter`. The |
| 184 | perturbation is small (matches an ``init_scale=0.02`` adapter) so |
| 185 | the null-vs-base divergence stays well below real-adapter territory |
| 186 | in probe tests. |
| 187 | |
| 188 | ``rank_scale`` simulates changing the effective LoRA rank: the |
| 189 | output variance of ``A·B`` scales linearly with rank, so the noise |
| 190 | std carries a ``sqrt(rank_scale)`` factor. Default 1.0 preserves |
| 191 | pre-S10 behavior exactly. |
| 192 | """ |
| 193 | |
| 194 | def __init__( |
| 195 | self, |
| 196 | base_responses: DummyResponses, |
| 197 | seed: int, |
| 198 | init_scale: float, |
| 199 | rank_scale: float = 1.0, |
| 200 | ) -> None: |
| 201 | super().__init__("base", base_responses) |
| 202 | self._seed = seed |
| 203 | self._init_scale = init_scale |
| 204 | self._rank_scale = rank_scale |
| 205 | |
| 206 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 207 | base_dist = super().next_token_dist(prompt, top_k=top_k) |
| 208 | # F08 — Python's built-in ``hash(str)`` is salted per-process via |
| 209 | # ``PYTHONHASHSEED``, so the same ``(self._seed, prompt)`` pair |
| 210 | # produced different RNG streams across interpreter invocations. |
| 211 | # That violated the README's determinism contract and meant the |
| 212 | # null-stats disk cache (~/.dlm-sway/null-stats/) could serve |
| 213 | # stale values on a restart. ``hashlib.md5`` is stable; taking |
| 214 | # the first 8 hex digits keeps the arithmetic in 32-bit range. |
| 215 | prompt_hash = int(hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8], 16) |
| 216 | rng = np.random.default_rng(self._seed + prompt_hash % 1_000_003) |
| 217 | effective_scale = self._init_scale * math.sqrt(self._rank_scale) |
| 218 | noise = rng.normal(0.0, effective_scale, size=base_dist.logprobs.shape).astype(np.float32) |
| 219 | new_lp = base_dist.logprobs + noise |
| 220 | # Re-normalize (within the top-k slice) so a valid distribution comes back. |
| 221 | max_lp = new_lp.max() |
| 222 | new_probs = np.exp(new_lp - max_lp) |
| 223 | new_probs /= new_probs.sum() |
| 224 | return TokenDist( |
| 225 | token_ids=base_dist.token_ids, |
| 226 | logprobs=np.log(new_probs).astype(np.float32), |
| 227 | vocab_size=base_dist.vocab_size, |
| 228 | tail_logprob=base_dist.tail_logprob, |
| 229 | ) |
| 230 | |
| 231 | |
| 232 | class _InterpolatedView(_DummyView): |
| 233 | """A dummy view where logits/dists are a lam-blend of base and ft. |
| 234 | |
| 235 | Used by :meth:`DummyDifferentialBackend.as_scaled_adapter`. |
| 236 | Generation falls back to the ft view at lam>=0.5, base otherwise — |
| 237 | rounded because the dummy backend's generations are canned strings |
| 238 | with no notion of "how much". |
| 239 | """ |
| 240 | |
| 241 | def __init__( |
| 242 | self, |
| 243 | base_responses: DummyResponses, |
| 244 | ft_responses: DummyResponses, |
| 245 | lam: float, |
| 246 | ) -> None: |
| 247 | super().__init__( |
| 248 | "ft" if lam >= 0.5 else "base", ft_responses if lam >= 0.5 else base_responses |
| 249 | ) |
| 250 | self._base_r = base_responses |
| 251 | self._ft_r = ft_responses |
| 252 | self._lam = lam |
| 253 | |
| 254 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 255 | base_v = self._base_r.logprobs.get((prompt, completion), -10.0) |
| 256 | ft_v = self._ft_r.logprobs.get((prompt, completion), -10.0) |
| 257 | return (1 - self._lam) * base_v + self._lam * ft_v |
| 258 | |
| 259 | def next_token_dist(self, prompt: str, *, top_k: int = 256): # type: ignore[no-untyped-def] |
| 260 | base_dist = _DummyView("base", self._base_r).next_token_dist(prompt, top_k=top_k) |
| 261 | ft_dist = _DummyView("ft", self._ft_r).next_token_dist(prompt, top_k=top_k) |
| 262 | # Both dists are on the same synthetic support when unseeded; blend |
| 263 | # their logprobs via log-space linear interpolation, which is a |
| 264 | # log-linear "tempered" mix and keeps normalization close enough. |
| 265 | lam = self._lam |
| 266 | blended_lp = (1 - lam) * base_dist.logprobs + lam * ft_dist.logprobs |
| 267 | return type(base_dist)( |
| 268 | token_ids=base_dist.token_ids, |
| 269 | logprobs=blended_lp, |
| 270 | vocab_size=base_dist.vocab_size, |
| 271 | tail_logprob=base_dist.tail_logprob, |
| 272 | ) |
| 273 | |
| 274 | |
| 275 | class DummyDifferentialBackend: |
| 276 | """Dummy implementation of |
| 277 | :class:`~dlm_sway.core.scoring.DifferentialBackend`. |
| 278 | |
| 279 | Construction takes one :class:`DummyResponses` per mode. The two |
| 280 | modes are mutually exclusive — the backend enforces that callers |
| 281 | exit one view before entering the other, catching bugs in probes |
| 282 | that hold a stale view across a toggle. |
| 283 | |
| 284 | Dummy declares ``safe_for_concurrent_views = False`` to mirror the |
| 285 | shipped backends' posture; tests that want to exercise the concurrent |
| 286 | scheduling path can subclass and set it ``True``. |
| 287 | |
| 288 | Also implements |
| 289 | :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend` with a |
| 290 | linear-blend between base and ft responses, so probes that need |
| 291 | ``as_scaled_adapter`` (N2 AdapterAblation) are unit-testable. |
| 292 | """ |
| 293 | |
| 294 | safe_for_concurrent_views: bool = False |
| 295 | |
| 296 | def __init__(self, *, base: DummyResponses, ft: DummyResponses) -> None: |
| 297 | self._base_r = base |
| 298 | self._ft_r = ft |
| 299 | # Sprint 07: one shared cache + trace + stats instance that |
| 300 | # every view yielded by this backend reads/writes. Tests can |
| 301 | # peek at ``backend._inst.stats`` to assert cache behavior. |
| 302 | self._inst = BackendInstrumentation() |
| 303 | self._base = _DummyView("base", base, inst=self._inst) |
| 304 | self._ft = _DummyView("ft", ft, inst=self._inst) |
| 305 | self._active: str | None = None |
| 306 | |
| 307 | @contextmanager |
| 308 | def as_base(self) -> Iterator[_DummyView]: |
| 309 | self._enter("base") |
| 310 | try: |
| 311 | yield self._base |
| 312 | finally: |
| 313 | self._exit() |
| 314 | |
| 315 | @contextmanager |
| 316 | def as_finetuned(self) -> Iterator[_DummyView]: |
| 317 | self._enter("ft") |
| 318 | try: |
| 319 | yield self._ft |
| 320 | finally: |
| 321 | self._exit() |
| 322 | |
| 323 | @contextmanager |
| 324 | def as_scaled_adapter(self, lam: float) -> Iterator[_DummyView]: |
| 325 | self._enter(f"scaled({lam})") |
| 326 | try: |
| 327 | view = _InterpolatedView(self._base_r, self._ft_r, lam) |
| 328 | view._inst = self._inst |
| 329 | view.id = f"scaled_{lam:.2f}" |
| 330 | yield view |
| 331 | finally: |
| 332 | self._exit() |
| 333 | |
| 334 | @contextmanager |
| 335 | def as_null_adapter( |
| 336 | self, |
| 337 | seed: int, |
| 338 | *, |
| 339 | init_scale: float = 0.02, |
| 340 | rank_scale: float = 1.0, |
| 341 | ) -> Iterator[_DummyView]: |
| 342 | if rank_scale <= 0.0 or not math.isfinite(rank_scale): |
| 343 | raise ValueError(f"rank_scale must be positive and finite; got {rank_scale!r}") |
| 344 | label = f"null({seed})" if rank_scale == 1.0 else f"null({seed},rank={rank_scale:.2f})" |
| 345 | view_id = f"null_{seed}" if rank_scale == 1.0 else f"null_{seed}_rank{rank_scale:.2f}" |
| 346 | self._enter(label) |
| 347 | try: |
| 348 | view = _NullView(self._base_r, seed=seed, init_scale=init_scale, rank_scale=rank_scale) |
| 349 | view._inst = self._inst |
| 350 | view.id = view_id |
| 351 | yield view |
| 352 | finally: |
| 353 | self._exit() |
| 354 | |
| 355 | def preflight_finite_check(self) -> tuple[bool, str]: |
| 356 | """Smoke a single forward pass per view; reject non-finite logits. |
| 357 | |
| 358 | For the dummy backend the canned data is finite by construction |
| 359 | unless tests deliberately seed NaN-laden ``TokenDist`` entries — |
| 360 | which is exactly what S01 tests do to verify the runner gate. |
| 361 | """ |
| 362 | prompt = "preflight" |
| 363 | try: |
| 364 | with self.as_base() as base_view: |
| 365 | base_dist = base_view.next_token_dist(prompt, top_k=8) |
| 366 | with self.as_finetuned() as ft_view: |
| 367 | ft_dist = ft_view.next_token_dist(prompt, top_k=8) |
| 368 | except Exception as exc: # noqa: BLE001 |
| 369 | return False, f"preflight raised {type(exc).__name__}: {exc}" |
| 370 | |
| 371 | for label, dist in (("base", base_dist), ("ft", ft_dist)): |
| 372 | if not np.all(np.isfinite(dist.logprobs)): |
| 373 | n_bad = int((~np.isfinite(dist.logprobs)).sum()) |
| 374 | return ( |
| 375 | False, |
| 376 | f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite " |
| 377 | f"logprob(s) on prompt {prompt!r}", |
| 378 | ) |
| 379 | return True, "" |
| 380 | |
| 381 | def _enter(self, mode: str) -> None: |
| 382 | if self._active is not None: |
| 383 | raise RuntimeError( |
| 384 | f"DifferentialBackend view already active ({self._active!r}); " |
| 385 | f"exit the current view before entering {mode!r}." |
| 386 | ) |
| 387 | self._active = mode |
| 388 | |
| 389 | def _exit(self) -> None: |
| 390 | self._active = None |