| 1 | """Scoring protocols: logprobs, next-token distributions, differential toggling. |
| 2 | |
| 3 | Scoring is **separate** from generation because not every backend can |
| 4 | provide logits. Every numeric sway probe depends on at least one of |
| 5 | three operations: |
| 6 | |
| 7 | 1. ``logprob_of(prompt, completion)`` — score a completion against a |
| 8 | prompt (A1, B2, B3, C2, …). |
| 9 | 2. ``rolling_logprob(text)`` — perplexity over a piece of text (B1, |
| 10 | C2). |
| 11 | 3. ``next_token_dist(prompt, top_k)`` — the raw next-token distribution |
| 12 | at a single position (A1, N2). |
| 13 | |
| 14 | The :class:`DifferentialBackend` is the key performance primitive: |
| 15 | both base and fine-tuned views share the same loaded weights and KV |
| 16 | cache layout, toggled via PEFT's :meth:`set_adapter` / |
| 17 | :meth:`disable_adapter`. A naive "load twice" implementation would |
| 18 | double memory and halve throughput. |
| 19 | """ |
| 20 | |
| 21 | from __future__ import annotations |
| 22 | |
| 23 | from contextlib import AbstractContextManager |
| 24 | from dataclasses import dataclass, field |
| 25 | from typing import Protocol, runtime_checkable |
| 26 | |
| 27 | import numpy as np |
| 28 | from numpy.typing import NDArray |
| 29 | |
| 30 | from dlm_sway.core.model import Model |
| 31 | |
| 32 | |
| 33 | @dataclass(frozen=True, slots=True) |
| 34 | class RollingLogprob: |
| 35 | """Per-token logprobs over a piece of text, plus summary stats. |
| 36 | |
| 37 | Attributes |
| 38 | ---------- |
| 39 | token_ids: |
| 40 | The tokenizer output for ``text``. Length ``N``. |
| 41 | logprobs: |
| 42 | ``log p(token_i | token_<i)`` for each position i ≥ 1. Length |
| 43 | ``N-1``. |
| 44 | num_tokens: |
| 45 | ``N`` — included for convenience; ``len(token_ids)``. |
| 46 | total_logprob: |
| 47 | Sum of :attr:`logprobs`. |
| 48 | """ |
| 49 | |
| 50 | token_ids: NDArray[np.int64] |
| 51 | logprobs: NDArray[np.float32] |
| 52 | num_tokens: int |
| 53 | total_logprob: float |
| 54 | |
| 55 | @property |
| 56 | def mean_logprob(self) -> float: |
| 57 | n = self.logprobs.size |
| 58 | return float(self.total_logprob / n) if n else 0.0 |
| 59 | |
| 60 | @property |
| 61 | def perplexity(self) -> float: |
| 62 | """``exp(-mean_logprob)``. Base-e, natural perplexity.""" |
| 63 | return float(np.exp(-self.mean_logprob)) |
| 64 | |
| 65 | |
| 66 | @dataclass(frozen=True, slots=True) |
| 67 | class TokenDist: |
| 68 | """A (possibly top-k truncated) next-token probability distribution. |
| 69 | |
| 70 | For KL / JS divergence probes sway needs matched distributions |
| 71 | across base and fine-tuned views. The runner is responsible for |
| 72 | aligning ``top_k`` token slices between two ``TokenDist`` objects |
| 73 | before handing them to divergence math. |
| 74 | """ |
| 75 | |
| 76 | token_ids: NDArray[np.int64] |
| 77 | """Token ids, descending by probability. Length ``k``.""" |
| 78 | logprobs: NDArray[np.float32] |
| 79 | """Log-probabilities for :attr:`token_ids`. Length ``k``.""" |
| 80 | vocab_size: int |
| 81 | """Full vocab size — needed to renormalize top-k truncated slices.""" |
| 82 | tail_logprob: float = field(default=0.0) |
| 83 | """log of (1 - sum of exp(logprobs[:k])); 0 if top_k covers the full vocab.""" |
| 84 | |
| 85 | |
| 86 | @runtime_checkable |
| 87 | class ScoringBackend(Protocol): |
| 88 | """Logit-level access to a loaded model.""" |
| 89 | |
| 90 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 91 | """Sum of log-probabilities of ``completion`` tokens given ``prompt``. |
| 92 | |
| 93 | The prompt is *not* scored; only the completion contributes. The |
| 94 | value is in nats (natural log). Longer completions are |
| 95 | monotonically more negative — callers normalize by length if |
| 96 | they need a rate. |
| 97 | """ |
| 98 | ... |
| 99 | |
| 100 | def rolling_logprob(self, text: str) -> RollingLogprob: |
| 101 | """Compute per-token logprobs for the whole of ``text``. |
| 102 | |
| 103 | Equivalent to lm-eval's ``loglikelihood_rolling``. Used for |
| 104 | perplexity comparison on held-out content (B1 SIS, C2). |
| 105 | """ |
| 106 | ... |
| 107 | |
| 108 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 109 | """Next-token distribution at the position after ``prompt``. |
| 110 | |
| 111 | Truncated to ``top_k`` for memory; callers doing divergence math |
| 112 | over the top-k slice accept the (typically negligible) error vs |
| 113 | full-vocab KL. |
| 114 | """ |
| 115 | ... |
| 116 | |
| 117 | |
| 118 | @runtime_checkable |
| 119 | class DifferentialBackend(Protocol): |
| 120 | """A backend that holds base + fine-tuned views on a single loaded model. |
| 121 | |
| 122 | The idiomatic usage is:: |
| 123 | |
| 124 | with backend.as_base() as base_view: |
| 125 | p_base = base_view.next_token_dist(prompt) |
| 126 | with backend.as_finetuned() as ft_view: |
| 127 | p_ft = ft_view.next_token_dist(prompt) |
| 128 | |
| 129 | Implementations toggle PEFT adapters via |
| 130 | :meth:`peft.PeftModel.set_adapter` / :meth:`disable_adapter`. |
| 131 | |
| 132 | Invariant: the two views must be **not simultaneously usable**. A |
| 133 | caller holding a ``base_view`` after entering the ``as_finetuned`` |
| 134 | context is a programmer error and implementations MUST detect and |
| 135 | raise. |
| 136 | """ |
| 137 | |
| 138 | def as_base(self) -> AbstractContextManager[_ScoringModel]: ... |
| 139 | |
| 140 | def as_finetuned(self) -> AbstractContextManager[_ScoringModel]: ... |
| 141 | |
| 142 | |
| 143 | @runtime_checkable |
| 144 | class ScalableDifferentialBackend(DifferentialBackend, Protocol): |
| 145 | """A differential backend that can also scale the LoRA additive term. |
| 146 | |
| 147 | LoRA applies ``W + (alpha/r) · B @ A`` to a base weight matrix. This |
| 148 | protocol exposes a context manager that temporarily multiplies that |
| 149 | additive term by ``lam`` for everything inside the ``with`` block. |
| 150 | |
| 151 | ``lam = 0.0`` is equivalent to :meth:`as_base`. |
| 152 | ``lam = 1.0`` is equivalent to :meth:`as_finetuned`. |
| 153 | ``lam = 1.25`` overshoots — useful for N2 AdapterAblation's |
| 154 | response-curve measurement. |
| 155 | |
| 156 | Only the HF backend ships an implementation in v0.1. Probes that |
| 157 | need scaling check via ``isinstance(backend, ScalableDifferentialBackend)`` |
| 158 | at runtime and SKIP gracefully when unavailable. |
| 159 | """ |
| 160 | |
| 161 | def as_scaled_adapter(self, lam: float) -> AbstractContextManager[_ScoringModel]: ... |
| 162 | |
| 163 | |
| 164 | @runtime_checkable |
| 165 | class NullCalibratedBackend(DifferentialBackend, Protocol): |
| 166 | """A differential backend that can produce a "null adapter" view. |
| 167 | |
| 168 | A null adapter has the *same structure* (rank, alpha, target modules) |
| 169 | as the real adapter but with weights drawn from a zero-mean Gaussian. |
| 170 | Running probes against this view yields the baseline "how much |
| 171 | signal does random noise produce" distribution — the denominator in |
| 172 | every numeric probe's z-score. |
| 173 | |
| 174 | The context manager takes a ``seed`` so calibration runs can be |
| 175 | reproduced and multiple independent null samples can be drawn to |
| 176 | estimate ``std``. |
| 177 | |
| 178 | Implementations MUST restore the real adapter on exit, including |
| 179 | on exceptions, so a caller can freely interleave null and real |
| 180 | calibrations within the same backend lifetime. |
| 181 | """ |
| 182 | |
| 183 | def as_null_adapter( |
| 184 | self, seed: int, *, init_scale: float = 0.02 |
| 185 | ) -> AbstractContextManager[_ScoringModel]: ... |
| 186 | |
| 187 | |
| 188 | # Helper Protocol for type-checking the yielded context object: it |
| 189 | # must satisfy both Model and ScoringBackend. mypy doesn't support |
| 190 | # intersection types, so we spell it out explicitly. |
| 191 | @runtime_checkable |
| 192 | class _ScoringModel(Model, ScoringBackend, Protocol): |
| 193 | """A Model that also exposes ScoringBackend.""" |
| 194 | |
| 195 | ... |
| 196 | |
| 197 | |
| 198 | ScoringModel = _ScoringModel |
| 199 | """Public alias for the intersection ``Model & ScoringBackend``. |
| 200 | |
| 201 | Exported for backend and probe implementations that need to annotate |
| 202 | variables of this combined type. |
| 203 | """ |