Python · 7147 bytes Raw Blame History
1 """Scoring protocols: logprobs, next-token distributions, differential toggling.
2
3 Scoring is **separate** from generation because not every backend can
4 provide logits. Every numeric sway probe depends on at least one of
5 three operations:
6
7 1. ``logprob_of(prompt, completion)`` — score a completion against a
8 prompt (A1, B2, B3, C2, …).
9 2. ``rolling_logprob(text)`` — perplexity over a piece of text (B1,
10 C2).
11 3. ``next_token_dist(prompt, top_k)`` — the raw next-token distribution
12 at a single position (A1, N2).
13
14 The :class:`DifferentialBackend` is the key performance primitive:
15 both base and fine-tuned views share the same loaded weights and KV
16 cache layout, toggled via PEFT's :meth:`set_adapter` /
17 :meth:`disable_adapter`. A naive "load twice" implementation would
18 double memory and halve throughput.
19 """
20
21 from __future__ import annotations
22
23 from contextlib import AbstractContextManager
24 from dataclasses import dataclass, field
25 from typing import Protocol, runtime_checkable
26
27 import numpy as np
28 from numpy.typing import NDArray
29
30 from dlm_sway.core.model import Model
31
32
33 @dataclass(frozen=True, slots=True)
34 class RollingLogprob:
35 """Per-token logprobs over a piece of text, plus summary stats.
36
37 Attributes
38 ----------
39 token_ids:
40 The tokenizer output for ``text``. Length ``N``.
41 logprobs:
42 ``log p(token_i | token_<i)`` for each position i ≥ 1. Length
43 ``N-1``.
44 num_tokens:
45 ``N`` — included for convenience; ``len(token_ids)``.
46 total_logprob:
47 Sum of :attr:`logprobs`.
48 """
49
50 token_ids: NDArray[np.int64]
51 logprobs: NDArray[np.float32]
52 num_tokens: int
53 total_logprob: float
54
55 @property
56 def mean_logprob(self) -> float:
57 n = self.logprobs.size
58 return float(self.total_logprob / n) if n else 0.0
59
60 @property
61 def perplexity(self) -> float:
62 """``exp(-mean_logprob)``. Base-e, natural perplexity."""
63 return float(np.exp(-self.mean_logprob))
64
65
66 @dataclass(frozen=True, slots=True)
67 class TokenDist:
68 """A (possibly top-k truncated) next-token probability distribution.
69
70 For KL / JS divergence probes sway needs matched distributions
71 across base and fine-tuned views. The runner is responsible for
72 aligning ``top_k`` token slices between two ``TokenDist`` objects
73 before handing them to divergence math.
74 """
75
76 token_ids: NDArray[np.int64]
77 """Token ids, descending by probability. Length ``k``."""
78 logprobs: NDArray[np.float32]
79 """Log-probabilities for :attr:`token_ids`. Length ``k``."""
80 vocab_size: int
81 """Full vocab size — needed to renormalize top-k truncated slices."""
82 tail_logprob: float = field(default=0.0)
83 """log of (1 - sum of exp(logprobs[:k])); 0 if top_k covers the full vocab."""
84
85
86 @runtime_checkable
87 class ScoringBackend(Protocol):
88 """Logit-level access to a loaded model."""
89
90 def logprob_of(self, prompt: str, completion: str) -> float:
91 """Sum of log-probabilities of ``completion`` tokens given ``prompt``.
92
93 The prompt is *not* scored; only the completion contributes. The
94 value is in nats (natural log). Longer completions are
95 monotonically more negative — callers normalize by length if
96 they need a rate.
97 """
98 ...
99
100 def rolling_logprob(self, text: str) -> RollingLogprob:
101 """Compute per-token logprobs for the whole of ``text``.
102
103 Equivalent to lm-eval's ``loglikelihood_rolling``. Used for
104 perplexity comparison on held-out content (B1 SIS, C2).
105 """
106 ...
107
108 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
109 """Next-token distribution at the position after ``prompt``.
110
111 Truncated to ``top_k`` for memory; callers doing divergence math
112 over the top-k slice accept the (typically negligible) error vs
113 full-vocab KL.
114 """
115 ...
116
117
118 @runtime_checkable
119 class DifferentialBackend(Protocol):
120 """A backend that holds base + fine-tuned views on a single loaded model.
121
122 The idiomatic usage is::
123
124 with backend.as_base() as base_view:
125 p_base = base_view.next_token_dist(prompt)
126 with backend.as_finetuned() as ft_view:
127 p_ft = ft_view.next_token_dist(prompt)
128
129 Implementations toggle PEFT adapters via
130 :meth:`peft.PeftModel.set_adapter` / :meth:`disable_adapter`.
131
132 Invariant: the two views must be **not simultaneously usable**. A
133 caller holding a ``base_view`` after entering the ``as_finetuned``
134 context is a programmer error and implementations MUST detect and
135 raise.
136 """
137
138 def as_base(self) -> AbstractContextManager[_ScoringModel]: ...
139
140 def as_finetuned(self) -> AbstractContextManager[_ScoringModel]: ...
141
142
143 @runtime_checkable
144 class ScalableDifferentialBackend(DifferentialBackend, Protocol):
145 """A differential backend that can also scale the LoRA additive term.
146
147 LoRA applies ``W + (alpha/r) · B @ A`` to a base weight matrix. This
148 protocol exposes a context manager that temporarily multiplies that
149 additive term by ``lam`` for everything inside the ``with`` block.
150
151 ``lam = 0.0`` is equivalent to :meth:`as_base`.
152 ``lam = 1.0`` is equivalent to :meth:`as_finetuned`.
153 ``lam = 1.25`` overshoots — useful for N2 AdapterAblation's
154 response-curve measurement.
155
156 Only the HF backend ships an implementation in v0.1. Probes that
157 need scaling check via ``isinstance(backend, ScalableDifferentialBackend)``
158 at runtime and SKIP gracefully when unavailable.
159 """
160
161 def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ...
162
163
164 @runtime_checkable
165 class NullCalibratedBackend(DifferentialBackend, Protocol):
166 """A differential backend that can produce a "null adapter" view.
167
168 A null adapter has the *same structure* (rank, alpha, target modules)
169 as the real adapter but with weights drawn from a zero-mean Gaussian.
170 Running probes against this view yields the baseline "how much
171 signal does random noise produce" distribution — the denominator in
172 every numeric probe's z-score.
173
174 The context manager takes a ``seed`` so calibration runs can be
175 reproduced and multiple independent null samples can be drawn to
176 estimate ``std``.
177
178 Implementations MUST restore the real adapter on exit, including
179 on exceptions, so a caller can freely interleave null and real
180 calibrations within the same backend lifetime.
181 """
182
183 def as_null_adapter(
184 self, seed: int, *, init_scale: float = 0.02
185 ) -> AbstractContextManager[_ScoringModel]: ...
186
187
188 # Helper Protocol for type-checking the yielded context object: it
189 # must satisfy both Model and ScoringBackend. mypy doesn't support
190 # intersection types, so we spell it out explicitly.
191 @runtime_checkable
192 class _ScoringModel(Model, ScoringBackend, Protocol):
193 """A Model that also exposes ScoringBackend."""
194
195 ...
196
197
198 ScoringModel = _ScoringModel
199 """Public alias for the intersection ``Model & ScoringBackend``.
200
201 Exported for backend and probe implementations that need to annotate
202 variables of this combined type.
203 """