| 1 | """HuggingFace + PEFT differential backend. |
| 2 | |
| 3 | Loads the base once, attaches the LoRA adapter once, and toggles between |
| 4 | "base" and "fine-tuned" views on the same module via PEFT's |
| 5 | :meth:`~peft.PeftModel.disable_adapter` / :meth:`~peft.PeftModel.set_adapter`. |
| 6 | |
| 7 | This is the single most important backend in sway. Every numeric probe |
| 8 | benefits from the shared-weights toggle — memory is halved compared to |
| 9 | loading two copies, and KV-cache layouts stay aligned so pairwise KL math |
| 10 | is straight-forward. |
| 11 | |
| 12 | Heavy imports (``torch``, ``transformers``, ``peft``) are deferred until |
| 13 | ``HuggingFaceDifferentialBackend`` is actually instantiated so |
| 14 | ``import dlm_sway`` stays light for users of the dummy backend or spec |
| 15 | validation. |
| 16 | """ |
| 17 | |
| 18 | from __future__ import annotations |
| 19 | |
| 20 | from collections.abc import Iterator, Sequence |
| 21 | from contextlib import contextmanager |
| 22 | from dataclasses import dataclass |
| 23 | from pathlib import Path |
| 24 | from typing import TYPE_CHECKING, Any, Literal |
| 25 | |
| 26 | import numpy as np |
| 27 | |
| 28 | from dlm_sway.backends._instrumentation import BackendInstrumentation |
| 29 | from dlm_sway.core.errors import BackendNotAvailableError, ProbeError |
| 30 | from dlm_sway.core.model import ModelSpec |
| 31 | from dlm_sway.core.scoring import RollingLogprob, TokenDist |
| 32 | |
| 33 | if TYPE_CHECKING: |
| 34 | from transformers import PreTrainedModel, PreTrainedTokenizerBase |
| 35 | |
| 36 | |
| 37 | Device = Literal["cuda", "mps", "cpu"] |
| 38 | |
| 39 | |
| 40 | def _detect_device() -> Device: |
| 41 | try: |
| 42 | import torch |
| 43 | except ImportError as exc: |
| 44 | raise BackendNotAvailableError("hf", extra="hf") from exc |
| 45 | if torch.cuda.is_available(): |
| 46 | return "cuda" |
| 47 | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| 48 | return "mps" |
| 49 | return "cpu" |
| 50 | |
| 51 | |
| 52 | def _resolve_dtype(requested: str, device: Device) -> Any: |
| 53 | """Map the user's ``dtype`` preference to a torch dtype.""" |
| 54 | import torch # noqa: PLC0415 — lazy |
| 55 | |
| 56 | if requested == "fp16": |
| 57 | return torch.float16 |
| 58 | if requested == "bf16": |
| 59 | return torch.bfloat16 |
| 60 | if requested == "fp32": |
| 61 | return torch.float32 |
| 62 | # auto: bf16 on CUDA (Ampere+) / MPS; fp32 on CPU for numerical stability. |
| 63 | if device == "cuda" and torch.cuda.is_bf16_supported(): |
| 64 | return torch.bfloat16 |
| 65 | if device == "mps": |
| 66 | return torch.float16 |
| 67 | return torch.float32 |
| 68 | |
| 69 | |
| 70 | def _require_hf() -> tuple[Any, Any, Any]: |
| 71 | """Import torch + transformers + peft, raising a friendly error if missing.""" |
| 72 | try: |
| 73 | import torch |
| 74 | import transformers |
| 75 | except ImportError as exc: |
| 76 | raise BackendNotAvailableError("hf", extra="hf") from exc |
| 77 | try: |
| 78 | import peft |
| 79 | except ImportError as exc: |
| 80 | raise BackendNotAvailableError( |
| 81 | "hf", extra="hf", hint="peft is required for the adapter toggle." |
| 82 | ) from exc |
| 83 | return torch, transformers, peft |
| 84 | |
| 85 | |
| 86 | # --- internal helpers ----------------------------------------------------- |
| 87 | |
| 88 | |
| 89 | def _topk_to_token_dist(log_probs: Any, *, top_k: int) -> TokenDist: |
| 90 | """Build a :class:`TokenDist` from a 1-D ``log_probs`` torch Tensor. |
| 91 | |
| 92 | Shared by the single-prompt and batched paths so the tail-mass |
| 93 | accounting (B6) stays identical across both. |
| 94 | """ |
| 95 | import torch |
| 96 | |
| 97 | vocab = int(log_probs.shape[0]) |
| 98 | k = min(top_k, vocab) |
| 99 | top = torch.topk(log_probs, k=k) |
| 100 | if k == vocab: |
| 101 | tail_logprob: float | None = None |
| 102 | else: |
| 103 | tail_mass = float(1.0 - torch.exp(top.values).sum().item()) |
| 104 | tail_logprob = float(np.log(tail_mass)) if tail_mass > 1e-12 else 0.0 |
| 105 | return TokenDist( |
| 106 | token_ids=top.indices.cpu().numpy().astype(np.int64), |
| 107 | logprobs=top.values.cpu().numpy().astype(np.float32), |
| 108 | vocab_size=vocab, |
| 109 | tail_logprob=tail_logprob, |
| 110 | ) |
| 111 | |
| 112 | |
| 113 | # --- the view object ------------------------------------------------------ |
| 114 | |
| 115 | |
| 116 | @dataclass(slots=True) |
| 117 | class _HFView: |
| 118 | """One side (base or ft) of a :class:`HuggingFaceDifferentialBackend`. |
| 119 | |
| 120 | Both sides reuse the same underlying module; the difference is |
| 121 | whether the adapter is active. Scoring calls route through |
| 122 | ``_inst`` (the backend's shared cache + tracer + stats) so repeated |
| 123 | forward passes on the same ``(view_id, prompt, top_k)`` are served |
| 124 | from the LRU instead of re-executed — the Sprint 07 performance win. |
| 125 | """ |
| 126 | |
| 127 | id: str |
| 128 | _model: Any |
| 129 | _tokenizer: Any |
| 130 | _device: str |
| 131 | _pad_token_id: int |
| 132 | _inst: BackendInstrumentation |
| 133 | |
| 134 | # -- Model --------------------------------------------------------- |
| 135 | def generate( |
| 136 | self, |
| 137 | prompt: str, |
| 138 | *, |
| 139 | max_new_tokens: int, |
| 140 | temperature: float = 0.0, |
| 141 | top_p: float = 1.0, |
| 142 | seed: int = 0, |
| 143 | ) -> str: |
| 144 | # Generation is intentionally *not* cached — probes call it |
| 145 | # through (prompt, max_new_tokens, temperature, seed) tuples |
| 146 | # that rarely collide across a suite, and cache hits on sampled |
| 147 | # output would hide seed bugs behind stale strings. |
| 148 | import torch |
| 149 | |
| 150 | torch.manual_seed(seed) |
| 151 | inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device) |
| 152 | do_sample = temperature > 0.0 |
| 153 | gen_kwargs: dict[str, Any] = { |
| 154 | "max_new_tokens": max_new_tokens, |
| 155 | "do_sample": do_sample, |
| 156 | "pad_token_id": self._pad_token_id, |
| 157 | } |
| 158 | if do_sample: |
| 159 | gen_kwargs["temperature"] = temperature |
| 160 | gen_kwargs["top_p"] = top_p |
| 161 | with torch.inference_mode(): |
| 162 | out_ids = self._model.generate(**inputs, **gen_kwargs) |
| 163 | new_tokens = out_ids[0, inputs["input_ids"].shape[1] :] |
| 164 | return str(self._tokenizer.decode(new_tokens, skip_special_tokens=True)) |
| 165 | |
| 166 | def close(self) -> None: |
| 167 | return None |
| 168 | |
| 169 | # -- ScoringBackend ------------------------------------------------ |
| 170 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 171 | # Fold (prompt, completion) into one cache-key string so a repeat |
| 172 | # (q, a) pair hits the cache without the completion args needing |
| 173 | # their own slot in ``ForwardCache``. |
| 174 | key_prompt = f"{prompt}\x00{completion}" |
| 175 | return self._inst.cached( |
| 176 | "logprob_of", |
| 177 | self.id, |
| 178 | key_prompt, |
| 179 | 0, |
| 180 | lambda: self._compute_logprob_of(prompt, completion), |
| 181 | ) |
| 182 | |
| 183 | def _compute_logprob_of(self, prompt: str, completion: str) -> float: |
| 184 | import torch |
| 185 | import torch.nn.functional as F |
| 186 | |
| 187 | prompt_ids = self._tokenizer(prompt, return_tensors="pt").input_ids.to(self._device) |
| 188 | full_ids = self._tokenizer(prompt + completion, return_tensors="pt").input_ids.to( |
| 189 | self._device |
| 190 | ) |
| 191 | if full_ids.shape[1] <= prompt_ids.shape[1]: |
| 192 | raise ProbeError( |
| 193 | "logprob_of", |
| 194 | f"completion tokenized to zero tokens (prompt={prompt!r}, completion={completion!r})", |
| 195 | ) |
| 196 | target_ids = full_ids[:, prompt_ids.shape[1] :] |
| 197 | with torch.inference_mode(): |
| 198 | logits = self._model(full_ids).logits # (1, T, V) |
| 199 | # Align: logit at position t predicts token at t+1. We want |
| 200 | # predictions for the completion slice. |
| 201 | shift_logits = logits[:, prompt_ids.shape[1] - 1 : -1, :] # (1, C, V) |
| 202 | log_probs = F.log_softmax(shift_logits.float(), dim=-1) |
| 203 | gathered = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) |
| 204 | return float(gathered.sum().item()) |
| 205 | |
| 206 | def rolling_logprob(self, text: str) -> RollingLogprob: |
| 207 | return self._inst.cached( |
| 208 | "rolling_logprob", |
| 209 | self.id, |
| 210 | text, |
| 211 | 0, |
| 212 | lambda: self._compute_rolling_logprob(text), |
| 213 | ) |
| 214 | |
| 215 | def _compute_rolling_logprob(self, text: str) -> RollingLogprob: |
| 216 | import torch |
| 217 | import torch.nn.functional as F |
| 218 | |
| 219 | ids = self._tokenizer(text, return_tensors="pt").input_ids.to(self._device) |
| 220 | if ids.shape[1] < 2: |
| 221 | return RollingLogprob( |
| 222 | token_ids=ids[0].cpu().numpy().astype(np.int64), |
| 223 | logprobs=np.array([], dtype=np.float32), |
| 224 | num_tokens=int(ids.shape[1]), |
| 225 | total_logprob=0.0, |
| 226 | ) |
| 227 | with torch.inference_mode(): |
| 228 | logits = self._model(ids).logits # (1, T, V) |
| 229 | log_probs = F.log_softmax(logits[:, :-1].float(), dim=-1) # predicts tokens 1..T |
| 230 | gathered = log_probs.gather(-1, ids[:, 1:].unsqueeze(-1)).squeeze(-1).squeeze(0) |
| 231 | return RollingLogprob( |
| 232 | token_ids=ids[0].cpu().numpy().astype(np.int64), |
| 233 | logprobs=gathered.cpu().numpy().astype(np.float32), |
| 234 | num_tokens=int(ids.shape[1]), |
| 235 | total_logprob=float(gathered.sum().item()), |
| 236 | ) |
| 237 | |
| 238 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 239 | return self._inst.cached( |
| 240 | "next_token_dist", |
| 241 | self.id, |
| 242 | prompt, |
| 243 | top_k, |
| 244 | lambda: self._compute_next_token_dist(prompt, top_k=top_k), |
| 245 | ) |
| 246 | |
| 247 | def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 248 | import torch |
| 249 | import torch.nn.functional as F |
| 250 | |
| 251 | ids = self._tokenizer(prompt, return_tensors="pt").input_ids.to(self._device) |
| 252 | with torch.inference_mode(): |
| 253 | logits = self._model(ids).logits[:, -1, :] # (1, V) |
| 254 | log_probs = F.log_softmax(logits.float(), dim=-1).squeeze(0) |
| 255 | return _topk_to_token_dist(log_probs, top_k=top_k) |
| 256 | |
| 257 | def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]: |
| 258 | """Batched forward via tokenizer left-padding. |
| 259 | |
| 260 | Decoder-only LMs need left-padding because the last-token |
| 261 | position (what :meth:`next_token_dist` reads) must line up with |
| 262 | the actual end of each sequence. Right-padding would put the |
| 263 | pad token at the end and read garbage logits. We set |
| 264 | ``padding_side="left"`` on the tokenizer call explicitly — |
| 265 | independent of any instance-wide setting the tokenizer may |
| 266 | carry for other code paths. |
| 267 | |
| 268 | Cache-per-prompt: the instrumentation's :meth:`cached_batch` |
| 269 | checks the LRU for every prompt and only forwards the misses |
| 270 | through the batch. That means a mixed workload (e.g. the |
| 271 | second ``as_base`` view after a run saw the same prompts) |
| 272 | pays near-zero cost. |
| 273 | """ |
| 274 | if not prompts: |
| 275 | return [] |
| 276 | |
| 277 | def compute_misses(miss_indices: list[int]) -> list[TokenDist]: |
| 278 | import torch |
| 279 | import torch.nn.functional as F |
| 280 | |
| 281 | miss_prompts = [prompts[i] for i in miss_indices] |
| 282 | tokens = self._tokenizer( |
| 283 | miss_prompts, |
| 284 | return_tensors="pt", |
| 285 | padding=True, |
| 286 | padding_side="left", |
| 287 | ).to(self._device) |
| 288 | input_ids = tokens["input_ids"] |
| 289 | with torch.inference_mode(): |
| 290 | logits = self._model( |
| 291 | input_ids=input_ids, |
| 292 | attention_mask=tokens.get("attention_mask"), |
| 293 | ).logits[:, -1, :] # (B, V) — left-pad makes "last" always the real last token |
| 294 | log_probs = F.log_softmax(logits.float(), dim=-1) # (B, V) |
| 295 | return [ |
| 296 | _topk_to_token_dist(log_probs[row], top_k=top_k) for row in range(len(miss_prompts)) |
| 297 | ] |
| 298 | |
| 299 | return self._inst.cached_batch( |
| 300 | "next_token_dist", self.id, list(prompts), top_k, compute_misses |
| 301 | ) |
| 302 | |
| 303 | |
| 304 | # --- the backend ----------------------------------------------------------- |
| 305 | |
| 306 | |
| 307 | class HuggingFaceDifferentialBackend: |
| 308 | """A :class:`~dlm_sway.core.scoring.DifferentialBackend` for HF+PEFT. |
| 309 | |
| 310 | The adapter toggle relies on |
| 311 | :meth:`peft.PeftModel.disable_adapter` producing a context where the |
| 312 | forward pass skips the LoRA deltas, and |
| 313 | :meth:`peft.PeftModel.set_adapter` (or just exiting the disable |
| 314 | context) re-enabling them. A dedicated sanity test asserts that |
| 315 | these actually change logits on a fixture. |
| 316 | """ |
| 317 | |
| 318 | #: B19 — the shared-weights toggle is not thread-safe. The runner |
| 319 | #: treats ``spec.defaults.concurrent_probes > 1`` as a no-op when |
| 320 | #: this attribute is ``False``. See ``.docs/design/backend-concurrency.md``. |
| 321 | safe_for_concurrent_views: bool = False |
| 322 | |
| 323 | def __init__(self, *, base_spec: ModelSpec, adapter_path: Path) -> None: |
| 324 | torch, transformers, peft = _require_hf() |
| 325 | self._torch = torch |
| 326 | self._spec = base_spec |
| 327 | # Path normalization lives in ``ModelSpec.adapter`` (B22). When |
| 328 | # the backend is constructed via ``backends.build``, the value |
| 329 | # is already absolute. Direct constructions (some tests) may |
| 330 | # pass a relative path, so ``Path(...).resolve()`` stays as a |
| 331 | # cheap idempotent fallback. |
| 332 | self._adapter_path = Path(adapter_path).resolve() |
| 333 | |
| 334 | device_str: Device = ( |
| 335 | _detect_device() if base_spec.device == "auto" else base_spec.device # type: ignore[assignment] |
| 336 | ) |
| 337 | self._device: str = device_str |
| 338 | dtype = _resolve_dtype(base_spec.dtype, device_str) |
| 339 | |
| 340 | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 341 | str(self._adapter_path) |
| 342 | if (self._adapter_path / "tokenizer_config.json").exists() |
| 343 | else base_spec.base, |
| 344 | trust_remote_code=base_spec.trust_remote_code, |
| 345 | ) |
| 346 | if tokenizer.pad_token_id is None: |
| 347 | tokenizer.pad_token = tokenizer.eos_token |
| 348 | |
| 349 | base_model = transformers.AutoModelForCausalLM.from_pretrained( |
| 350 | base_spec.base, |
| 351 | torch_dtype=dtype, |
| 352 | trust_remote_code=base_spec.trust_remote_code, |
| 353 | ) |
| 354 | base_model.to(self._device) |
| 355 | peft_model = peft.PeftModel.from_pretrained( |
| 356 | base_model, |
| 357 | str(self._adapter_path), |
| 358 | is_trainable=False, |
| 359 | ) |
| 360 | peft_model.eval() |
| 361 | |
| 362 | self._tokenizer: PreTrainedTokenizerBase = tokenizer |
| 363 | self._peft_model: PreTrainedModel = peft_model |
| 364 | self._pad_token_id: int = int(tokenizer.pad_token_id) |
| 365 | self._active: str | None = None |
| 366 | # Shared cache + trace + stats (Sprint 07). The instrumentation |
| 367 | # instance outlives individual view entries — entering/exiting |
| 368 | # ``as_base()`` doesn't reset the cache, because each view id |
| 369 | # (``"base"`` / ``"ft"`` / ``"scaled_0.50"`` / ``"null_42"``) |
| 370 | # is part of the cache key. A future toggle bug that fails to |
| 371 | # flip the view would surface as a cache hit returning the |
| 372 | # wrong side's value — integration tests cover this. |
| 373 | self._inst = BackendInstrumentation() |
| 374 | |
| 375 | # -- DifferentialBackend ------------------------------------------- |
| 376 | |
| 377 | @contextmanager |
| 378 | def as_base(self) -> Iterator[_HFView]: |
| 379 | self._enter("base") |
| 380 | try: |
| 381 | # peft.PeftModel.disable_adapter is a context manager; newer |
| 382 | # transformers builds ship stubs that mis-type it as a Tensor, |
| 383 | # so we suppress the operator check on the `with` line. The |
| 384 | # ``unused-ignore`` tag makes the suppression itself a no-op |
| 385 | # when [hf] isn't installed (mypy can't see the conflicting |
| 386 | # stub and would otherwise flag the ignore as redundant). |
| 387 | with self._peft_model.disable_adapter(): # type: ignore[operator,unused-ignore] |
| 388 | yield self._make_view("base") |
| 389 | finally: |
| 390 | self._exit() |
| 391 | |
| 392 | @contextmanager |
| 393 | def as_finetuned(self) -> Iterator[_HFView]: |
| 394 | self._enter("ft") |
| 395 | try: |
| 396 | yield self._make_view("ft") |
| 397 | finally: |
| 398 | self._exit() |
| 399 | |
| 400 | @contextmanager |
| 401 | def as_scaled_adapter(self, lam: float) -> Iterator[_HFView]: |
| 402 | """Temporarily multiply every LoRA layer's scaling factor by ``lam``. |
| 403 | |
| 404 | Works by walking the PEFT module tree and mutating each |
| 405 | ``LoraLayer.scaling[adapter_name]`` in place. The original |
| 406 | scalings are restored when the context exits — or when an |
| 407 | exception propagates, to keep the model in a sane state. |
| 408 | """ |
| 409 | self._enter(f"scaled({lam})") |
| 410 | # ``module`` is dynamic (peft LoraLayer subclass) — Any avoids |
| 411 | # mypy treating its ``.scaling`` as a Tensor when peft is loaded. |
| 412 | saved: list[tuple[Any, str, float]] = [] |
| 413 | try: |
| 414 | import peft # noqa: PLC0415 — already a hard dep of this backend |
| 415 | |
| 416 | lora_cls = getattr(peft.tuners.lora, "LoraLayer", None) |
| 417 | if lora_cls is None: |
| 418 | raise RuntimeError("peft.tuners.lora.LoraLayer not found; check peft>=0.13 pin") |
| 419 | for module in self._peft_model.modules(): |
| 420 | if not isinstance(module, lora_cls): |
| 421 | continue |
| 422 | scaling = getattr(module, "scaling", None) |
| 423 | if not isinstance(scaling, dict): |
| 424 | continue |
| 425 | for key, original in scaling.items(): |
| 426 | saved.append((module, key, float(original))) |
| 427 | scaling[key] = float(original) * lam |
| 428 | yield self._make_view(f"scaled_{lam:.2f}") |
| 429 | finally: |
| 430 | for module, key, original in saved: |
| 431 | module.scaling[key] = original |
| 432 | self._exit() |
| 433 | |
| 434 | @contextmanager |
| 435 | def as_null_adapter( |
| 436 | self, seed: int, *, init_scale: float = 0.02, rank_scale: float = 1.0 |
| 437 | ) -> Iterator[_HFView]: |
| 438 | """Temporarily replace every LoRA ``A``/``B`` tensor with random noise. |
| 439 | |
| 440 | Same rank, alpha, and target modules as the real adapter — only |
| 441 | the weights differ. This is the denominator in every z-score |
| 442 | path: "how much signal does structural noise produce?" |
| 443 | |
| 444 | ``rank_scale`` simulates a null adapter at a different effective |
| 445 | rank without reshaping any tensor. The LoRA output ``A·B`` is a |
| 446 | sum of ``r`` rank-1 outer products; its output variance scales |
| 447 | linearly with ``r``. So a ``rank_scale`` of 0.5 is equivalent to |
| 448 | halving the rank in output-variance terms, which we get by |
| 449 | scaling both factors' noise std by ``sqrt(rank_scale)``. The |
| 450 | PEFT tensors keep their original shapes; no surgery on |
| 451 | alpha/scaling; no model reload. |
| 452 | |
| 453 | Implementation walks the PEFT module tree for ``lora_A``/``lora_B`` |
| 454 | parameters, saves a clone of each current value, overwrites in |
| 455 | place with a zero-mean Gaussian at ``init_scale * |
| 456 | sqrt(rank_scale)``, and restores on exit (including on exception). |
| 457 | """ |
| 458 | import math |
| 459 | |
| 460 | import torch |
| 461 | |
| 462 | if rank_scale <= 0.0 or not math.isfinite(rank_scale): |
| 463 | raise ValueError(f"rank_scale must be positive and finite; got {rank_scale!r}") |
| 464 | |
| 465 | effective_scale = init_scale * math.sqrt(rank_scale) |
| 466 | view_id = f"null_{seed}" if rank_scale == 1.0 else f"null_{seed}_rank{rank_scale:.2f}" |
| 467 | |
| 468 | self._enter(view_id) |
| 469 | gen = torch.Generator(device="cpu").manual_seed(int(seed)) |
| 470 | saved: list[tuple[torch.nn.Parameter, torch.Tensor]] = [] |
| 471 | try: |
| 472 | for pname, param in self._peft_model.named_parameters(): |
| 473 | if not any(key in pname for key in ("lora_A", "lora_B")): |
| 474 | continue |
| 475 | saved.append((param, param.detach().clone())) |
| 476 | with torch.no_grad(): |
| 477 | noise = torch.randn( |
| 478 | *param.shape, |
| 479 | generator=gen, |
| 480 | dtype=torch.float32, |
| 481 | ).to(dtype=param.dtype, device=param.device) |
| 482 | param.copy_(noise * effective_scale) |
| 483 | yield self._make_view(view_id) |
| 484 | finally: |
| 485 | with torch.no_grad(): |
| 486 | for param, original in saved: |
| 487 | param.copy_(original) |
| 488 | self._exit() |
| 489 | |
| 490 | def close(self) -> None: |
| 491 | """Release GPU memory + flush the trace writer. Safe to call more than once.""" |
| 492 | inst = getattr(self, "_inst", None) |
| 493 | if inst is not None: |
| 494 | inst.close() |
| 495 | if getattr(self, "_peft_model", None) is not None: |
| 496 | del self._peft_model |
| 497 | if self._torch.cuda.is_available(): |
| 498 | self._torch.cuda.empty_cache() |
| 499 | |
| 500 | # -- PreflightCheckable ------------------------------------------- |
| 501 | |
| 502 | _PREFLIGHT_PROMPT = "hello" |
| 503 | _PREFLIGHT_TOP_K = 8 |
| 504 | |
| 505 | def cache_identity(self) -> str: |
| 506 | """Stable string identifying this backend for on-disk caching. |
| 507 | |
| 508 | The base model id + the adapter's resolved absolute path is |
| 509 | enough to key a null-calibration cache: swapping either |
| 510 | invalidates the previously-computed stats. |
| 511 | """ |
| 512 | return f"hf:{self._spec.base}:{self._adapter_path}" |
| 513 | |
| 514 | def preflight_finite_check(self) -> tuple[bool, str]: |
| 515 | """One forward pass per view; assert both produce finite logits. |
| 516 | |
| 517 | Catches the +11639σ class of bug at suite-load time: a NaN-weighted |
| 518 | adapter would produce non-finite logprobs here, the runner sees |
| 519 | ``ok=False``, and the suite aborts with a single synthetic ERROR |
| 520 | probe — never reaching a probe that would pass on garbage. |
| 521 | """ |
| 522 | import math |
| 523 | |
| 524 | try: |
| 525 | with self.as_base() as base_view: |
| 526 | base_dist = base_view.next_token_dist( |
| 527 | self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K |
| 528 | ) |
| 529 | with self.as_finetuned() as ft_view: |
| 530 | ft_dist = ft_view.next_token_dist( |
| 531 | self._PREFLIGHT_PROMPT, top_k=self._PREFLIGHT_TOP_K |
| 532 | ) |
| 533 | except Exception as exc: # noqa: BLE001 — backend may raise anything |
| 534 | return False, f"preflight forward pass raised {type(exc).__name__}: {exc}" |
| 535 | |
| 536 | for label, dist in (("base", base_dist), ("ft", ft_dist)): |
| 537 | n_bad = int((~np.isfinite(dist.logprobs)).sum()) |
| 538 | if n_bad > 0: |
| 539 | return ( |
| 540 | False, |
| 541 | f"{label} view produced {n_bad}/{dist.logprobs.size} non-finite " |
| 542 | f"logprob(s) on prompt {self._PREFLIGHT_PROMPT!r} — adapter is " |
| 543 | f"likely broken (NaN/inf weights). sway refuses to score a model " |
| 544 | f"producing non-finite outputs.", |
| 545 | ) |
| 546 | tail = dist.tail_logprob |
| 547 | # B6: ``None`` is a sentinel for "k covered the whole vocab," |
| 548 | # not a numeric value to range-check. |
| 549 | if tail is not None and not math.isfinite(tail): |
| 550 | return ( |
| 551 | False, |
| 552 | f"{label} view produced non-finite tail_logprob = {tail}", |
| 553 | ) |
| 554 | |
| 555 | return True, "" |
| 556 | |
| 557 | # -- internals ----------------------------------------------------- |
| 558 | |
| 559 | def _make_view(self, mode: str) -> _HFView: |
| 560 | return _HFView( |
| 561 | id=mode, |
| 562 | _model=self._peft_model, |
| 563 | _tokenizer=self._tokenizer, |
| 564 | _device=self._device, |
| 565 | _pad_token_id=self._pad_token_id, |
| 566 | _inst=self._inst, |
| 567 | ) |
| 568 | |
| 569 | def _enter(self, mode: str) -> None: |
| 570 | if self._active is not None: |
| 571 | raise RuntimeError( |
| 572 | f"HuggingFaceDifferentialBackend view {self._active!r} already active; " |
| 573 | f"exit it before entering {mode!r}." |
| 574 | ) |
| 575 | self._active = mode |
| 576 | |
| 577 | def _exit(self) -> None: |
| 578 | self._active = None |
| 579 | |
| 580 | |
| 581 | __all__ = ["HuggingFaceDifferentialBackend"] |