Python · 15586 bytes Raw Blame History
1 """In-memory backend for unit tests.
2
3 Deterministic, torchless, and trivially fast. Tests pass canned responses
4 and canned score tables keyed by ``(mode, prompt, completion)``. The same
5 backend instance serves as both ``as_base`` and ``as_finetuned`` — it
6 switches an internal mode flag.
7
8 Use it to drive every probe's unit test without loading a real model.
9 For integration tests against a real PEFT adapter, see
10 :class:`~dlm_sway.backends.hf.HuggingFaceDifferentialBackend`.
11 """
12
13 from __future__ import annotations
14
15 import hashlib
16 import math
17 from collections.abc import Iterator, Sequence
18 from contextlib import contextmanager
19 from dataclasses import dataclass, field
20 from typing import Literal
21
22 import numpy as np
23
24 from dlm_sway.backends._instrumentation import BackendInstrumentation
25 from dlm_sway.core.scoring import RollingLogprob, TokenDist
26
27 Mode = Literal["base", "ft"]
28
29
30 @dataclass(slots=True)
31 class DummyResponses:
32 """Canned data for one mode (base or ft).
33
34 Callers populate one of these per mode and hand both to
35 :class:`DummyDifferentialBackend`.
36 """
37
38 generations: dict[str, str] = field(default_factory=dict)
39 """Prompt → canned completion. Lookup is exact-match."""
40 logprobs: dict[tuple[str, str], float] = field(default_factory=dict)
41 """``(prompt, completion) → sum logprob``. Default ``-10.0`` if missing."""
42 rolling: dict[str, RollingLogprob] = field(default_factory=dict)
43 """Text → canned :class:`RollingLogprob`."""
44 token_dists: dict[str, TokenDist] = field(default_factory=dict)
45 """Prompt → canned :class:`TokenDist`."""
46
47
48 class _DummyView:
49 """The per-mode view yielded by ``as_base`` / ``as_finetuned``.
50
51 Implements :class:`~dlm_sway.core.model.Model` *and*
52 :class:`~dlm_sway.core.scoring.ScoringBackend` — i.e. the
53 ``ScoringModel`` intersection.
54 """
55
56 def __init__(
57 self,
58 mode: Mode,
59 responses: DummyResponses,
60 inst: BackendInstrumentation | None = None,
61 ) -> None:
62 # ``id`` is what surfaces in cache keys and trace events; widened
63 # to ``str`` so scaled / null views can override it with
64 # e.g. ``"scaled_0.50"`` / ``"null_42"`` without a cast dance.
65 self.id: str = mode
66 self._mode: Mode = mode
67 self._r = responses
68 # Private instrumentation when the caller didn't supply one —
69 # keeps direct ``_DummyView("base", DummyResponses())``
70 # constructions (common in probe tests that grab a view
71 # without going through the differential backend) working
72 # transparently.
73 self._inst: BackendInstrumentation = inst if inst is not None else BackendInstrumentation()
74
75 # -- Model ---------------------------------------------------------
76 def generate(
77 self,
78 prompt: str,
79 *,
80 max_new_tokens: int,
81 temperature: float = 0.0,
82 top_p: float = 1.0,
83 seed: int = 0,
84 ) -> str:
85 del max_new_tokens, temperature, top_p, seed # canned; decoding is trivial.
86 try:
87 return self._r.generations[prompt]
88 except KeyError as exc:
89 raise KeyError(
90 f"dummy backend ({self._mode}): no canned generation for prompt {prompt!r}"
91 ) from exc
92
93 def close(self) -> None:
94 return None
95
96 # -- ScoringBackend ------------------------------------------------
97 def logprob_of(self, prompt: str, completion: str) -> float:
98 return self._inst.cached(
99 "logprob_of",
100 self.id,
101 f"{prompt}\x00{completion}",
102 0,
103 lambda: self._r.logprobs.get((prompt, completion), -10.0),
104 )
105
106 def rolling_logprob(self, text: str) -> RollingLogprob:
107 return self._inst.cached(
108 "rolling_logprob",
109 self.id,
110 text,
111 0,
112 lambda: self._compute_rolling_logprob(text),
113 )
114
115 def _compute_rolling_logprob(self, text: str) -> RollingLogprob:
116 if text in self._r.rolling:
117 return self._r.rolling[text]
118 # Synthesize a plausible rolling logprob so probes that just
119 # want a non-trivial value work without per-text configuration.
120 tokens = text.split()
121 n = max(len(tokens), 1)
122 per_tok = -2.0 if self._mode == "base" else -1.5
123 return RollingLogprob(
124 token_ids=np.arange(n, dtype=np.int64),
125 logprobs=np.full(max(n - 1, 0), per_tok, dtype=np.float32),
126 num_tokens=n,
127 total_logprob=per_tok * max(n - 1, 0),
128 )
129
130 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
131 return self._inst.cached(
132 "next_token_dist",
133 self.id,
134 prompt,
135 top_k,
136 lambda: self._compute_next_token_dist(prompt, top_k=top_k),
137 )
138
139 def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
140 # Dummy backend has no real forward to batch against. Loop
141 # over ``self.next_token_dist`` so subclasses (``_NullView``,
142 # ``_InterpolatedView``) that override the per-prompt method
143 # get their per-prompt semantics preserved on the batched
144 # code path. Batching counters stay at zero on this backend —
145 # real amortization lives in the HF backend's override.
146 return [self.next_token_dist(p, top_k=top_k) for p in prompts]
147
148 def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
149 del top_k
150 if prompt in self._r.token_dists:
151 return self._r.token_dists[prompt]
152 # Synthesize a sharp base / broad ft distribution so divergence
153 # probes see a non-zero signal without hand-rolled data.
154 vocab = 1000
155 k = 8
156 if self._mode == "base":
157 lp = np.array([-0.1] + [-5.0] * (k - 1), dtype=np.float32)
158 else:
159 # More uniform mass across the top-k tokens. A *literally*
160 # uniform distribution looks like a broken lm_head to
161 # ``_divergence``'s degenerate-distribution guard, so we
162 # add a tiny monotonic perturbation — the spread stays
163 # small enough that the ft dist still looks "broad" to
164 # KL/JS but clears the 1e-9 uniformity threshold a real
165 # model would also clear via fp32 accumulation noise.
166 lp = np.full(k, -math.log(k), dtype=np.float32)
167 lp += np.linspace(-1e-4, 1e-4, k, dtype=np.float32)
168 # B6: tail_logprob=None means "no measurable tail" (k covers vocab
169 # or residual underflowed); reserve floats for measurable mass.
170 residual = 1.0 - float(np.exp(lp).sum())
171 tail_lp = math.log(residual) if residual > 1e-12 else None
172 return TokenDist(
173 token_ids=np.arange(k, dtype=np.int64),
174 logprobs=lp,
175 vocab_size=vocab,
176 tail_logprob=tail_lp,
177 )
178
179
180 class _NullView(_DummyView):
181 """A dummy view that perturbs the base distribution with seeded noise.
182
183 Used by :meth:`DummyDifferentialBackend.as_null_adapter`. The
184 perturbation is small (matches an ``init_scale=0.02`` adapter) so
185 the null-vs-base divergence stays well below real-adapter territory
186 in probe tests.
187
188 ``rank_scale`` simulates changing the effective LoRA rank: the
189 output variance of ``A·B`` scales linearly with rank, so the noise
190 std carries a ``sqrt(rank_scale)`` factor. Default 1.0 preserves
191 pre-S10 behavior exactly.
192 """
193
194 def __init__(
195 self,
196 base_responses: DummyResponses,
197 seed: int,
198 init_scale: float,
199 rank_scale: float = 1.0,
200 ) -> None:
201 super().__init__("base", base_responses)
202 self._seed = seed
203 self._init_scale = init_scale
204 self._rank_scale = rank_scale
205
206 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
207 base_dist = super().next_token_dist(prompt, top_k=top_k)
208 # F08 — Python's built-in ``hash(str)`` is salted per-process via
209 # ``PYTHONHASHSEED``, so the same ``(self._seed, prompt)`` pair
210 # produced different RNG streams across interpreter invocations.
211 # That violated the README's determinism contract and meant the
212 # null-stats disk cache (~/.dlm-sway/null-stats/) could serve
213 # stale values on a restart. ``hashlib.md5`` is stable; taking
214 # the first 8 hex digits keeps the arithmetic in 32-bit range.
215 prompt_hash = int(hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8], 16)
216 rng = np.random.default_rng(self._seed + prompt_hash % 1_000_003)
217 effective_scale = self._init_scale * math.sqrt(self._rank_scale)
218 noise = rng.normal(0.0, effective_scale, size=base_dist.logprobs.shape).astype(np.float32)
219 new_lp = base_dist.logprobs + noise
220 # Re-normalize (within the top-k slice) so a valid distribution comes back.
221 max_lp = new_lp.max()
222 new_probs = np.exp(new_lp - max_lp)
223 new_probs /= new_probs.sum()
224 return TokenDist(
225 token_ids=base_dist.token_ids,
226 logprobs=np.log(new_probs).astype(np.float32),
227 vocab_size=base_dist.vocab_size,
228 tail_logprob=base_dist.tail_logprob,
229 )
230
231
232 class _InterpolatedView(_DummyView):
233 """A dummy view where logits/dists are a lam-blend of base and ft.
234
235 Used by :meth:`DummyDifferentialBackend.as_scaled_adapter`.
236 Generation falls back to the ft view at lam>=0.5, base otherwise —
237 rounded because the dummy backend's generations are canned strings
238 with no notion of "how much".
239 """
240
241 def __init__(
242 self,
243 base_responses: DummyResponses,
244 ft_responses: DummyResponses,
245 lam: float,
246 ) -> None:
247 super().__init__(
248 "ft" if lam >= 0.5 else "base", ft_responses if lam >= 0.5 else base_responses
249 )
250 self._base_r = base_responses
251 self._ft_r = ft_responses
252 self._lam = lam
253
254 def logprob_of(self, prompt: str, completion: str) -> float:
255 base_v = self._base_r.logprobs.get((prompt, completion), -10.0)
256 ft_v = self._ft_r.logprobs.get((prompt, completion), -10.0)
257 return (1 - self._lam) * base_v + self._lam * ft_v
258
259 def next_token_dist(self, prompt: str, *, top_k: int = 256): # type: ignore[no-untyped-def]
260 base_dist = _DummyView("base", self._base_r).next_token_dist(prompt, top_k=top_k)
261 ft_dist = _DummyView("ft", self._ft_r).next_token_dist(prompt, top_k=top_k)
262 # Both dists are on the same synthetic support when unseeded; blend
263 # their logprobs via log-space linear interpolation, which is a
264 # log-linear "tempered" mix and keeps normalization close enough.
265 lam = self._lam
266 blended_lp = (1 - lam) * base_dist.logprobs + lam * ft_dist.logprobs
267 return type(base_dist)(
268 token_ids=base_dist.token_ids,
269 logprobs=blended_lp,
270 vocab_size=base_dist.vocab_size,
271 tail_logprob=base_dist.tail_logprob,
272 )
273
274
275 class DummyDifferentialBackend:
276 """Dummy implementation of
277 :class:`~dlm_sway.core.scoring.DifferentialBackend`.
278
279 Construction takes one :class:`DummyResponses` per mode. The two
280 modes are mutually exclusive — the backend enforces that callers
281 exit one view before entering the other, catching bugs in probes
282 that hold a stale view across a toggle.
283
284 Dummy declares ``safe_for_concurrent_views = False`` to mirror the
285 shipped backends' posture; tests that want to exercise the concurrent
286 scheduling path can subclass and set it ``True``.
287
288 Also implements
289 :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend` with a
290 linear-blend between base and ft responses, so probes that need
291 ``as_scaled_adapter`` (N2 AdapterAblation) are unit-testable.
292 """
293
294 safe_for_concurrent_views: bool = False
295
296 def __init__(self, *, base: DummyResponses, ft: DummyResponses) -> None:
297 self._base_r = base
298 self._ft_r = ft
299 # Sprint 07: one shared cache + trace + stats instance that
300 # every view yielded by this backend reads/writes. Tests can
301 # peek at ``backend._inst.stats`` to assert cache behavior.
302 self._inst = BackendInstrumentation()
303 self._base = _DummyView("base", base, inst=self._inst)
304 self._ft = _DummyView("ft", ft, inst=self._inst)
305 self._active: str | None = None
306
307 @contextmanager
308 def as_base(self) -> Iterator[_DummyView]:
309 self._enter("base")
310 try:
311 yield self._base
312 finally:
313 self._exit()
314
315 @contextmanager
316 def as_finetuned(self) -> Iterator[_DummyView]:
317 self._enter("ft")
318 try:
319 yield self._ft
320 finally:
321 self._exit()
322
323 @contextmanager
324 def as_scaled_adapter(self, lam: float) -> Iterator[_DummyView]:
325 self._enter(f"scaled({lam})")
326 try:
327 view = _InterpolatedView(self._base_r, self._ft_r, lam)
328 view._inst = self._inst
329 view.id = f"scaled_{lam:.2f}"
330 yield view
331 finally:
332 self._exit()
333
334 @contextmanager
335 def as_null_adapter(
336 self,
337 seed: int,
338 *,
339 init_scale: float = 0.02,
340 rank_scale: float = 1.0,
341 ) -> Iterator[_DummyView]:
342 if rank_scale <= 0.0 or not math.isfinite(rank_scale):
343 raise ValueError(f"rank_scale must be positive and finite; got {rank_scale!r}")
344 label = f"null({seed})" if rank_scale == 1.0 else f"null({seed},rank={rank_scale:.2f})"
345 view_id = f"null_{seed}" if rank_scale == 1.0 else f"null_{seed}_rank{rank_scale:.2f}"
346 self._enter(label)
347 try:
348 view = _NullView(self._base_r, seed=seed, init_scale=init_scale, rank_scale=rank_scale)
349 view._inst = self._inst
350 view.id = view_id
351 yield view
352 finally:
353 self._exit()
354
355 def preflight_finite_check(self) -> tuple[bool, str]:
356 """Smoke a single forward pass per view; reject non-finite logits.
357
358 For the dummy backend the canned data is finite by construction
359 unless tests deliberately seed NaN-laden ``TokenDist`` entries —
360 which is exactly what S01 tests do to verify the runner gate.
361 """
362 prompt = "preflight"
363 try:
364 with self.as_base() as base_view:
365 base_dist = base_view.next_token_dist(prompt, top_k=8)
366 with self.as_finetuned() as ft_view:
367 ft_dist = ft_view.next_token_dist(prompt, top_k=8)
368 except Exception as exc: # noqa: BLE001
369 return False, f"preflight raised {type(exc).__name__}: {exc}"
370
371 for label, dist in (("base", base_dist), ("ft", ft_dist)):
372 if not np.all(np.isfinite(dist.logprobs)):
373 n_bad = int((~np.isfinite(dist.logprobs)).sum())
374 return (
375 False,
376 f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite "
377 f"logprob(s) on prompt {prompt!r}",
378 )
379 return True, ""
380
381 def _enter(self, mode: str) -> None:
382 if self._active is not None:
383 raise RuntimeError(
384 f"DifferentialBackend view already active ({self._active!r}); "
385 f"exit the current view before entering {mode!r}."
386 )
387 self._active = mode
388
389 def _exit(self) -> None:
390 self._active = None