Python · 9735 bytes Raw Blame History
1 """External-perplexity-gap probe (S09, audit §F3).
2
3 Measures how much the adapter shifted the model's behavior on *held-out
4 natural prose* — text the model has seen a lot of during pretraining
5 and that has nothing to do with the training document. This is the
6 complement to :mod:`dlm_sway.probes.calibration_drift`:
7
8 - ``calibration_drift`` asks "did the adapter regress specific
9 factual Q/A items?"
10 - ``external_perplexity`` asks "did the adapter raise the model's
11 perplexity on natural English prose in general?"
12
13 A healthy, targeted fine-tune shifts the model toward the document's
14 content; it should leave the model's fluency on unrelated natural
15 prose roughly intact. An over-fit fine-tune (too many steps, too high
16 a learning rate, too small a training set) drifts the whole language
17 model toward the document's register and raises perplexity on
18 everything else — often invisibly to ``calibration_drift`` if the
19 degradation is diffuse (all items nudged slightly, none crossing the
20 regression threshold).
21
22 Metric: ``mean_delta_nats`` is the mean of per-token logprob deltas
23 ``(logprob_ft - logprob_base) / num_tokens`` across chunks. Positive
24 values mean ft assigns higher probability to external prose than base
25 did (rare but possible on a multilingual adapter that improved English
26 modeling incidentally). Negative values mean ft's perplexity rose
27 (forgetting). The metric is higher-is-better, so the raw z-score
28 against a null-adapter distribution maps directly onto the shared
29 ``z >= assert_z_gte`` rule — no sign flip: the adapter passes when
30 ``mean_delta`` sits at least ``assert_z_gte`` σ *above* the null's
31 distribution of ``mean_delta`` on the same corpus.
32 """
33
34 from __future__ import annotations
35
36 import math
37 import statistics
38 from typing import Literal
39
40 from pydantic import Field
41
42 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
43 from dlm_sway.core.stats import bootstrap_ci
44 from dlm_sway.probes._external_corpus import (
45 available_corpora,
46 chunk_corpus,
47 load_corpus,
48 )
49 from dlm_sway.probes._zscore import (
50 no_calibration_note,
51 score_from_z,
52 verdict_from_z,
53 z_score,
54 z_scores_by_rank,
55 )
56 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
57 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
58
59 CorpusName = Literal["public_domain_en"]
60
61
62 class ExternalPerplexitySpec(ProbeSpec):
63 """Spec for ``kind: external_perplexity``."""
64
65 kind: Literal["external_perplexity"] = "external_perplexity"
66 corpus: CorpusName = "public_domain_en"
67 """Which packaged public-domain corpus to measure against. See
68 :func:`dlm_sway.probes._external_corpus.available_corpora` for
69 the installed set."""
70 chunk_chars: int = Field(default=2048, ge=128, le=16_384)
71 """Characters per chunk — controls the rolling-logprob window. At
72 2048 chars each chunk fits comfortably inside a 1-2k token context
73 for SmolLM2-sized models."""
74 max_chunks: int = Field(default=16, ge=1, le=128)
75 """Hard cap on chunks the probe processes. Each chunk is 2 forward
76 passes (base + ft); 16 chunks ≈ 32 passes ≈ 8 s on CPU for a
77 135 M model. Lower for faster suites."""
78 assert_mean_delta_gte: float = -0.1
79 """Fallback threshold when no null stats are available. Mean
80 per-token logprob delta must be ≥ this (negative = worse ft)."""
81 assert_z_gte: float = 3.0
82 """Z-score pass criterion against the null-adapter baseline.
83 ``mean_delta`` is higher-is-better (positive = ft is more confident
84 on external prose than base), so the raw z-score is compared
85 directly: the adapter must be at least ``assert_z_gte`` σ *above*
86 the null baseline's ``mean_delta`` distribution — σ *better than
87 noise* on external prose fluency."""
88
89
90 class ExternalPerplexityProbe(Probe):
91 """Diffuse-forgetting detector on held-out natural prose."""
92
93 kind = "external_perplexity"
94 spec_cls = ExternalPerplexitySpec
95 category = "calibration"
96
97 @classmethod
98 def calibrate_spec(cls, ctx: RunContext) -> ExternalPerplexitySpec | None:
99 # Cheap calibration: 4 chunks × 2 views × N seeds. Each chunk
100 # is the same 2 KB slice across seeds, so the S07 cache turns
101 # later seeds into hits on the base side.
102 del ctx
103 return ExternalPerplexitySpec(
104 name="_calibration",
105 kind="external_perplexity",
106 max_chunks=4,
107 )
108
109 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
110 assert isinstance(spec, ExternalPerplexitySpec)
111 if spec.corpus not in available_corpora():
112 return ProbeResult(
113 name=spec.name,
114 kind=spec.kind,
115 verdict=Verdict.ERROR,
116 score=None,
117 message=(f"unknown corpus {spec.corpus!r}; available: {available_corpora()!r}"),
118 )
119
120 try:
121 corpus_text = load_corpus(spec.corpus)
122 except OSError as exc:
123 return ProbeResult(
124 name=spec.name,
125 kind=spec.kind,
126 verdict=Verdict.ERROR,
127 score=None,
128 message=f"failed to load corpus {spec.corpus!r}: {exc}",
129 )
130
131 chunks = chunk_corpus(corpus_text, chunk_chars=spec.chunk_chars, max_chunks=spec.max_chunks)
132 if not chunks:
133 return ProbeResult(
134 name=spec.name,
135 kind=spec.kind,
136 verdict=Verdict.ERROR,
137 score=None,
138 message=(
139 f"corpus {spec.corpus!r} chunked to zero pieces "
140 f"(chunk_chars={spec.chunk_chars}, max_chunks={spec.max_chunks})"
141 ),
142 )
143
144 per_chunk_deltas: list[float] = []
145 total_base_tokens = 0
146 total_ft_tokens = 0
147 total_base_lp = 0.0
148 total_ft_lp = 0.0
149 for chunk in chunks:
150 with ctx.require_backend.as_base() as b:
151 base_rl = b.rolling_logprob(chunk)
152 with ctx.require_backend.as_finetuned() as f:
153 ft_rl = f.rolling_logprob(chunk)
154 # Per-token mean logprob for this chunk. ``logprobs.size``
155 # is ``num_tokens - 1`` by the RollingLogprob contract.
156 base_n = max(base_rl.logprobs.size, 1)
157 ft_n = max(ft_rl.logprobs.size, 1)
158 base_per_tok = float(base_rl.total_logprob) / base_n
159 ft_per_tok = float(ft_rl.total_logprob) / ft_n
160 # Skip chunks whose base_n or ft_n is 0 — happens only on
161 # genuinely empty text, which would be a probe bug, not an
162 # adapter signal. ``max(_, 1)`` above guards the division;
163 # here we filter non-finite results.
164 delta = ft_per_tok - base_per_tok
165 if math.isfinite(delta):
166 per_chunk_deltas.append(delta)
167 total_base_tokens += base_n
168 total_ft_tokens += ft_n
169 total_base_lp += float(base_rl.total_logprob)
170 total_ft_lp += float(ft_rl.total_logprob)
171
172 if not per_chunk_deltas:
173 return ProbeResult(
174 name=spec.name,
175 kind=spec.kind,
176 verdict=Verdict.ERROR,
177 score=None,
178 message="every chunk produced a non-finite delta",
179 )
180
181 mean_delta = statistics.fmean(per_chunk_deltas)
182 base_mean_per_tok = total_base_lp / max(total_base_tokens, 1)
183 ft_mean_per_tok = total_ft_lp / max(total_ft_tokens, 1)
184 ci_95 = bootstrap_ci(per_chunk_deltas, seed=ctx.seed)
185
186 # Null calibration is the preferred path. ``mean_delta`` is
187 # higher-is-better (positive = ft assigns higher probability to
188 # external prose than base did), so the raw z-score already
189 # reads as "σ better than noise" — no sign flip.
190 stats = get_null_stats(ctx, spec.kind)
191 z = z_score(mean_delta, stats)
192 z_by_rank = z_scores_by_rank(mean_delta, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
193 verdict_z = verdict_from_z(z, spec.assert_z_gte)
194 if verdict_z is not None:
195 verdict = verdict_z
196 score_val = score_from_z(z)
197 score = score_val if score_val is not None else 0.0
198 message = (
199 f"external_ppl delta={mean_delta:+.3f} nats/tok, "
200 f"z={z:+.2f}σ vs null (higher-is-better)"
201 )
202 else:
203 verdict = Verdict.PASS if mean_delta >= spec.assert_mean_delta_gte else Verdict.FAIL
204 score = max(0.0, min(1.0, 0.5 + mean_delta))
205 message = (
206 f"external_ppl delta={mean_delta:+.3f} nats/tok "
207 f"({'≥' if verdict == Verdict.PASS else '<'} "
208 f"{spec.assert_mean_delta_gte}) {no_calibration_note(spec.kind)}"
209 )
210
211 return safe_finalize(
212 name=spec.name,
213 kind=spec.kind,
214 verdict=verdict,
215 score=score,
216 raw=mean_delta,
217 z_score=z,
218 base_value=base_mean_per_tok,
219 ft_value=ft_mean_per_tok,
220 evidence={
221 "corpus": spec.corpus,
222 "chunk_chars": spec.chunk_chars,
223 "num_chunks": len(per_chunk_deltas),
224 "per_chunk_delta": per_chunk_deltas,
225 "base_mean_logprob_per_tok": base_mean_per_tok,
226 "ft_mean_logprob_per_tok": ft_mean_per_tok,
227 "weight": spec.weight,
228 "z_by_rank": z_by_rank,
229 "raw_ci_95": list(ci_95) if ci_95 is not None else None,
230 },
231 message=message,
232 ci_95=ci_95,
233 )