Python · 11885 bytes Raw Blame History
1 """Null-adapter baseline probe — per-kind calibration matrix (S02).
2
3 Every numeric primitive reports its raw metric *and* a z-score against
4 a null-adapter distribution. This probe is the runtime engine that
5 establishes that distribution — for **every** numeric probe kind the
6 user has downstream in the suite, not just one.
7
8 How it works:
9
10 1. The runner populates ``ctx.downstream_kinds`` with every probe kind
11 that appears after this one in the suite.
12 2. For each target kind, we ask its probe class for a
13 :meth:`~dlm_sway.probes.base.Probe.calibrate_spec` — a small spec
14 suitable for null calibration. A probe that returns ``None`` opts
15 out (typically because its inputs can't be synthesized, e.g.
16 ``adapter_revert`` without an embedder, or ``adapter_ablation``
17 which needs ``as_scaled_adapter`` that the proxy doesn't expose).
18 3. For each calibrating kind × seed, we run the probe through a
19 :class:`~dlm_sway.probes._null_proxy.NullCalibrationBackendProxy`
20 which makes ``as_finetuned()`` yield ``as_null_adapter(seed)`` —
21 so the probe's own math is computing "what does my metric look
22 like when the fine-tune is structural noise?".
23 4. We harvest each run's ``raw`` value, aggregate to ``(mean, std, n)``
24 per kind, and publish under ``evidence["null_stats"]``.
25 5. The runner threads ``null_stats`` into ``RunContext`` for every
26 subsequent probe, which then prefers the z-score path over the
27 fixed-threshold path (see :mod:`dlm_sway.probes._zscore`).
28
29 Backends that don't implement
30 :class:`~dlm_sway.core.scoring.NullCalibratedBackend` cause this probe
31 to ``Verdict.SKIP``; every downstream probe falls back to fixed
32 thresholds and surfaces ``(no calibration)`` in the report.
33 """
34
35 from __future__ import annotations
36
37 import math
38 import statistics
39 from typing import Any, Literal
40
41 from pydantic import Field
42
43 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
44 from dlm_sway.core.scoring import NullCalibratedBackend
45 from dlm_sway.probes._null_cache import compute_key, load, save
46 from dlm_sway.probes._null_proxy import NullCalibrationBackendProxy
47 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext, registry
48
49
50 class NullAdapterSpec(ProbeSpec):
51 """Spec for ``kind: null_adapter``.
52
53 Place this probe **first** in the suite so its output populates
54 :attr:`RunContext.null_stats` before subsequent probes consult it.
55 """
56
57 kind: Literal["null_adapter"] = "null_adapter"
58 runs: int = Field(default=3, ge=1, le=10)
59 """Number of independent null adapters to evaluate. Three is the
60 smallest that yields a usable std; more is better but quickly
61 dominates suite runtime."""
62 init_scale: float = 0.02
63 """Stddev of the zero-mean Gaussian used to fill lora_A/lora_B."""
64 seed_base: int = 1000
65 """First seed; successive runs use ``seed_base + run_idx``."""
66 calibrate_kinds: list[str] = Field(default_factory=list)
67 """Which probe kinds to calibrate. Empty = auto-populate from
68 ``ctx.downstream_kinds`` (the kinds that appear after this probe
69 in the suite). Set explicitly to force calibration of specific
70 kinds regardless of suite order."""
71 cache: bool = True
72 """Read / write the on-disk calibration cache under
73 ``~/.dlm-sway/null-stats``. Keyed by backend identity + calibration
74 params. Disable to force a fresh calibration (e.g. when you suspect
75 the cached stats are stale)."""
76
77
78 class NullAdapterProbe(Probe):
79 """Populate ``ctx.null_stats`` with per-kind null distributions.
80
81 The probe itself reports ``Verdict.PASS`` on success — its job is
82 calibration, not judgment. If the backend can't support null-view
83 substitution, reports ``Verdict.SKIP`` with a clear message; every
84 downstream numeric probe then falls back to fixed thresholds.
85 """
86
87 kind = "null_adapter"
88 spec_cls = NullAdapterSpec
89 category = "baseline"
90
91 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
92 assert isinstance(spec, NullAdapterSpec)
93 if not isinstance(ctx.backend, NullCalibratedBackend):
94 return ProbeResult(
95 name=spec.name,
96 kind=spec.kind,
97 verdict=Verdict.SKIP,
98 score=None,
99 message=(
100 "backend does not implement NullCalibratedBackend — "
101 "numeric probes will fall back to fixed thresholds"
102 ),
103 )
104
105 registered = registry()
106
107 # Decide which kinds to calibrate. Explicit spec field wins;
108 # otherwise auto-populate from downstream_kinds.
109 target_kinds: list[str] = list(spec.calibrate_kinds)
110 if not target_kinds:
111 target_kinds = [k for k in ctx.downstream_kinds if k and k != spec.kind]
112 # De-dupe while preserving order; drop self and unregistered.
113 seen: set[str] = set()
114 filtered: list[str] = []
115 for k in target_kinds:
116 if k == spec.kind or k in seen or k not in registered:
117 continue
118 seen.add(k)
119 filtered.append(k)
120 target_kinds = filtered
121
122 # Cache lookup: backends can opt in by providing a
123 # ``cache_identity()`` method returning a stable string. The
124 # key incorporates both that identity and the calibration
125 # parameters that actually influence the output.
126 cache_key: str | None = None
127 if spec.cache:
128 backend_identity = _backend_identity(ctx.backend)
129 cache_key = compute_key(
130 backend_identity=backend_identity,
131 params={
132 "runs": spec.runs,
133 "init_scale": spec.init_scale,
134 "seed_base": spec.seed_base,
135 "top_k": ctx.top_k,
136 "kinds": sorted(target_kinds),
137 },
138 )
139 cached = load(cache_key)
140 if cached is not None and "null_stats" in cached:
141 cached_evidence: dict[str, Any] = dict(cached)
142 cached_evidence.setdefault("skipped_kinds", [])
143 cached_evidence.setdefault("calibrated_kinds", list(cached["null_stats"].keys()))
144 cached_evidence["weight"] = spec.weight
145 cached_evidence["from_cache"] = True
146 return safe_finalize(
147 name=spec.name,
148 kind=spec.kind,
149 verdict=Verdict.PASS,
150 score=1.0,
151 evidence=cached_evidence,
152 message=(
153 f"null calibration: {len(cached['null_stats'])} kinds (loaded from cache)"
154 ),
155 )
156
157 per_kind_stats: dict[str, dict[str, float]] = {}
158 per_kind_samples: dict[str, list[float]] = {}
159 skipped_kinds: list[dict[str, str]] = []
160
161 for kind in target_kinds:
162 probe_cls = registered[kind]
163 try:
164 cal_spec = probe_cls.calibrate_spec(ctx)
165 except Exception as exc: # noqa: BLE001 — defensive
166 skipped_kinds.append({"kind": kind, "reason": f"calibrate_spec raised: {exc}"})
167 continue
168 if cal_spec is None:
169 skipped_kinds.append(
170 {
171 "kind": kind,
172 "reason": "probe opted out (calibrate_spec returned None)",
173 }
174 )
175 continue
176
177 probe = probe_cls()
178 raws: list[float] = []
179 errors: list[str] = []
180 for run_idx in range(spec.runs):
181 seed = spec.seed_base + run_idx
182 proxy = NullCalibrationBackendProxy(
183 ctx.backend, seed=seed, init_scale=spec.init_scale
184 )
185 cal_ctx = RunContext(
186 backend=proxy,
187 seed=seed,
188 top_k=ctx.top_k,
189 sections=ctx.sections,
190 doc_text=ctx.doc_text,
191 null_stats={}, # calibration uses fixed thresholds — no recursion
192 downstream_kinds=(),
193 )
194 try:
195 cal_result = probe.run(cal_spec, cal_ctx)
196 except Exception as exc: # noqa: BLE001
197 errors.append(f"seed={seed}: {type(exc).__name__}: {exc}")
198 continue
199 raw = cal_result.raw
200 if raw is not None and math.isfinite(raw):
201 raws.append(float(raw))
202 elif cal_result.verdict == Verdict.ERROR:
203 errors.append(f"seed={seed}: probe ERROR — {cal_result.message}")
204
205 if raws:
206 mean = statistics.fmean(raws)
207 std = statistics.pstdev(raws) if len(raws) > 1 else 0.0
208 per_kind_stats[kind] = {
209 "mean": mean,
210 # C9: clamp the std floor so the downstream z-score
211 # path doesn't blow up when every seed produces
212 # identical raws.
213 "std": max(std, 1e-6),
214 "n": float(len(raws)),
215 }
216 per_kind_samples[kind] = raws
217 else:
218 reason = "no finite raws across all seeds"
219 if errors:
220 reason += f" ({errors[0]})"
221 skipped_kinds.append({"kind": kind, "reason": reason})
222
223 evidence: dict[str, Any] = {
224 "null_stats": per_kind_stats,
225 "per_kind_raw_samples": per_kind_samples,
226 "skipped_kinds": skipped_kinds,
227 "calibrated_kinds": list(per_kind_stats.keys()),
228 "runs": spec.runs,
229 "init_scale": spec.init_scale,
230 "seed_base": spec.seed_base,
231 "weight": spec.weight,
232 "from_cache": False,
233 }
234
235 if cache_key is not None:
236 # Persist the stats dict only — the samples list can be
237 # large, and downstream consumers only need the aggregates.
238 save(
239 cache_key,
240 {
241 "null_stats": per_kind_stats,
242 "runs": spec.runs,
243 "init_scale": spec.init_scale,
244 "seed_base": spec.seed_base,
245 "calibrated_kinds": list(per_kind_stats.keys()),
246 },
247 )
248
249 message = f"null calibration: {len(per_kind_stats)} kinds calibrated over {spec.runs} seeds"
250 if skipped_kinds:
251 message += f" ({len(skipped_kinds)} opted out)"
252
253 return safe_finalize(
254 name=spec.name,
255 kind=spec.kind,
256 verdict=Verdict.PASS,
257 score=1.0,
258 evidence=evidence,
259 message=message,
260 )
261
262
263 def _backend_identity(backend: Any) -> str | None:
264 """Ask the backend for a stable cache identity string, if it has one.
265
266 Duck-typed: backends that can't uniquely identify themselves (the
267 dummy backend in tests, for example) simply don't provide this
268 method, and caching is skipped for them.
269 """
270 fn = getattr(backend, "cache_identity", None)
271 if not callable(fn):
272 return None
273 try:
274 value = fn()
275 except Exception: # noqa: BLE001 — cache is best-effort
276 return None
277 return str(value) if value else None
278
279
280 def get_null_stats(ctx: RunContext, probe_kind: str) -> dict[str, float] | None:
281 """Look up null-adapter stats for ``probe_kind`` in the run context.
282
283 Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for
284 this kind, else ``None``. Probes treat ``None`` as "fall back to
285 the fixed threshold from your spec" and surface ``(no calibration)``
286 in the report.
287 """
288 return ctx.null_stats.get(probe_kind)