Python · 18764 bytes Raw Blame History
1 """Null-adapter baseline probe — per-kind calibration matrix (S02).
2
3 Every numeric primitive reports its raw metric *and* a z-score against
4 a null-adapter distribution. This probe is the runtime engine that
5 establishes that distribution — for **every** numeric probe kind the
6 user has downstream in the suite, not just one.
7
8 How it works:
9
10 1. The runner populates ``ctx.downstream_kinds`` with every probe kind
11 that appears after this one in the suite.
12 2. For each target kind, we ask its probe class for a
13 :meth:`~dlm_sway.probes.base.Probe.calibrate_spec` — a small spec
14 suitable for null calibration. A probe that returns ``None`` opts
15 out (typically because its inputs can't be synthesized, e.g.
16 ``adapter_revert`` without an embedder, or ``adapter_ablation``
17 which needs ``as_scaled_adapter`` that the proxy doesn't expose).
18 3. For each calibrating kind × seed, we run the probe through a
19 :class:`~dlm_sway.probes._null_proxy.NullCalibrationBackendProxy`
20 which makes ``as_finetuned()`` yield ``as_null_adapter(seed)`` —
21 so the probe's own math is computing "what does my metric look
22 like when the fine-tune is structural noise?".
23 4. We harvest each run's ``raw`` value, aggregate to ``(mean, std, n)``
24 per kind, and publish under ``evidence["null_stats"]``.
25 5. The runner threads ``null_stats`` into ``RunContext`` for every
26 subsequent probe, which then prefers the z-score path over the
27 fixed-threshold path (see :mod:`dlm_sway.probes._zscore`).
28
29 Backends that don't implement
30 :class:`~dlm_sway.core.scoring.NullCalibratedBackend` cause this probe
31 to ``Verdict.SKIP``; every downstream probe falls back to fixed
32 thresholds and surfaces ``(no calibration)`` in the report.
33 """
34
35 from __future__ import annotations
36
37 import math
38 import statistics
39 from collections.abc import Mapping
40 from typing import Any, Literal
41
42 from pydantic import Field
43
44 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
45 from dlm_sway.core.scoring import NullCalibratedBackend
46 from dlm_sway.probes._null_cache import compute_key, load, save
47 from dlm_sway.probes._null_proxy import NullCalibrationBackendProxy
48 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext, registry
49
50
51 class NullAdapterSpec(ProbeSpec):
52 """Spec for ``kind: null_adapter``.
53
54 Place this probe **first** in the suite so its output populates
55 :attr:`RunContext.null_stats` before subsequent probes consult it.
56 """
57
58 kind: Literal["null_adapter"] = "null_adapter"
59 runs: int = Field(default=3, ge=1, le=10)
60 """Number of independent null adapters to evaluate. Three is the
61 smallest that yields a usable std; more is better but quickly
62 dominates suite runtime."""
63 init_scale: float = 0.02
64 """Stddev of the zero-mean Gaussian used to fill lora_A/lora_B."""
65 seed_base: int = 1000
66 """First seed; successive runs use ``seed_base + run_idx``."""
67 calibrate_kinds: list[str] = Field(default_factory=list)
68 """Which probe kinds to calibrate. Empty = auto-populate from
69 ``ctx.downstream_kinds`` (the kinds that appear after this probe
70 in the suite). Set explicitly to force calibration of specific
71 kinds regardless of suite order."""
72 cache: bool = True
73 """Read / write the on-disk calibration cache under
74 ``~/.dlm-sway/null-stats``. Keyed by backend identity + calibration
75 params. Disable to force a fresh calibration (e.g. when you suspect
76 the cached stats are stale)."""
77 rank_multipliers: list[float] = Field(default_factory=lambda: [1.0])
78 """Rank multipliers at which to calibrate. Each value scales the
79 null-adapter noise std by ``sqrt(multiplier)`` — mathematically
80 equivalent to rank-scaling the LoRA output variance. Default
81 ``[1.0]`` preserves pre-S10 single-rank behavior byte-for-byte.
82
83 Three-point profiles like ``[0.5, 1.0, 2.0]`` let users read
84 "how rank-saturated is my adapter?" off the report:
85
86 - A healthy adapter's z-score is stable across multipliers.
87 - An adapter that's barely above noise at its own rank but
88 solidly above noise at ``0.5x`` is rank-saturated — a smaller
89 rank would have yielded a sharper signal.
90
91 Per-rank stats land in ``evidence["null_stats_by_rank"]`` keyed
92 by ``f"rank_{mult:.2f}"``; the 1.0x group (when present) also
93 lands under ``evidence["null_stats"]`` for back-compat with
94 probes that consume a single calibration level.
95 """
96
97
98 class NullAdapterProbe(Probe):
99 """Populate ``ctx.null_stats`` with per-kind null distributions.
100
101 The probe itself reports ``Verdict.PASS`` on success — its job is
102 calibration, not judgment. If the backend can't support null-view
103 substitution, reports ``Verdict.SKIP`` with a clear message; every
104 downstream numeric probe then falls back to fixed thresholds.
105 """
106
107 kind = "null_adapter"
108 spec_cls = NullAdapterSpec
109 category = "baseline"
110
111 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
112 assert isinstance(spec, NullAdapterSpec)
113 if not isinstance(ctx.backend, NullCalibratedBackend):
114 return ProbeResult(
115 name=spec.name,
116 kind=spec.kind,
117 verdict=Verdict.SKIP,
118 score=None,
119 message=(
120 "backend does not implement NullCalibratedBackend — "
121 "numeric probes will fall back to fixed thresholds"
122 ),
123 )
124
125 registered = registry()
126
127 # Decide which kinds to calibrate. Explicit spec field wins;
128 # otherwise auto-populate from downstream_kinds.
129 target_kinds: list[str] = list(spec.calibrate_kinds)
130 if not target_kinds:
131 target_kinds = [k for k in ctx.downstream_kinds if k and k != spec.kind]
132 # De-dupe while preserving order; drop self and unregistered.
133 seen: set[str] = set()
134 filtered: list[str] = []
135 for k in target_kinds:
136 if k == spec.kind or k in seen or k not in registered:
137 continue
138 seen.add(k)
139 filtered.append(k)
140 target_kinds = filtered
141
142 # Validate rank multipliers up front; empty list is nonsensical.
143 rank_multipliers = list(spec.rank_multipliers) or [1.0]
144 for mult in rank_multipliers:
145 if mult <= 0.0 or not math.isfinite(mult):
146 return ProbeResult(
147 name=spec.name,
148 kind=spec.kind,
149 verdict=Verdict.ERROR,
150 score=None,
151 message=f"rank_multipliers must be positive and finite; got {mult!r}",
152 )
153
154 # Cache lookup: backends can opt in by providing a
155 # ``cache_identity()`` method returning a stable string. The
156 # key incorporates both that identity and the calibration
157 # parameters that actually influence the output — including
158 # the sorted rank-multiplier tuple so multi-rank caches don't
159 # collide with single-rank.
160 cache_key: str | None = None
161 if spec.cache:
162 backend_identity = _backend_identity(ctx.backend)
163 cache_key = compute_key(
164 backend_identity=backend_identity,
165 params={
166 "runs": spec.runs,
167 "init_scale": spec.init_scale,
168 "seed_base": spec.seed_base,
169 "top_k": ctx.top_k,
170 "kinds": sorted(target_kinds),
171 "rank_multipliers": sorted(rank_multipliers),
172 },
173 )
174 cached = load(cache_key)
175 if cached is not None and "null_stats_by_rank" in cached:
176 return _pass_from_cache(spec, cached)
177 # Pre-S10 cache entries only have ``null_stats`` (implicit
178 # single-rank). Promote them into the new shape so repeated
179 # runs benefit from the existing cache.
180 #
181 # DC3 (Audit 02) — this branch is legacy compatibility and
182 # will be removed in the next minor version bump. After S10
183 # shipped, every newly-written cache carries
184 # ``null_stats_by_rank``; promoted entries get rewritten on
185 # the first full recalibration. We keep this shim for one
186 # release cycle to avoid stranding users whose cache files
187 # predate S10.
188 if cached is not None and "null_stats" in cached:
189 promoted = dict(cached)
190 promoted["null_stats_by_rank"] = {
191 _rank_key(1.0): cached["null_stats"],
192 }
193 return _pass_from_cache(spec, promoted)
194
195 null_stats_by_rank: dict[str, dict[str, dict[str, float]]] = {}
196 per_rank_skipped: dict[str, list[dict[str, str]]] = {}
197 per_rank_samples: dict[str, dict[str, list[float]]] = {}
198
199 for mult in rank_multipliers:
200 rkey = _rank_key(mult)
201 per_kind_stats, samples, skipped = _calibrate_at_rank(
202 ctx=ctx,
203 spec=spec,
204 target_kinds=target_kinds,
205 registered=registered,
206 rank_scale=mult,
207 )
208 null_stats_by_rank[rkey] = per_kind_stats
209 per_rank_samples[rkey] = samples
210 per_rank_skipped[rkey] = skipped
211
212 # Back-compat surface: ``null_stats`` is the 1.0x group when
213 # present, else the first multiplier's stats (so older probes
214 # that only read the single-rank dict still get *something*).
215 primary_rkey = _rank_key(1.0)
216 if primary_rkey in null_stats_by_rank:
217 primary_stats = null_stats_by_rank[primary_rkey]
218 primary_skipped = per_rank_skipped[primary_rkey]
219 primary_samples = per_rank_samples[primary_rkey]
220 else:
221 first_rkey = _rank_key(rank_multipliers[0])
222 primary_stats = null_stats_by_rank[first_rkey]
223 primary_skipped = per_rank_skipped[first_rkey]
224 primary_samples = per_rank_samples[first_rkey]
225
226 evidence: dict[str, Any] = {
227 "null_stats": primary_stats,
228 "null_stats_by_rank": null_stats_by_rank,
229 "per_kind_raw_samples": primary_samples,
230 "skipped_kinds": primary_skipped,
231 "calibrated_kinds": list(primary_stats.keys()),
232 "runs": spec.runs,
233 "init_scale": spec.init_scale,
234 "seed_base": spec.seed_base,
235 "rank_multipliers": rank_multipliers,
236 "weight": spec.weight,
237 "from_cache": False,
238 }
239
240 if cache_key is not None:
241 save(
242 cache_key,
243 {
244 "null_stats": primary_stats,
245 "null_stats_by_rank": null_stats_by_rank,
246 "runs": spec.runs,
247 "init_scale": spec.init_scale,
248 "seed_base": spec.seed_base,
249 "rank_multipliers": rank_multipliers,
250 "calibrated_kinds": list(primary_stats.keys()),
251 },
252 )
253
254 if len(rank_multipliers) == 1:
255 message = (
256 f"null calibration: {len(primary_stats)} kinds calibrated over {spec.runs} seeds"
257 )
258 else:
259 mults_str = ", ".join(f"{m:g}x" for m in rank_multipliers)
260 message = (
261 f"null calibration: {len(primary_stats)} kinds × "
262 f"{len(rank_multipliers)} ranks [{mults_str}] over {spec.runs} seeds"
263 )
264 if primary_skipped:
265 message += f" ({len(primary_skipped)} opted out)"
266
267 return safe_finalize(
268 name=spec.name,
269 kind=spec.kind,
270 verdict=Verdict.PASS,
271 score=1.0,
272 evidence=evidence,
273 message=message,
274 )
275
276
277 def _rank_key(mult: float) -> str:
278 """Canonical string key for a rank multiplier. Stable across runs."""
279 return f"rank_{mult:.2f}"
280
281
282 def _calibrate_at_rank(
283 *,
284 ctx: RunContext,
285 spec: NullAdapterSpec,
286 target_kinds: list[str],
287 registered: dict[str, type[Probe]],
288 rank_scale: float,
289 ) -> tuple[dict[str, dict[str, float]], dict[str, list[float]], list[dict[str, str]]]:
290 """Run the full kind × seed calibration matrix at one rank multiplier.
291
292 Returns ``(per_kind_stats, per_kind_samples, skipped)``.
293 """
294 per_kind_stats: dict[str, dict[str, float]] = {}
295 per_kind_samples: dict[str, list[float]] = {}
296 skipped: list[dict[str, str]] = []
297
298 for kind in target_kinds:
299 probe_cls = registered[kind]
300 try:
301 cal_spec = probe_cls.calibrate_spec(ctx)
302 except Exception as exc: # noqa: BLE001 — defensive
303 skipped.append({"kind": kind, "reason": f"calibrate_spec raised: {exc}"})
304 continue
305 if cal_spec is None:
306 skipped.append(
307 {"kind": kind, "reason": "probe opted out (calibrate_spec returned None)"}
308 )
309 continue
310
311 probe = probe_cls()
312 raws: list[float] = []
313 errors: list[str] = []
314 for run_idx in range(spec.runs):
315 seed = spec.seed_base + run_idx
316 proxy = NullCalibrationBackendProxy(
317 ctx.backend, # type: ignore[arg-type]
318 seed=seed,
319 init_scale=spec.init_scale,
320 rank_scale=rank_scale,
321 )
322 cal_ctx = RunContext(
323 backend=proxy,
324 seed=seed,
325 top_k=ctx.top_k,
326 sections=ctx.sections,
327 doc_text=ctx.doc_text,
328 null_stats={}, # calibration uses fixed thresholds — no recursion
329 downstream_kinds=(),
330 )
331 try:
332 cal_result = probe.run(cal_spec, cal_ctx)
333 except Exception as exc: # noqa: BLE001
334 errors.append(f"seed={seed}: {type(exc).__name__}: {exc}")
335 continue
336 raw = cal_result.raw
337 if raw is not None and math.isfinite(raw):
338 raws.append(float(raw))
339 elif cal_result.verdict == Verdict.ERROR:
340 errors.append(f"seed={seed}: probe ERROR — {cal_result.message}")
341
342 if raws:
343 mean = statistics.fmean(raws)
344 raw_std = statistics.pstdev(raws) if len(raws) > 1 else 0.0
345 # F02 (Audit 03) — detect the degenerate case (``runs: 1``
346 # or every seed producing the *exact* same raw) as a first-
347 # class property of the stats dict. The previous code hid
348 # this via ``max(std, 1e-6)`` which collided with
349 # :data:``_zscore.MIN_STD`` and let the z-score path fire
350 # on a std that had been synthetically lifted from ``0.0``
351 # — the path that produced the ``+290,766σ`` observation in
352 # the audit. A multi-seed run with genuinely small variance
353 # (e.g. 5e-7 on a low-noise dummy) is NOT degenerate; we
354 # keep the 1e-6 floor for that case so valid-but-tight
355 # calibrations still z-score. ``z_score`` inspects both the
356 # ``degenerate`` flag and the ``std < MIN_STD`` threshold.
357 degenerate = len(raws) <= 1 or raw_std == 0.0
358 per_kind_stats[kind] = {
359 "mean": mean,
360 "std": max(raw_std, 1e-6),
361 "n": float(len(raws)),
362 "degenerate": 1.0 if degenerate else 0.0,
363 }
364 per_kind_samples[kind] = raws
365 else:
366 reason = "no finite raws across all seeds"
367 if errors:
368 reason += f" ({errors[0]})"
369 skipped.append({"kind": kind, "reason": reason})
370
371 return per_kind_stats, per_kind_samples, skipped
372
373
374 def _pass_from_cache(spec: NullAdapterSpec, cached: dict[str, Any]) -> ProbeResult:
375 """Rebuild a PASS result from a cache-loaded evidence dict."""
376 stats_by_rank: dict[str, dict[str, dict[str, float]]] = dict(
377 cached.get("null_stats_by_rank") or {}
378 )
379 # Prefer the explicit 1.0x group; fall back to the legacy ``null_stats``.
380 primary_stats = stats_by_rank.get(_rank_key(1.0), cached.get("null_stats", {}))
381 evidence: dict[str, Any] = dict(cached)
382 evidence["null_stats"] = primary_stats
383 evidence["null_stats_by_rank"] = stats_by_rank
384 evidence.setdefault("skipped_kinds", [])
385 evidence.setdefault("calibrated_kinds", list(primary_stats.keys()))
386 evidence["weight"] = spec.weight
387 evidence["from_cache"] = True
388 n_kinds = len(primary_stats)
389 n_ranks = len(stats_by_rank)
390 message = (
391 f"null calibration: {n_kinds} kinds (loaded from cache)"
392 if n_ranks <= 1
393 else f"null calibration: {n_kinds} kinds × {n_ranks} ranks (loaded from cache)"
394 )
395 return safe_finalize(
396 name=spec.name,
397 kind=spec.kind,
398 verdict=Verdict.PASS,
399 score=1.0,
400 evidence=evidence,
401 message=message,
402 )
403
404
405 def _backend_identity(backend: Any) -> str | None:
406 """Ask the backend for a stable cache identity string, if it has one.
407
408 Duck-typed: backends that can't uniquely identify themselves (the
409 dummy backend in tests, for example) simply don't provide this
410 method, and caching is skipped for them.
411 """
412 fn = getattr(backend, "cache_identity", None)
413 if not callable(fn):
414 return None
415 try:
416 value = fn()
417 except Exception: # noqa: BLE001 — cache is best-effort
418 return None
419 return str(value) if value else None
420
421
422 def get_null_stats(ctx: RunContext, probe_kind: str) -> Mapping[str, float] | None:
423 """Look up null-adapter stats for ``probe_kind`` in the run context.
424
425 Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for
426 this kind, else ``None``. Probes treat ``None`` as "fall back to
427 the fixed threshold from your spec" and surface ``(no calibration)``
428 in the report.
429 """
430 return ctx.null_stats.get(probe_kind)
431
432
433 def get_null_stats_by_rank(
434 ctx: RunContext, probe_kind: str
435 ) -> Mapping[str, Mapping[str, float]] | None:
436 """Look up per-rank null-adapter stats for ``probe_kind``.
437
438 Returns ``{rank_key: {"mean": …, "std": …, "n": …}}`` across every
439 rank multiplier the ``null_adapter`` probe calibrated. ``None`` when
440 no multi-rank calibration ran (pre-S10 behavior, or S02's single-
441 rank default).
442 """
443 by_rank = ctx.null_stats_by_rank
444 if not by_rank:
445 return None
446 out: dict[str, Mapping[str, float]] = {}
447 for rkey, kind_map in by_rank.items():
448 if probe_kind in kind_map:
449 out[rkey] = kind_map[probe_kind]
450 return out or None