tenseleyflow/sway / 1e28540

Browse files

probes/null_adapter: per-kind calibration matrix (fixes P02, B2, C9)

Authored by espadonne
SHA
1e28540187e96e9aae50146a8d6faf2f8306d475
Parents
fc49f7e
Tree
640aab0

3 changed files

StatusFile+-
M src/dlm_sway/probes/null_adapter.py 160 82
M tests/unit/test_null_calibration.py 61 2
M tests/unit/test_suite_runner.py 9 1
src/dlm_sway/probes/null_adapter.pymodified
@@ -1,40 +1,56 @@
1
-"""Null-adapter baseline probe.
2
-
3
-Every numeric primitive reports its raw metric *and* a z-score against a
4
-null-adapter distribution. This probe is the runtime engine that
5
-establishes that distribution — it builds random-init "null" adapters
6
-(structurally identical to the real adapter but with weights drawn from
7
-a Gaussian) and measures how much signal they produce.
8
-
9
-The resulting ``(mean, std, n)`` per kind is attached to this probe's
10
-``evidence["null_stats"]``. The runner picks it up and threads it into
11
-:attr:`RunContext.null_stats`, where every downstream probe can read it
12
-and turn a raw metric into a z-score.
13
-
14
-Backends that don't implement :class:`~dlm_sway.core.scoring.NullCalibratedBackend`
15
-cause this probe to :attr:`Verdict.SKIP` — downstream probes fall back
16
-to their fixed thresholds in that case.
1
+"""Null-adapter baseline probe — per-kind calibration matrix (S02).
2
+
3
+Every numeric primitive reports its raw metric *and* a z-score against
4
+a null-adapter distribution. This probe is the runtime engine that
5
+establishes that distribution — for **every** numeric probe kind the
6
+user has downstream in the suite, not just one.
7
+
8
+How it works:
9
+
10
+1. The runner populates ``ctx.downstream_kinds`` with every probe kind
11
+   that appears after this one in the suite.
12
+2. For each target kind, we ask its probe class for a
13
+   :meth:`~dlm_sway.probes.base.Probe.calibrate_spec` — a small spec
14
+   suitable for null calibration. A probe that returns ``None`` opts
15
+   out (typically because its inputs can't be synthesized, e.g.
16
+   ``adapter_revert`` without an embedder, or ``adapter_ablation``
17
+   which needs ``as_scaled_adapter`` that the proxy doesn't expose).
18
+3. For each calibrating kind × seed, we run the probe through a
19
+   :class:`~dlm_sway.probes._null_proxy.NullCalibrationBackendProxy`
20
+   which makes ``as_finetuned()`` yield ``as_null_adapter(seed)`` —
21
+   so the probe's own math is computing "what does my metric look
22
+   like when the fine-tune is structural noise?".
23
+4. We harvest each run's ``raw`` value, aggregate to ``(mean, std, n)``
24
+   per kind, and publish under ``evidence["null_stats"]``.
25
+5. The runner threads ``null_stats`` into ``RunContext`` for every
26
+   subsequent probe, which then prefers the z-score path over the
27
+   fixed-threshold path (see :mod:`dlm_sway.probes._zscore`).
28
+
29
+Backends that don't implement
30
+:class:`~dlm_sway.core.scoring.NullCalibratedBackend` cause this probe
31
+to ``Verdict.SKIP``; every downstream probe falls back to fixed
32
+thresholds and surfaces ``(no calibration)`` in the report.
1733
 """
1834
 
1935
 from __future__ import annotations
2036
 
37
+import math
2138
 import statistics
22
-from typing import Literal
39
+from typing import Any, Literal
2340
 
2441
 from pydantic import Field
2542
 
26
-from dlm_sway.core.result import ProbeResult, Verdict
43
+from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
2744
 from dlm_sway.core.scoring import NullCalibratedBackend
28
-from dlm_sway.probes._divergence import divergence
29
-from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
45
+from dlm_sway.probes._null_proxy import NullCalibrationBackendProxy
46
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext, registry
3047
 
3148
 
3249
 class NullAdapterSpec(ProbeSpec):
3350
     """Spec for ``kind: null_adapter``.
3451
 
35
-    Authors place this probe **first** in the suite so its output
36
-    populates :attr:`RunContext.null_stats` before subsequent probes
37
-    consult it.
52
+    Place this probe **first** in the suite so its output populates
53
+    :attr:`RunContext.null_stats` before subsequent probes consult it.
3854
     """
3955
 
4056
     kind: Literal["null_adapter"] = "null_adapter"
@@ -42,33 +58,24 @@ class NullAdapterSpec(ProbeSpec):
4258
     """Number of independent null adapters to evaluate. Three is the
4359
     smallest that yields a usable std; more is better but quickly
4460
     dominates suite runtime."""
45
-    prompts: list[str] = Field(default_factory=list)
46
-    """Prompt set for null calibration. Keep small — calibration runs
47
-    ``runs × len(prompts)`` forward passes. 4–8 prompts is typical.
48
-    If empty, a minimal built-in prompt set is used so the probe
49
-    always produces stats."""
5061
     init_scale: float = 0.02
5162
     """Stddev of the zero-mean Gaussian used to fill lora_A/lora_B."""
5263
     seed_base: int = 1000
5364
     """First seed; successive runs use ``seed_base + run_idx``."""
54
-
55
-
56
-_DEFAULT_PROMPTS: tuple[str, ...] = (
57
-    "The quick brown fox",
58
-    "Once upon a time",
59
-    "In this document we explain",
60
-    "The key takeaway is",
61
-    "An important point to remember",
62
-)
65
+    calibrate_kinds: list[str] = Field(default_factory=list)
66
+    """Which probe kinds to calibrate. Empty = auto-populate from
67
+    ``ctx.downstream_kinds`` (the kinds that appear after this probe
68
+    in the suite). Set explicitly to force calibration of specific
69
+    kinds regardless of suite order."""
6370
 
6471
 
6572
 class NullAdapterProbe(Probe):
66
-    """Populate ``ctx.null_stats``; report a :attr:`Verdict.PASS` verdict itself.
73
+    """Populate ``ctx.null_stats`` with per-kind null distributions.
6774
 
68
-    The probe never fails on its own terms — its *job* is calibration.
69
-    Downstream probes pick up :attr:`RunContext.null_stats` keyed by
70
-    probe kind (``delta_kl``, ``adapter_ablation`` …) and use the
71
-    populated mean/std to z-score their own raw metrics.
75
+    The probe itself reports ``Verdict.PASS`` on success — its job is
76
+    calibration, not judgment. If the backend can't support null-view
77
+    substitution, reports ``Verdict.SKIP`` with a clear message; every
78
+    downstream numeric probe then falls back to fixed thresholds.
7279
     """
7380
 
7481
     kind = "null_adapter"
@@ -88,57 +95,128 @@ class NullAdapterProbe(Probe):
8895
                     "numeric probes will fall back to fixed thresholds"
8996
                 ),
9097
             )
91
-        prompts = list(spec.prompts) or list(_DEFAULT_PROMPTS)
92
-
93
-        per_seed_means: list[float] = []
94
-        for run_idx in range(spec.runs):
95
-            seed = spec.seed_base + run_idx
96
-            per_prompt: list[float] = []
97
-            for prompt in prompts:
98
-                with ctx.backend.as_base() as base_view:
99
-                    base_dist = base_view.next_token_dist(prompt, top_k=ctx.top_k)
100
-                with ctx.backend.as_null_adapter(seed, init_scale=spec.init_scale) as null_view:
101
-                    null_dist = null_view.next_token_dist(prompt, top_k=ctx.top_k)
102
-                per_prompt.append(divergence(base_dist, null_dist, kind="js"))
103
-            per_seed_means.append(statistics.fmean(per_prompt) if per_prompt else 0.0)
104
-
105
-        mean = statistics.fmean(per_seed_means)
106
-        std = statistics.pstdev(per_seed_means) if len(per_seed_means) > 1 else 0.0
107
-
108
-        # Publish per-kind stats. delta_kl is the primary kind; other
109
-        # divergence-based probes (adapter_ablation) share this scale.
110
-        null_stats = {
111
-            "delta_kl": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)},
112
-            "adapter_ablation": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)},
98
+
99
+        registered = registry()
100
+
101
+        # Decide which kinds to calibrate. Explicit spec field wins;
102
+        # otherwise auto-populate from downstream_kinds.
103
+        target_kinds: list[str] = list(spec.calibrate_kinds)
104
+        if not target_kinds:
105
+            target_kinds = [k for k in ctx.downstream_kinds if k and k != spec.kind]
106
+        # De-dupe while preserving order; drop self and unregistered.
107
+        seen: set[str] = set()
108
+        filtered: list[str] = []
109
+        for k in target_kinds:
110
+            if k == spec.kind or k in seen or k not in registered:
111
+                continue
112
+            seen.add(k)
113
+            filtered.append(k)
114
+        target_kinds = filtered
115
+
116
+        per_kind_stats: dict[str, dict[str, float]] = {}
117
+        per_kind_samples: dict[str, list[float]] = {}
118
+        skipped_kinds: list[dict[str, str]] = []
119
+
120
+        for kind in target_kinds:
121
+            probe_cls = registered[kind]
122
+            try:
123
+                cal_spec = probe_cls.calibrate_spec(ctx)
124
+            except Exception as exc:  # noqa: BLE001 — defensive
125
+                skipped_kinds.append(
126
+                    {"kind": kind, "reason": f"calibrate_spec raised: {exc}"}
127
+                )
128
+                continue
129
+            if cal_spec is None:
130
+                skipped_kinds.append(
131
+                    {
132
+                        "kind": kind,
133
+                        "reason": "probe opted out (calibrate_spec returned None)",
134
+                    }
135
+                )
136
+                continue
137
+
138
+            probe = probe_cls()
139
+            raws: list[float] = []
140
+            errors: list[str] = []
141
+            for run_idx in range(spec.runs):
142
+                seed = spec.seed_base + run_idx
143
+                proxy = NullCalibrationBackendProxy(
144
+                    ctx.backend, seed=seed, init_scale=spec.init_scale
145
+                )
146
+                cal_ctx = RunContext(
147
+                    backend=proxy,
148
+                    seed=seed,
149
+                    top_k=ctx.top_k,
150
+                    sections=ctx.sections,
151
+                    doc_text=ctx.doc_text,
152
+                    null_stats={},  # calibration uses fixed thresholds — no recursion
153
+                    downstream_kinds=(),
154
+                )
155
+                try:
156
+                    cal_result = probe.run(cal_spec, cal_ctx)
157
+                except Exception as exc:  # noqa: BLE001
158
+                    errors.append(f"seed={seed}: {type(exc).__name__}: {exc}")
159
+                    continue
160
+                raw = cal_result.raw
161
+                if raw is not None and math.isfinite(raw):
162
+                    raws.append(float(raw))
163
+                elif cal_result.verdict == Verdict.ERROR:
164
+                    errors.append(
165
+                        f"seed={seed}: probe ERROR — {cal_result.message}"
166
+                    )
167
+
168
+            if raws:
169
+                mean = statistics.fmean(raws)
170
+                std = statistics.pstdev(raws) if len(raws) > 1 else 0.0
171
+                per_kind_stats[kind] = {
172
+                    "mean": mean,
173
+                    # C9: clamp the std floor so the downstream z-score
174
+                    # path doesn't blow up when every seed produces
175
+                    # identical raws.
176
+                    "std": max(std, 1e-6),
177
+                    "n": float(len(raws)),
178
+                }
179
+                per_kind_samples[kind] = raws
180
+            else:
181
+                reason = "no finite raws across all seeds"
182
+                if errors:
183
+                    reason += f" ({errors[0]})"
184
+                skipped_kinds.append({"kind": kind, "reason": reason})
185
+
186
+        evidence: dict[str, Any] = {
187
+            "null_stats": per_kind_stats,
188
+            "per_kind_raw_samples": per_kind_samples,
189
+            "skipped_kinds": skipped_kinds,
190
+            "calibrated_kinds": list(per_kind_stats.keys()),
191
+            "runs": spec.runs,
192
+            "init_scale": spec.init_scale,
193
+            "seed_base": spec.seed_base,
194
+            "weight": spec.weight,
113195
         }
114196
 
115
-        return ProbeResult(
197
+        message = (
198
+            f"null calibration: {len(per_kind_stats)} kinds calibrated "
199
+            f"over {spec.runs} seeds"
200
+        )
201
+        if skipped_kinds:
202
+            message += f" ({len(skipped_kinds)} opted out)"
203
+
204
+        return safe_finalize(
116205
             name=spec.name,
117206
             kind=spec.kind,
118207
             verdict=Verdict.PASS,
119208
             score=1.0,
120
-            raw=mean,
121
-            evidence={
122
-                "null_stats": null_stats,
123
-                "per_seed_mean_js": per_seed_means,
124
-                "init_scale": spec.init_scale,
125
-                "runs": spec.runs,
126
-                "num_prompts": len(prompts),
127
-                "weight": spec.weight,
128
-            },
129
-            message=(
130
-                f"null JS divergence μ={mean:.4f} ± {std:.4f} "
131
-                f"(over {spec.runs} seeds × {len(prompts)} prompts) — "
132
-                f"downstream probes will z-score against this baseline"
133
-            ),
209
+            evidence=evidence,
210
+            message=message,
134211
         )
135212
 
136213
 
137214
 def get_null_stats(ctx: RunContext, probe_kind: str) -> dict[str, float] | None:
138
-    """Look up null-adapter stats for ``probe_kind``.
215
+    """Look up null-adapter stats for ``probe_kind`` in the run context.
139216
 
140217
     Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for
141
-    this kind, else ``None``. Probes treat ``None`` as "fall back to the
142
-    fixed threshold from your spec."
218
+    this kind, else ``None``. Probes treat ``None`` as "fall back to
219
+    the fixed threshold from your spec" and surface ``(no calibration)``
220
+    in the report.
143221
     """
144222
     return ctx.null_stats.get(probe_kind)
tests/unit/test_null_calibration.pymodified
@@ -57,13 +57,14 @@ class TestAsNullAdapter:
5757
 
5858
 class TestProbe:
5959
     def test_populates_null_stats(self) -> None:
60
+        """Explicit `calibrate_kinds` calibrates regardless of suite order."""
6061
         backend = _diverging_backend()
6162
         probe, spec = build_probe(
6263
             {
6364
                 "name": "null",
6465
                 "kind": "null_adapter",
6566
                 "runs": 3,
66
-                "prompts": ["q1", "q2"],
67
+                "calibrate_kinds": ["delta_kl"],
6768
             }
6869
         )
6970
         ctx = RunContext(backend=backend)
@@ -74,6 +75,65 @@ class TestProbe:
7475
         assert stats["delta_kl"]["n"] == 3.0
7576
         assert stats["delta_kl"]["std"] > 0.0  # seeded perturbations produce variance
7677
 
78
+    def test_auto_populates_from_downstream_kinds(self) -> None:
79
+        """When `calibrate_kinds` is empty, falls back to `ctx.downstream_kinds`."""
80
+        backend = _diverging_backend()
81
+        probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
82
+        ctx = RunContext(
83
+            backend=backend,
84
+            downstream_kinds=("delta_kl", "prompt_collapse"),
85
+        )
86
+        result = probe.run(spec, ctx)
87
+        assert result.verdict == Verdict.PASS
88
+        stats = result.evidence["null_stats"]
89
+        # Every downstream numeric kind that opts in gets stats.
90
+        assert "delta_kl" in stats
91
+        assert "prompt_collapse" in stats
92
+
93
+    def test_empty_calibrate_kinds_with_no_downstream_is_noop(self) -> None:
94
+        """No kinds, no calibration — probe still PASSes with empty stats."""
95
+        backend = _diverging_backend()
96
+        probe, spec = build_probe({"name": "null", "kind": "null_adapter", "runs": 2})
97
+        ctx = RunContext(backend=backend)  # no downstream_kinds
98
+        result = probe.run(spec, ctx)
99
+        assert result.verdict == Verdict.PASS
100
+        assert result.evidence["null_stats"] == {}
101
+        assert result.evidence["calibrated_kinds"] == []
102
+
103
+    def test_unregistered_kind_is_silently_skipped(self) -> None:
104
+        backend = _diverging_backend()
105
+        probe, spec = build_probe(
106
+            {
107
+                "name": "null",
108
+                "kind": "null_adapter",
109
+                "runs": 2,
110
+                "calibrate_kinds": ["delta_kl", "nonexistent_kind"],
111
+            }
112
+        )
113
+        ctx = RunContext(backend=backend)
114
+        result = probe.run(spec, ctx)
115
+        assert "delta_kl" in result.evidence["null_stats"]
116
+        assert "nonexistent_kind" not in result.evidence["null_stats"]
117
+
118
+    def test_opt_out_probe_is_reported_as_skipped(self) -> None:
119
+        """A kind whose calibrate_spec returns None surfaces in skipped_kinds."""
120
+        backend = _diverging_backend()
121
+        probe, spec = build_probe(
122
+            {
123
+                "name": "null",
124
+                "kind": "null_adapter",
125
+                "runs": 2,
126
+                # adapter_revert.calibrate_spec returns None by default
127
+                # (inherits from base), so we expect it to opt out.
128
+                "calibrate_kinds": ["adapter_revert", "delta_kl"],
129
+            }
130
+        )
131
+        ctx = RunContext(backend=backend)
132
+        result = probe.run(spec, ctx)
133
+        assert "delta_kl" in result.evidence["null_stats"]
134
+        skipped = [s["kind"] for s in result.evidence["skipped_kinds"]]
135
+        assert "adapter_revert" in skipped
136
+
77137
     def test_runner_threads_null_stats_to_subsequent_probes(self) -> None:
78138
         """End-to-end: null_adapter first → delta_kl picks up z-score path."""
79139
         backend = _diverging_backend()
@@ -86,7 +146,6 @@ class TestProbe:
86146
                         "name": "null",
87147
                         "kind": "null_adapter",
88148
                         "runs": 3,
89
-                        "prompts": ["p1", "p2"],
90149
                     },
91150
                     {
92151
                         "name": "dk",
tests/unit/test_suite_runner.pymodified
@@ -127,7 +127,15 @@ class TestRunner:
127127
         self, backend: DummyDifferentialBackend
128128
     ) -> None:
129129
         # Dummy backend implements NullCalibratedBackend, so calibration runs.
130
-        spec = _spec({"name": "null", "kind": "null_adapter", "runs": 2, "prompts": ["q1"]})
130
+        # Explicit calibrate_kinds so it runs even without downstream probes.
131
+        spec = _spec(
132
+            {
133
+                "name": "null",
134
+                "kind": "null_adapter",
135
+                "runs": 2,
136
+                "calibrate_kinds": ["delta_kl"],
137
+            }
138
+        )
131139
         result = run(spec, backend)
132140
         assert result.probes[0].kind == "null_adapter"
133141
         assert result.probes[0].verdict == Verdict.PASS