Python · 23329 bytes Raw Blame History
1 """HuggingFace + PEFT differential backend.
2
3 Loads the base once, attaches the LoRA adapter once, and toggles between
4 "base" and "fine-tuned" views on the same module via PEFT's
5 :meth:`~peft.PeftModel.disable_adapter` / :meth:`~peft.PeftModel.set_adapter`.
6
7 This is the single most important backend in sway. Every numeric probe
8 benefits from the shared-weights toggle — memory is halved compared to
9 loading two copies, and KV-cache layouts stay aligned so pairwise KL math
10 is straight-forward.
11
12 Heavy imports (``torch``, ``transformers``, ``peft``) are deferred until
13 ``HuggingFaceDifferentialBackend`` is actually instantiated so
14 ``import dlm_sway`` stays light for users of the dummy backend or spec
15 validation.
16 """
17
18 from __future__ import annotations
19
20 from collections.abc import Iterator, Sequence
21 from contextlib import contextmanager
22 from dataclasses import dataclass
23 from pathlib import Path
24 from typing import TYPE_CHECKING, Any, Literal
25
26 import numpy as np
27
28 from dlm_sway.backends._instrumentation import BackendInstrumentation
29 from dlm_sway.core.errors import BackendNotAvailableError, ProbeError
30 from dlm_sway.core.model import ModelSpec
31 from dlm_sway.core.scoring import RollingLogprob, TokenDist
32
33 if TYPE_CHECKING:
34 from transformers import PreTrainedModel, PreTrainedTokenizerBase
35
36
37 Device = Literal["cuda", "mps", "cpu"]
38
39
40 def _detect_device() -> Device:
41 try:
42 import torch
43 except ImportError as exc:
44 raise BackendNotAvailableError("hf", extra="hf") from exc
45 if torch.cuda.is_available():
46 return "cuda"
47 if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
48 return "mps"
49 return "cpu"
50
51
52 def _resolve_dtype(requested: str, device: Device) -> Any:
53 """Map the user's ``dtype`` preference to a torch dtype."""
54 import torch # noqa: PLC0415 — lazy
55
56 if requested == "fp16":
57 return torch.float16
58 if requested == "bf16":
59 return torch.bfloat16
60 if requested == "fp32":
61 return torch.float32
62 # auto: bf16 on CUDA (Ampere+) / MPS; fp32 on CPU for numerical stability.
63 if device == "cuda" and torch.cuda.is_bf16_supported():
64 return torch.bfloat16
65 if device == "mps":
66 return torch.float16
67 return torch.float32
68
69
70 def _require_hf() -> tuple[Any, Any, Any]:
71 """Import torch + transformers + peft, raising a friendly error if missing."""
72 try:
73 import torch
74 import transformers
75 except ImportError as exc:
76 raise BackendNotAvailableError("hf", extra="hf") from exc
77 try:
78 import peft
79 except ImportError as exc:
80 raise BackendNotAvailableError(
81 "hf", extra="hf", hint="peft is required for the adapter toggle."
82 ) from exc
83 return torch, transformers, peft
84
85
86 # --- internal helpers -----------------------------------------------------
87
88
89 def _topk_to_token_dist(log_probs: Any, *, top_k: int) -> TokenDist:
90 """Build a :class:`TokenDist` from a 1-D ``log_probs`` torch Tensor.
91
92 Shared by the single-prompt and batched paths so the tail-mass
93 accounting (B6) stays identical across both.
94 """
95 import torch
96
97 vocab = int(log_probs.shape[0])
98 k = min(top_k, vocab)
99 top = torch.topk(log_probs, k=k)
100 if k == vocab:
101 tail_logprob: float | None = None
102 else:
103 tail_mass = float(1.0 - torch.exp(top.values).sum().item())
104 tail_logprob = float(np.log(tail_mass)) if tail_mass > 1e-12 else 0.0
105 return TokenDist(
106 token_ids=top.indices.cpu().numpy().astype(np.int64),
107 logprobs=top.values.cpu().numpy().astype(np.float32),
108 vocab_size=vocab,
109 tail_logprob=tail_logprob,
110 )
111
112
113 # --- the view object ------------------------------------------------------
114
115
116 @dataclass(slots=True)
117 class _HFView:
118 """One side (base or ft) of a :class:`HuggingFaceDifferentialBackend`.
119
120 Both sides reuse the same underlying module; the difference is
121 whether the adapter is active. Scoring calls route through
122 ``_inst`` (the backend's shared cache + tracer + stats) so repeated
123 forward passes on the same ``(view_id, prompt, top_k)`` are served
124 from the LRU instead of re-executed — the Sprint 07 performance win.
125 """
126
127 id: str
128 _model: Any
129 _tokenizer: Any
130 _device: str
131 _pad_token_id: int
132 _inst: BackendInstrumentation
133
134 # -- Model ---------------------------------------------------------
135 def generate(
136 self,
137 prompt: str,
138 *,
139 max_new_tokens: int,
140 temperature: float = 0.0,
141 top_p: float = 1.0,
142 seed: int = 0,
143 ) -> str:
144 # Generation is intentionally *not* cached — probes call it
145 # through (prompt, max_new_tokens, temperature, seed) tuples
146 # that rarely collide across a suite, and cache hits on sampled
147 # output would hide seed bugs behind stale strings.
148 import torch
149
150 torch.manual_seed(seed)
151 inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device)
152 do_sample = temperature > 0.0
153 gen_kwargs: dict[str, Any] = {
154 "max_new_tokens": max_new_tokens,
155 "do_sample": do_sample,
156 "pad_token_id": self._pad_token_id,
157 }
158 if do_sample:
159 gen_kwargs["temperature"] = temperature
160 gen_kwargs["top_p"] = top_p
161 with torch.inference_mode():
162 out_ids = self._model.generate(**inputs, **gen_kwargs)
163 new_tokens = out_ids[0, inputs["input_ids"].shape[1] :]
164 return str(self._tokenizer.decode(new_tokens, skip_special_tokens=True))
165
166 def close(self) -> None:
167 return None
168
169 # -- ScoringBackend ------------------------------------------------
170 def logprob_of(self, prompt: str, completion: str) -> float:
171 # Fold (prompt, completion) into one cache-key string so a repeat
172 # (q, a) pair hits the cache without the completion args needing
173 # their own slot in ``ForwardCache``.
174 key_prompt = f"{prompt}\x00{completion}"
175 return self._inst.cached(
176 "logprob_of",
177 self.id,
178 key_prompt,
179 0,
180 lambda: self._compute_logprob_of(prompt, completion),
181 )
182
183 def _compute_logprob_of(self, prompt: str, completion: str) -> float:
184 import torch
185 import torch.nn.functional as F
186
187 prompt_ids = self._tokenizer(prompt, return_tensors="pt").input_ids.to(self._device)
188 full_ids = self._tokenizer(prompt + completion, return_tensors="pt").input_ids.to(
189 self._device
190 )
191 if full_ids.shape[1] <= prompt_ids.shape[1]:
192 raise ProbeError(
193 "logprob_of",
194 f"completion tokenized to zero tokens (prompt={prompt!r}, completion={completion!r})",
195 )
196 target_ids = full_ids[:, prompt_ids.shape[1] :]
197 with torch.inference_mode():
198 logits = self._model(full_ids).logits # (1, T, V)
199 # Align: logit at position t predicts token at t+1. We want
200 # predictions for the completion slice.
201 shift_logits = logits[:, prompt_ids.shape[1] - 1 : -1, :] # (1, C, V)
202 log_probs = F.log_softmax(shift_logits.float(), dim=-1)
203 gathered = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
204 return float(gathered.sum().item())
205
206 def rolling_logprob(self, text: str) -> RollingLogprob:
207 return self._inst.cached(
208 "rolling_logprob",
209 self.id,
210 text,
211 0,
212 lambda: self._compute_rolling_logprob(text),
213 )
214
215 def _compute_rolling_logprob(self, text: str) -> RollingLogprob:
216 import torch
217 import torch.nn.functional as F
218
219 ids = self._tokenizer(text, return_tensors="pt").input_ids.to(self._device)
220 if ids.shape[1] < 2:
221 return RollingLogprob(
222 token_ids=ids[0].cpu().numpy().astype(np.int64),
223 logprobs=np.array([], dtype=np.float32),
224 num_tokens=int(ids.shape[1]),
225 total_logprob=0.0,
226 )
227 with torch.inference_mode():
228 logits = self._model(ids).logits # (1, T, V)
229 log_probs = F.log_softmax(logits[:, :-1].float(), dim=-1) # predicts tokens 1..T
230 gathered = log_probs.gather(-1, ids[:, 1:].unsqueeze(-1)).squeeze(-1).squeeze(0)
231 return RollingLogprob(
232 token_ids=ids[0].cpu().numpy().astype(np.int64),
233 logprobs=gathered.cpu().numpy().astype(np.float32),
234 num_tokens=int(ids.shape[1]),
235 total_logprob=float(gathered.sum().item()),
236 )
237
238 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
239 return self._inst.cached(
240 "next_token_dist",
241 self.id,
242 prompt,
243 top_k,
244 lambda: self._compute_next_token_dist(prompt, top_k=top_k),
245 )
246
247 def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
248 import torch
249 import torch.nn.functional as F
250
251 ids = self._tokenizer(prompt, return_tensors="pt").input_ids.to(self._device)
252 with torch.inference_mode():
253 logits = self._model(ids).logits[:, -1, :] # (1, V)
254 log_probs = F.log_softmax(logits.float(), dim=-1).squeeze(0)
255 return _topk_to_token_dist(log_probs, top_k=top_k)
256
257 def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
258 """Batched forward via tokenizer left-padding.
259
260 Decoder-only LMs need left-padding because the last-token
261 position (what :meth:`next_token_dist` reads) must line up with
262 the actual end of each sequence. Right-padding would put the
263 pad token at the end and read garbage logits. We set
264 ``padding_side="left"`` on the tokenizer call explicitly —
265 independent of any instance-wide setting the tokenizer may
266 carry for other code paths.
267
268 Cache-per-prompt: the instrumentation's :meth:`cached_batch`
269 checks the LRU for every prompt and only forwards the misses
270 through the batch. That means a mixed workload (e.g. the
271 second ``as_base`` view after a run saw the same prompts)
272 pays near-zero cost.
273 """
274 if not prompts:
275 return []
276
277 def compute_misses(miss_indices: list[int]) -> list[TokenDist]:
278 import torch
279 import torch.nn.functional as F
280
281 miss_prompts = [prompts[i] for i in miss_indices]
282 tokens = self._tokenizer(
283 miss_prompts,
284 return_tensors="pt",
285 padding=True,
286 padding_side="left",
287 ).to(self._device)
288 input_ids = tokens["input_ids"]
289 with torch.inference_mode():
290 logits = self._model(
291 input_ids=input_ids,
292 attention_mask=tokens.get("attention_mask"),
293 ).logits[:, -1, :] # (B, V) — left-pad makes "last" always the real last token
294 log_probs = F.log_softmax(logits.float(), dim=-1) # (B, V)
295 return [
296 _topk_to_token_dist(log_probs[row], top_k=top_k) for row in range(len(miss_prompts))
297 ]
298
299 return self._inst.cached_batch(
300 "next_token_dist", self.id, list(prompts), top_k, compute_misses
301 )
302
303
304 # --- the backend -----------------------------------------------------------
305
306
307 class HuggingFaceDifferentialBackend:
308 """A :class:`~dlm_sway.core.scoring.DifferentialBackend` for HF+PEFT.
309
310 The adapter toggle relies on
311 :meth:`peft.PeftModel.disable_adapter` producing a context where the
312 forward pass skips the LoRA deltas, and
313 :meth:`peft.PeftModel.set_adapter` (or just exiting the disable
314 context) re-enabling them. A dedicated sanity test asserts that
315 these actually change logits on a fixture.
316 """
317
318 #: B19 — the shared-weights toggle is not thread-safe. The runner
319 #: treats ``spec.defaults.concurrent_probes > 1`` as a no-op when
320 #: this attribute is ``False``. See ``.docs/design/backend-concurrency.md``.
321 safe_for_concurrent_views: bool = False
322
323 def __init__(self, *, base_spec: ModelSpec, adapter_path: Path) -> None:
324 torch, transformers, peft = _require_hf()
325 self._torch = torch
326 self._spec = base_spec
327 # Path normalization lives in ``ModelSpec.adapter`` (B22). When
328 # the backend is constructed via ``backends.build``, the value
329 # is already absolute. Direct constructions (some tests) may
330 # pass a relative path, so ``Path(...).resolve()`` stays as a
331 # cheap idempotent fallback.
332 self._adapter_path = Path(adapter_path).resolve()
333
334 device_str: Device = (
335 _detect_device() if base_spec.device == "auto" else base_spec.device # type: ignore[assignment]
336 )
337 self._device: str = device_str
338 dtype = _resolve_dtype(base_spec.dtype, device_str)
339
340 tokenizer = transformers.AutoTokenizer.from_pretrained(
341 str(self._adapter_path)
342 if (self._adapter_path / "tokenizer_config.json").exists()
343 else base_spec.base,
344 trust_remote_code=base_spec.trust_remote_code,
345 )
346 if tokenizer.pad_token_id is None:
347 tokenizer.pad_token = tokenizer.eos_token
348
349 base_model = transformers.AutoModelForCausalLM.from_pretrained(
350 base_spec.base,
351 torch_dtype=dtype,
352 trust_remote_code=base_spec.trust_remote_code,
353 )
354 base_model.to(self._device)
355 peft_model = peft.PeftModel.from_pretrained(
356 base_model,
357 str(self._adapter_path),
358 is_trainable=False,
359 )
360 peft_model.eval()
361
362 self._tokenizer: PreTrainedTokenizerBase = tokenizer
363 self._peft_model: PreTrainedModel = peft_model
364 self._pad_token_id: int = int(tokenizer.pad_token_id)
365 self._active: str | None = None
366 # Shared cache + trace + stats (Sprint 07). The instrumentation
367 # instance outlives individual view entries — entering/exiting
368 # ``as_base()`` doesn't reset the cache, because each view id
369 # (``"base"`` / ``"ft"`` / ``"scaled_0.50"`` / ``"null_42"``)
370 # is part of the cache key. A future toggle bug that fails to
371 # flip the view would surface as a cache hit returning the
372 # wrong side's value — integration tests cover this.
373 self._inst = BackendInstrumentation()
374
375 # -- DifferentialBackend -------------------------------------------
376
377 @contextmanager
378 def as_base(self) -> Iterator[_HFView]:
379 self._enter("base")
380 try:
381 # peft.PeftModel.disable_adapter is a context manager; newer
382 # transformers builds ship stubs that mis-type it as a Tensor,
383 # so we suppress the operator check on the `with` line. The
384 # ``unused-ignore`` tag makes the suppression itself a no-op
385 # when [hf] isn't installed (mypy can't see the conflicting
386 # stub and would otherwise flag the ignore as redundant).
387 with self._peft_model.disable_adapter(): # type: ignore[operator,unused-ignore]
388 yield self._make_view("base")
389 finally:
390 self._exit()
391
392 @contextmanager
393 def as_finetuned(self) -> Iterator[_HFView]:
394 self._enter("ft")
395 try:
396 yield self._make_view("ft")
397 finally:
398 self._exit()
399
400 @contextmanager
401 def as_scaled_adapter(self, lam: float) -> Iterator[_HFView]:
402 """Temporarily multiply every LoRA layer's scaling factor by ``lam``.
403
404 Works by walking the PEFT module tree and mutating each
405 ``LoraLayer.scaling[adapter_name]`` in place. The original
406 scalings are restored when the context exits — or when an
407 exception propagates, to keep the model in a sane state.
408 """
409 self._enter(f"scaled({lam})")
410 # ``module`` is dynamic (peft LoraLayer subclass) — Any avoids
411 # mypy treating its ``.scaling`` as a Tensor when peft is loaded.
412 saved: list[tuple[Any, str, float]] = []
413 try:
414 import peft # noqa: PLC0415 — already a hard dep of this backend
415
416 lora_cls = getattr(peft.tuners.lora, "LoraLayer", None)
417 if lora_cls is None:
418 raise RuntimeError("peft.tuners.lora.LoraLayer not found; check peft>=0.13 pin")
419 for module in self._peft_model.modules():
420 if not isinstance(module, lora_cls):
421 continue
422 scaling = getattr(module, "scaling", None)
423 if not isinstance(scaling, dict):
424 continue
425 for key, original in scaling.items():
426 saved.append((module, key, float(original)))
427 scaling[key] = float(original) * lam
428 yield self._make_view(f"scaled_{lam:.2f}")
429 finally:
430 for module, key, original in saved:
431 module.scaling[key] = original
432 self._exit()
433
434 @contextmanager
435 def as_null_adapter(
436 self, seed: int, *, init_scale: float = 0.02, rank_scale: float = 1.0
437 ) -> Iterator[_HFView]:
438 """Temporarily replace every LoRA ``A``/``B`` tensor with random noise.
439
440 Same rank, alpha, and target modules as the real adapter — only
441 the weights differ. This is the denominator in every z-score
442 path: "how much signal does structural noise produce?"
443
444 ``rank_scale`` simulates a null adapter at a different effective
445 rank without reshaping any tensor. The LoRA output ``A·B`` is a
446 sum of ``r`` rank-1 outer products; its output variance scales
447 linearly with ``r``. So a ``rank_scale`` of 0.5 is equivalent to
448 halving the rank in output-variance terms, which we get by
449 scaling both factors' noise std by ``sqrt(rank_scale)``. The
450 PEFT tensors keep their original shapes; no surgery on
451 alpha/scaling; no model reload.
452
453 Implementation walks the PEFT module tree for ``lora_A``/``lora_B``
454 parameters, saves a clone of each current value, overwrites in
455 place with a zero-mean Gaussian at ``init_scale *
456 sqrt(rank_scale)``, and restores on exit (including on exception).
457 """
458 import math
459
460 import torch
461
462 if rank_scale <= 0.0 or not math.isfinite(rank_scale):
463 raise ValueError(f"rank_scale must be positive and finite; got {rank_scale!r}")
464
465 effective_scale = init_scale * math.sqrt(rank_scale)
466 view_id = f"null_{seed}" if rank_scale == 1.0 else f"null_{seed}_rank{rank_scale:.2f}"
467
468 self._enter(view_id)
469 gen = torch.Generator(device="cpu").manual_seed(int(seed))
470 saved: list[tuple[torch.nn.Parameter, torch.Tensor]] = []
471 try:
472 for pname, param in self._peft_model.named_parameters():
473 if not any(key in pname for key in ("lora_A", "lora_B")):
474 continue
475 saved.append((param, param.detach().clone()))
476 with torch.no_grad():
477 noise = torch.randn(
478 *param.shape,
479 generator=gen,
480 dtype=torch.float32,
481 ).to(dtype=param.dtype, device=param.device)
482 param.copy_(noise * effective_scale)
483 yield self._make_view(view_id)
484 finally:
485 with torch.no_grad():
486 for param, original in saved:
487 param.copy_(original)
488 self._exit()
489
490 def close(self) -> None:
491 """Release GPU memory + flush the trace writer. Safe to call more than once."""
492 inst = getattr(self, "_inst", None)
493 if inst is not None:
494 inst.close()
495 if getattr(self, "_peft_model", None) is not None:
496 del self._peft_model
497 if self._torch.cuda.is_available():
498 self._torch.cuda.empty_cache()
499
500 # -- PreflightCheckable -------------------------------------------
501
502 _PREFLIGHT_PROMPT = "hello"
503 _PREFLIGHT_TOP_K = 8
504
505 def cache_identity(self) -> str:
506 """Stable string identifying this backend for on-disk caching.
507
508 The base model id + the adapter's resolved absolute path is
509 enough to key a null-calibration cache: swapping either
510 invalidates the previously-computed stats.
511 """
512 return f"hf:{self._spec.base}:{self._adapter_path}"
513
514 def preflight_finite_check(self) -> tuple[bool, str]:
515 """One forward pass per view; assert both produce finite logits.
516
517 Catches the +11639σ class of bug at suite-load time: a NaN-weighted
518 adapter would produce non-finite logprobs here, the runner sees
519 ``ok=False``, and the suite aborts with a single synthetic ERROR
520 probe — never reaching a probe that would pass on garbage.
521 """
522 import math
523
524 try:
525 with self.as_base() as base_view:
526 base_dist = base_view.next_token_dist(
527 self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K
528 )
529 with self.as_finetuned() as ft_view:
530 ft_dist = ft_view.next_token_dist(
531 self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K
532 )
533 except Exception as exc: # noqa: BLE001 — backend may raise anything
534 return False, f"preflight forward pass raised {type(exc).__name__}: {exc}"
535
536 for label, dist in (("base", base_dist), ("ft", ft_dist)):
537 n_bad = int((~np.isfinite(dist.logprobs)).sum())
538 if n_bad > 0:
539 return (
540 False,
541 f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite "
542 f"logprob(s) on prompt {self._PREFLIGHT_PROMPT!r} — adapter is "
543 f"likely broken (NaN/inf weights). sway refuses to score a model "
544 f"producing non-finite outputs.",
545 )
546 tail = dist.tail_logprob
547 # B6: ``None`` is a sentinel for "k covered the whole vocab,"
548 # not a numeric value to range-check.
549 if tail is not None and not math.isfinite(tail):
550 return (
551 False,
552 f"{label} view produced non-finite tail_logprob = {tail}",
553 )
554
555 return True, ""
556
557 # -- internals -----------------------------------------------------
558
559 def _make_view(self, mode: str) -> _HFView:
560 return _HFView(
561 id=mode,
562 _model=self._peft_model,
563 _tokenizer=self._tokenizer,
564 _device=self._device,
565 _pad_token_id=self._pad_token_id,
566 _inst=self._inst,
567 )
568
569 def _enter(self, mode: str) -> None:
570 if self._active is not None:
571 raise RuntimeError(
572 f"HuggingFaceDifferentialBackend view {self._active!r} already active; "
573 f"exit it before entering {mode!r}."
574 )
575 self._active = mode
576
577 def _exit(self) -> None:
578 self._active = None
579
580
581 __all__ = ["HuggingFaceDifferentialBackend"]