| 1 | """External-perplexity-gap probe (S09, audit §F3). |
| 2 | |
| 3 | Measures how much the adapter shifted the model's behavior on *held-out |
| 4 | natural prose* — text the model has seen a lot of during pretraining |
| 5 | and that has nothing to do with the training document. This is the |
| 6 | complement to :mod:`dlm_sway.probes.calibration_drift`: |
| 7 | |
| 8 | - ``calibration_drift`` asks "did the adapter regress specific |
| 9 | factual Q/A items?" |
| 10 | - ``external_perplexity`` asks "did the adapter raise the model's |
| 11 | perplexity on natural English prose in general?" |
| 12 | |
| 13 | A healthy, targeted fine-tune shifts the model toward the document's |
| 14 | content; it should leave the model's fluency on unrelated natural |
| 15 | prose roughly intact. An over-fit fine-tune (too many steps, too high |
| 16 | a learning rate, too small a training set) drifts the whole language |
| 17 | model toward the document's register and raises perplexity on |
| 18 | everything else — often invisibly to ``calibration_drift`` if the |
| 19 | degradation is diffuse (all items nudged slightly, none crossing the |
| 20 | regression threshold). |
| 21 | |
| 22 | Metric: ``mean_delta_nats`` is the mean of per-token logprob deltas |
| 23 | ``(logprob_ft - logprob_base) / num_tokens`` across chunks. Positive |
| 24 | values mean ft assigns higher probability to external prose than base |
| 25 | did (rare but possible on a multilingual adapter that improved English |
| 26 | modeling incidentally). Negative values mean ft's perplexity rose |
| 27 | (forgetting). The metric is higher-is-better, so the raw z-score |
| 28 | against a null-adapter distribution maps directly onto the shared |
| 29 | ``z >= assert_z_gte`` rule — no sign flip: the adapter passes when |
| 30 | ``mean_delta`` sits at least ``assert_z_gte`` σ *above* the null's |
| 31 | distribution of ``mean_delta`` on the same corpus. |
| 32 | """ |
| 33 | |
| 34 | from __future__ import annotations |
| 35 | |
| 36 | import math |
| 37 | import statistics |
| 38 | from typing import Literal |
| 39 | |
| 40 | from pydantic import Field |
| 41 | |
| 42 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 43 | from dlm_sway.core.stats import bootstrap_ci |
| 44 | from dlm_sway.probes._external_corpus import ( |
| 45 | available_corpora, |
| 46 | chunk_corpus, |
| 47 | load_corpus, |
| 48 | ) |
| 49 | from dlm_sway.probes._zscore import ( |
| 50 | no_calibration_note, |
| 51 | score_from_z, |
| 52 | verdict_from_z, |
| 53 | z_score, |
| 54 | z_scores_by_rank, |
| 55 | ) |
| 56 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 57 | from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank |
| 58 | |
| 59 | CorpusName = Literal["public_domain_en"] |
| 60 | |
| 61 | |
| 62 | class ExternalPerplexitySpec(ProbeSpec): |
| 63 | """Spec for ``kind: external_perplexity``.""" |
| 64 | |
| 65 | kind: Literal["external_perplexity"] = "external_perplexity" |
| 66 | corpus: CorpusName = "public_domain_en" |
| 67 | """Which packaged public-domain corpus to measure against. See |
| 68 | :func:`dlm_sway.probes._external_corpus.available_corpora` for |
| 69 | the installed set.""" |
| 70 | chunk_chars: int = Field(default=2048, ge=128, le=16_384) |
| 71 | """Characters per chunk — controls the rolling-logprob window. At |
| 72 | 2048 chars each chunk fits comfortably inside a 1-2k token context |
| 73 | for SmolLM2-sized models.""" |
| 74 | max_chunks: int = Field(default=16, ge=1, le=128) |
| 75 | """Hard cap on chunks the probe processes. Each chunk is 2 forward |
| 76 | passes (base + ft); 16 chunks ≈ 32 passes ≈ 8 s on CPU for a |
| 77 | 135 M model. Lower for faster suites.""" |
| 78 | assert_mean_delta_gte: float = -0.1 |
| 79 | """Fallback threshold when no null stats are available. Mean |
| 80 | per-token logprob delta must be ≥ this (negative = worse ft).""" |
| 81 | assert_z_gte: float = 3.0 |
| 82 | """Z-score pass criterion against the null-adapter baseline. |
| 83 | ``mean_delta`` is higher-is-better (positive = ft is more confident |
| 84 | on external prose than base), so the raw z-score is compared |
| 85 | directly: the adapter must be at least ``assert_z_gte`` σ *above* |
| 86 | the null baseline's ``mean_delta`` distribution — σ *better than |
| 87 | noise* on external prose fluency.""" |
| 88 | |
| 89 | |
| 90 | class ExternalPerplexityProbe(Probe): |
| 91 | """Diffuse-forgetting detector on held-out natural prose.""" |
| 92 | |
| 93 | kind = "external_perplexity" |
| 94 | spec_cls = ExternalPerplexitySpec |
| 95 | category = "calibration" |
| 96 | |
| 97 | @classmethod |
| 98 | def calibrate_spec(cls, ctx: RunContext) -> ExternalPerplexitySpec | None: |
| 99 | # Cheap calibration: 4 chunks × 2 views × N seeds. Each chunk |
| 100 | # is the same 2 KB slice across seeds, so the S07 cache turns |
| 101 | # later seeds into hits on the base side. |
| 102 | del ctx |
| 103 | return ExternalPerplexitySpec( |
| 104 | name="_calibration", |
| 105 | kind="external_perplexity", |
| 106 | max_chunks=4, |
| 107 | ) |
| 108 | |
| 109 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 110 | assert isinstance(spec, ExternalPerplexitySpec) |
| 111 | if spec.corpus not in available_corpora(): |
| 112 | return ProbeResult( |
| 113 | name=spec.name, |
| 114 | kind=spec.kind, |
| 115 | verdict=Verdict.ERROR, |
| 116 | score=None, |
| 117 | message=(f"unknown corpus {spec.corpus!r}; available: {available_corpora()!r}"), |
| 118 | ) |
| 119 | |
| 120 | try: |
| 121 | corpus_text = load_corpus(spec.corpus) |
| 122 | except OSError as exc: |
| 123 | return ProbeResult( |
| 124 | name=spec.name, |
| 125 | kind=spec.kind, |
| 126 | verdict=Verdict.ERROR, |
| 127 | score=None, |
| 128 | message=f"failed to load corpus {spec.corpus!r}: {exc}", |
| 129 | ) |
| 130 | |
| 131 | chunks = chunk_corpus(corpus_text, chunk_chars=spec.chunk_chars, max_chunks=spec.max_chunks) |
| 132 | if not chunks: |
| 133 | return ProbeResult( |
| 134 | name=spec.name, |
| 135 | kind=spec.kind, |
| 136 | verdict=Verdict.ERROR, |
| 137 | score=None, |
| 138 | message=( |
| 139 | f"corpus {spec.corpus!r} chunked to zero pieces " |
| 140 | f"(chunk_chars={spec.chunk_chars}, max_chunks={spec.max_chunks})" |
| 141 | ), |
| 142 | ) |
| 143 | |
| 144 | per_chunk_deltas: list[float] = [] |
| 145 | total_base_tokens = 0 |
| 146 | total_ft_tokens = 0 |
| 147 | total_base_lp = 0.0 |
| 148 | total_ft_lp = 0.0 |
| 149 | for chunk in chunks: |
| 150 | with ctx.require_backend.as_base() as b: |
| 151 | base_rl = b.rolling_logprob(chunk) |
| 152 | with ctx.require_backend.as_finetuned() as f: |
| 153 | ft_rl = f.rolling_logprob(chunk) |
| 154 | # Per-token mean logprob for this chunk. ``logprobs.size`` |
| 155 | # is ``num_tokens - 1`` by the RollingLogprob contract. |
| 156 | base_n = max(base_rl.logprobs.size, 1) |
| 157 | ft_n = max(ft_rl.logprobs.size, 1) |
| 158 | base_per_tok = float(base_rl.total_logprob) / base_n |
| 159 | ft_per_tok = float(ft_rl.total_logprob) / ft_n |
| 160 | # Skip chunks whose base_n or ft_n is 0 — happens only on |
| 161 | # genuinely empty text, which would be a probe bug, not an |
| 162 | # adapter signal. ``max(_, 1)`` above guards the division; |
| 163 | # here we filter non-finite results. |
| 164 | delta = ft_per_tok - base_per_tok |
| 165 | if math.isfinite(delta): |
| 166 | per_chunk_deltas.append(delta) |
| 167 | total_base_tokens += base_n |
| 168 | total_ft_tokens += ft_n |
| 169 | total_base_lp += float(base_rl.total_logprob) |
| 170 | total_ft_lp += float(ft_rl.total_logprob) |
| 171 | |
| 172 | if not per_chunk_deltas: |
| 173 | return ProbeResult( |
| 174 | name=spec.name, |
| 175 | kind=spec.kind, |
| 176 | verdict=Verdict.ERROR, |
| 177 | score=None, |
| 178 | message="every chunk produced a non-finite delta", |
| 179 | ) |
| 180 | |
| 181 | mean_delta = statistics.fmean(per_chunk_deltas) |
| 182 | base_mean_per_tok = total_base_lp / max(total_base_tokens, 1) |
| 183 | ft_mean_per_tok = total_ft_lp / max(total_ft_tokens, 1) |
| 184 | ci_95 = bootstrap_ci(per_chunk_deltas, seed=ctx.seed) |
| 185 | |
| 186 | # Null calibration is the preferred path. ``mean_delta`` is |
| 187 | # higher-is-better (positive = ft assigns higher probability to |
| 188 | # external prose than base did), so the raw z-score already |
| 189 | # reads as "σ better than noise" — no sign flip. |
| 190 | stats = get_null_stats(ctx, spec.kind) |
| 191 | z = z_score(mean_delta, stats) |
| 192 | z_by_rank = z_scores_by_rank(mean_delta, get_null_stats_by_rank(ctx, spec.kind), sign=+1) |
| 193 | verdict_z = verdict_from_z(z, spec.assert_z_gte) |
| 194 | if verdict_z is not None: |
| 195 | verdict = verdict_z |
| 196 | score_val = score_from_z(z) |
| 197 | score = score_val if score_val is not None else 0.0 |
| 198 | message = ( |
| 199 | f"external_ppl delta={mean_delta:+.3f} nats/tok, " |
| 200 | f"z={z:+.2f}σ vs null (higher-is-better)" |
| 201 | ) |
| 202 | else: |
| 203 | verdict = Verdict.PASS if mean_delta >= spec.assert_mean_delta_gte else Verdict.FAIL |
| 204 | score = max(0.0, min(1.0, 0.5 + mean_delta)) |
| 205 | message = ( |
| 206 | f"external_ppl delta={mean_delta:+.3f} nats/tok " |
| 207 | f"({'≥' if verdict == Verdict.PASS else '<'} " |
| 208 | f"{spec.assert_mean_delta_gte}) {no_calibration_note(spec.kind)}" |
| 209 | ) |
| 210 | |
| 211 | return safe_finalize( |
| 212 | name=spec.name, |
| 213 | kind=spec.kind, |
| 214 | verdict=verdict, |
| 215 | score=score, |
| 216 | raw=mean_delta, |
| 217 | z_score=z, |
| 218 | base_value=base_mean_per_tok, |
| 219 | ft_value=ft_mean_per_tok, |
| 220 | evidence={ |
| 221 | "corpus": spec.corpus, |
| 222 | "chunk_chars": spec.chunk_chars, |
| 223 | "num_chunks": len(per_chunk_deltas), |
| 224 | "per_chunk_delta": per_chunk_deltas, |
| 225 | "base_mean_logprob_per_tok": base_mean_per_tok, |
| 226 | "ft_mean_logprob_per_tok": ft_mean_per_tok, |
| 227 | "weight": spec.weight, |
| 228 | "z_by_rank": z_by_rank, |
| 229 | "raw_ci_95": list(ci_95) if ci_95 is not None else None, |
| 230 | }, |
| 231 | message=message, |
| 232 | ci_95=ci_95, |
| 233 | ) |