Python · 10249 bytes Raw Blame History
1 """Scoring protocols: logprobs, next-token distributions, differential toggling.
2
3 Scoring is **separate** from generation because not every backend can
4 provide logits. Every numeric sway probe depends on at least one of
5 three operations:
6
7 1. ``logprob_of(prompt, completion)`` — score a completion against a
8 prompt (A1, B2, B3, C2, …).
9 2. ``rolling_logprob(text)`` — perplexity over a piece of text (B1,
10 C2).
11 3. ``next_token_dist(prompt, top_k)`` — the raw next-token distribution
12 at a single position (A1, N2).
13
14 The :class:`DifferentialBackend` is the key performance primitive:
15 both base and fine-tuned views share the same loaded weights and KV
16 cache layout, toggled via PEFT's :meth:`set_adapter` /
17 :meth:`disable_adapter`. A naive "load twice" implementation would
18 double memory and halve throughput.
19 """
20
21 from __future__ import annotations
22
23 from collections.abc import Sequence
24 from contextlib import AbstractContextManager
25 from dataclasses import dataclass, field
26 from typing import Protocol, runtime_checkable
27
28 import numpy as np
29 from numpy.typing import NDArray
30
31 from dlm_sway.core.model import Model
32
33
34 @dataclass(frozen=True, slots=True)
35 class RollingLogprob:
36 """Per-token logprobs over a piece of text, plus summary stats.
37
38 Attributes
39 ----------
40 token_ids:
41 The tokenizer output for ``text``. Length ``N``.
42 logprobs:
43 ``log p(token_i | token_<i)`` for each position i ≥ 1. Length
44 ``N-1``.
45 num_tokens:
46 ``N`` — included for convenience; ``len(token_ids)``.
47 total_logprob:
48 Sum of :attr:`logprobs`.
49 """
50
51 token_ids: NDArray[np.int64]
52 logprobs: NDArray[np.float32]
53 num_tokens: int
54 total_logprob: float
55
56 @property
57 def mean_logprob(self) -> float:
58 n = self.logprobs.size
59 return float(self.total_logprob / n) if n else 0.0
60
61 @property
62 def perplexity(self) -> float:
63 """``exp(-mean_logprob)``. Base-e, natural perplexity."""
64 return float(np.exp(-self.mean_logprob))
65
66
67 @dataclass(frozen=True, slots=True)
68 class TokenDist:
69 """A (possibly top-k truncated) next-token probability distribution.
70
71 For KL / JS divergence probes sway needs matched distributions
72 across base and fine-tuned views. The runner is responsible for
73 aligning ``top_k`` token slices between two ``TokenDist`` objects
74 before handing them to divergence math.
75 """
76
77 token_ids: NDArray[np.int64]
78 """Token ids, descending by probability. Length ``k``."""
79 logprobs: NDArray[np.float32]
80 """Log-probabilities for :attr:`token_ids`. Length ``k``."""
81 vocab_size: int
82 """Full vocab size — needed to renormalize top-k truncated slices."""
83 tail_logprob: float | None = field(default=None)
84 """Log of the residual mass outside the top-k slice (B6).
85
86 Three states the consumer must distinguish:
87
88 - ``None`` — the top-k slice already covers the full vocabulary
89 (``k == vocab_size``) or the residual underflowed below the
90 backend's reportable floor. Treat as "no tail to redistribute."
91 - ``0.0`` — the residual mass is *exactly* zero in fp32. A real
92 tail exists in theory but isn't measurable above the backend's
93 epsilon. Equivalent to ``None`` for divergence math but kept
94 separate so backends with extra precision can opt in.
95 - ``float`` (negative) — measurable tail mass. Divergence helpers
96 redistribute this evenly across the vocab tokens not in either
97 side's top-k.
98 """
99
100
101 @runtime_checkable
102 class ScoringBackend(Protocol):
103 """Logit-level access to a loaded model."""
104
105 def logprob_of(self, prompt: str, completion: str) -> float:
106 """Sum of log-probabilities of ``completion`` tokens given ``prompt``.
107
108 The prompt is *not* scored; only the completion contributes. The
109 value is in nats (natural log). Longer completions are
110 monotonically more negative — callers normalize by length if
111 they need a rate.
112 """
113 ...
114
115 def rolling_logprob(self, text: str) -> RollingLogprob:
116 """Compute per-token logprobs for the whole of ``text``.
117
118 Equivalent to lm-eval's ``loglikelihood_rolling``. Used for
119 perplexity comparison on held-out content (B1 SIS, C2).
120 """
121 ...
122
123 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
124 """Next-token distribution at the position after ``prompt``.
125
126 Truncated to ``top_k`` for memory; callers doing divergence math
127 over the top-k slice accept the (typically negligible) error vs
128 full-vocab KL.
129 """
130 ...
131
132 def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
133 """Batched variant of :meth:`next_token_dist`.
134
135 Returns one :class:`TokenDist` per entry in ``prompts``, in the
136 same order. Backends with real batching support (HF, MLX)
137 amortize kernel-launch + memory-transfer cost across the batch
138 — a 3-5× speedup for KL-style probes.
139
140 The default implementation on this Protocol loops over
141 :meth:`next_token_dist`; backends that don't benefit from
142 batching (``dummy``, ``api``) can inherit the default and
143 ignore this method. Implementations that override MUST produce
144 results numerically identical to the per-prompt path within a
145 tight float32 tolerance — the S07 cache is consulted
146 per-prompt before the batch is built so callers can mix cached
147 and uncached prompts in one call without surprise drift.
148 """
149 return [self.next_token_dist(p, top_k=top_k) for p in prompts]
150
151
152 @runtime_checkable
153 class DifferentialBackend(Protocol):
154 """A backend that holds base + fine-tuned views on a single loaded model.
155
156 The idiomatic usage is::
157
158 with backend.as_base() as base_view:
159 p_base = base_view.next_token_dist(prompt)
160 with backend.as_finetuned() as ft_view:
161 p_ft = ft_view.next_token_dist(prompt)
162
163 Implementations toggle PEFT adapters via
164 :meth:`peft.PeftModel.set_adapter` / :meth:`disable_adapter`.
165
166 Invariant: the two views must be **not simultaneously usable**. A
167 caller holding a ``base_view`` after entering the ``as_finetuned``
168 context is a programmer error and implementations MUST detect and
169 raise.
170 """
171
172 def as_base(self) -> AbstractContextManager[_ScoringModel]: ...
173
174 def as_finetuned(self) -> AbstractContextManager[_ScoringModel]: ...
175
176
177 @runtime_checkable
178 class ScalableDifferentialBackend(DifferentialBackend, Protocol):
179 """A differential backend that can also scale the LoRA additive term.
180
181 LoRA applies ``W + (alpha/r) · B @ A`` to a base weight matrix. This
182 protocol exposes a context manager that temporarily multiplies that
183 additive term by ``lam`` for everything inside the ``with`` block.
184
185 ``lam = 0.0`` is equivalent to :meth:`as_base`.
186 ``lam = 1.0`` is equivalent to :meth:`as_finetuned`.
187 ``lam = 1.25`` overshoots — useful for N2 AdapterAblation's
188 response-curve measurement.
189
190 Only the HF backend ships an implementation in v0.1. Probes that
191 need scaling check via ``isinstance(backend, ScalableDifferentialBackend)``
192 at runtime and SKIP gracefully when unavailable.
193 """
194
195 def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ...
196
197
198 @runtime_checkable
199 class PreflightCheckable(Protocol):
200 """A backend that can validate itself before any probe runs.
201
202 Returns ``(ok, reason)`` from a single forward pass per view with a
203 fixed sentinel prompt, asserting that both the base and fine-tuned
204 distributions contain finite logits.
205
206 The runner calls this at suite start; on failure it aborts with a
207 single synthetic ERROR probe explaining the issue, so a NaN-weighted
208 adapter never produces a false PASS verdict (the +11639σ class of
209 bug from Audit 01).
210
211 This Protocol is **opt-in** — backends that don't implement it run
212 without the check (the runner skips with a NOTE-level log entry).
213 All shipped backends in this version implement it; custom backends
214 are encouraged to.
215 """
216
217 def preflight_finite_check(self) -> tuple[bool, str]: ...
218
219
220 @runtime_checkable
221 class NullCalibratedBackend(DifferentialBackend, Protocol):
222 """A differential backend that can produce a "null adapter" view.
223
224 A null adapter has the *same structure* (rank, alpha, target modules)
225 as the real adapter but with weights drawn from a zero-mean Gaussian.
226 Running probes against this view yields the baseline "how much
227 signal does random noise produce" distribution — the denominator in
228 every numeric probe's z-score.
229
230 The context manager takes a ``seed`` so calibration runs can be
231 reproduced and multiple independent null samples can be drawn to
232 estimate ``std``.
233
234 Implementations MUST restore the real adapter on exit, including
235 on exceptions, so a caller can freely interleave null and real
236 calibrations within the same backend lifetime.
237
238 ``rank_scale`` lets callers simulate a null adapter of a different
239 effective rank without reshaping the underlying PEFT tensors. The
240 output variance of the LoRA product ``A·B`` scales linearly with
241 rank, so a faithful rank-``r · rank_scale`` null is approximated
242 by scaling each factor's noise std by ``sqrt(rank_scale)``.
243 Implementations MUST multiply ``init_scale`` by ``sqrt(rank_scale)``
244 internally (and reject negative or zero values).
245 """
246
247 def as_null_adapter(
248 self, seed: int, *, init_scale: float = 0.02, rank_scale: float = 1.0
249 ) -> AbstractContextManager[_ScoringModel]: ...
250
251
252 # Helper Protocol for type-checking the yielded context object: it
253 # must satisfy both Model and ScoringBackend. mypy doesn't support
254 # intersection types, so we spell it out explicitly.
255 @runtime_checkable
256 class _ScoringModel(Model, ScoringBackend, Protocol):
257 """A Model that also exposes ScoringBackend."""
258
259 ...
260
261
262 ScoringModel = _ScoringModel
263 """Public alias for the intersection ``Model & ScoringBackend``.
264
265 Exported for backend and probe implementations that need to annotate
266 variables of this combined type.
267 """