tenseleyflow/sway / c666cbf

Browse files

sway(probes): real NullAdapterProbe calibration — 3 seeds, JS divergence

Authored by espadonne
SHA
c666cbfdb1109527c06897f5f52708ece6aba816
Parents
93c3098
Tree
8fcd8b2

3 changed files

StatusFile+-
M src/dlm_sway/probes/null_adapter.py 92 36
A tests/unit/test_null_calibration.py 123 0
M tests/unit/test_suite_runner.py 6 5
src/dlm_sway/probes/null_adapter.pymodified
@@ -2,56 +2,73 @@
2
 
2
 
3
 Every numeric primitive reports its raw metric *and* a z-score against a
3
 Every numeric primitive reports its raw metric *and* a z-score against a
4
 null-adapter distribution. This probe is the runtime engine that
4
 null-adapter distribution. This probe is the runtime engine that
5
-establishes that distribution — running each configured primitive
5
+establishes that distribution — it builds random-init "null" adapters
6
-against a series of random-init-style "null" adapters (structurally
6
+(structurally identical to the real adapter but with weights drawn from
7
-identical to the real adapter but with weights indistinguishable from
7
+a Gaussian) and measures how much signal they produce.
8
-noise) and caching the resulting ``(mean, std, n)`` per primitive kind.
8
+
9
-
9
+The resulting ``(mean, std, n)`` per kind is attached to this probe's
10
-The heavy lifting — materializing random-init LoRAs on the loaded model
10
+``evidence["null_stats"]``. The runner picks it up and threads it into
11
-and running probes with them — lives in the HF backend (later
11
+:attr:`RunContext.null_stats`, where every downstream probe can read it
12
-milestone). For now this module ships the spec + the lookup API that
12
+and turn a raw metric into a z-score.
13
-probes will use to z-score their results once stats are populated.
13
+
14
+Backends that don't implement :class:`~dlm_sway.core.scoring.NullCalibratedBackend`
15
+cause this probe to :attr:`Verdict.SKIP` — downstream probes fall back
16
+to their fixed thresholds in that case.
14
 """
17
 """
15
 
18
 
16
 from __future__ import annotations
19
 from __future__ import annotations
17
 
20
 
21
+import statistics
18
 from typing import Literal
22
 from typing import Literal
19
 
23
 
20
 from pydantic import Field
24
 from pydantic import Field
21
 
25
 
22
 from dlm_sway.core.result import ProbeResult, Verdict
26
 from dlm_sway.core.result import ProbeResult, Verdict
27
+from dlm_sway.core.scoring import NullCalibratedBackend
28
+from dlm_sway.probes._divergence import divergence
23
 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
29
 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
24
 
30
 
25
 
31
 
26
 class NullAdapterSpec(ProbeSpec):
32
 class NullAdapterSpec(ProbeSpec):
27
     """Spec for ``kind: null_adapter``.
33
     """Spec for ``kind: null_adapter``.
28
 
34
 
29
-    This is a meta-probe: it doesn't test the adapter, it calibrates
35
+    Authors place this probe **first** in the suite so its output
30
-    *other* probes. Place it first in the suite so its output is in
36
+    populates :attr:`RunContext.null_stats` before subsequent probes
31
-    :attr:`~dlm_sway.probes.base.RunContext.null_stats` when later
37
+    consult it.
32
-    probes run.
33
     """
38
     """
34
 
39
 
35
     kind: Literal["null_adapter"] = "null_adapter"
40
     kind: Literal["null_adapter"] = "null_adapter"
36
     runs: int = Field(default=3, ge=1, le=10)
41
     runs: int = Field(default=3, ge=1, le=10)
37
     """Number of independent null adapters to evaluate. Three is the
42
     """Number of independent null adapters to evaluate. Three is the
38
-    smallest that gives a usable std estimate; more is better but quickly
43
+    smallest that yields a usable std; more is better but quickly
39
     dominates suite runtime."""
44
     dominates suite runtime."""
40
-    rank: int | None = None
45
+    prompts: list[str] = Field(default_factory=list)
41
-    """LoRA rank for the null adapter. ``None`` → match the real adapter."""
46
+    """Prompt set for null calibration. Keep small — calibration runs
42
-    alpha: int | None = None
47
+    ``runs × len(prompts)`` forward passes. 4–8 prompts is typical.
43
-    """LoRA alpha. ``None`` → match the real adapter."""
48
+    If empty, a minimal built-in prompt set is used so the probe
49
+    always produces stats."""
44
     init_scale: float = 0.02
50
     init_scale: float = 0.02
45
-    """Standard deviation of the zero-mean Gaussian used to init
51
+    """Stddev of the zero-mean Gaussian used to fill lora_A/lora_B."""
46
-    lora_A/lora_B. Matches typical post-init scale."""
52
+    seed_base: int = 1000
53
+    """First seed; successive runs use ``seed_base + run_idx``."""
54
+
55
+
56
+_DEFAULT_PROMPTS: tuple[str, ...] = (
57
+    "The quick brown fox",
58
+    "Once upon a time",
59
+    "In this document we explain",
60
+    "The key takeaway is",
61
+    "An important point to remember",
62
+)
47
 
63
 
48
 
64
 
49
 class NullAdapterProbe(Probe):
65
 class NullAdapterProbe(Probe):
50
-    """Populate ``ctx.null_stats``; report a :attr:`Verdict.SKIP` verdict itself.
66
+    """Populate ``ctx.null_stats``; report a :attr:`Verdict.PASS` verdict itself.
51
 
67
 
52
-    The probe never fails on its own terms — its *job* is calibration,
68
+    The probe never fails on its own terms — its *job* is calibration.
53
-    not judgment. Downstream probes consult
69
+    Downstream probes pick up :attr:`RunContext.null_stats` keyed by
54
-    :meth:`get_null_stats` to turn their raw metric into a z-score.
70
+    probe kind (``delta_kl``, ``adapter_ablation`` …) and use the
71
+    populated mean/std to z-score their own raw metrics.
55
     """
72
     """
56
 
73
 
57
     kind = "null_adapter"
74
     kind = "null_adapter"
@@ -59,22 +76,61 @@ class NullAdapterProbe(Probe):
59
     category = "baseline"
76
     category = "baseline"
60
 
77
 
61
     def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
78
     def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
62
-        # Concrete null-adapter materialization is backend-specific. For
63
-        # the HF backend it will build random-init LoRAs with matched
64
-        # rank/alpha. That path is wired in a later milestone; this probe
65
-        # currently reports SKIP so suite composition stays stable.
66
-        del ctx  # unused until HF-level materialization lands
67
         assert isinstance(spec, NullAdapterSpec)
79
         assert isinstance(spec, NullAdapterSpec)
80
+        if not isinstance(ctx.backend, NullCalibratedBackend):
81
+            return ProbeResult(
82
+                name=spec.name,
83
+                kind=spec.kind,
84
+                verdict=Verdict.SKIP,
85
+                score=None,
86
+                message=(
87
+                    "backend does not implement NullCalibratedBackend — "
88
+                    "numeric probes will fall back to fixed thresholds"
89
+                ),
90
+            )
91
+        prompts = list(spec.prompts) or list(_DEFAULT_PROMPTS)
92
+
93
+        per_seed_means: list[float] = []
94
+        for run_idx in range(spec.runs):
95
+            seed = spec.seed_base + run_idx
96
+            per_prompt: list[float] = []
97
+            for prompt in prompts:
98
+                with ctx.backend.as_base() as base_view:
99
+                    base_dist = base_view.next_token_dist(prompt, top_k=ctx.top_k)
100
+                with ctx.backend.as_null_adapter(seed, init_scale=spec.init_scale) as null_view:
101
+                    null_dist = null_view.next_token_dist(prompt, top_k=ctx.top_k)
102
+                per_prompt.append(divergence(base_dist, null_dist, kind="js"))
103
+            per_seed_means.append(statistics.fmean(per_prompt) if per_prompt else 0.0)
104
+
105
+        mean = statistics.fmean(per_seed_means)
106
+        std = statistics.pstdev(per_seed_means) if len(per_seed_means) > 1 else 0.0
107
+
108
+        # Publish per-kind stats. delta_kl is the primary kind; other
109
+        # divergence-based probes (adapter_ablation) share this scale.
110
+        null_stats = {
111
+            "delta_kl": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)},
112
+            "adapter_ablation": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)},
113
+        }
114
+
68
         return ProbeResult(
115
         return ProbeResult(
69
             name=spec.name,
116
             name=spec.name,
70
             kind=spec.kind,
117
             kind=spec.kind,
71
-            verdict=Verdict.SKIP,
118
+            verdict=Verdict.PASS,
72
-            score=None,
119
+            score=1.0,
120
+            raw=mean,
121
+            evidence={
122
+                "null_stats": null_stats,
123
+                "per_seed_mean_js": per_seed_means,
124
+                "init_scale": spec.init_scale,
125
+                "runs": spec.runs,
126
+                "num_prompts": len(prompts),
127
+                "weight": spec.weight,
128
+            },
73
             message=(
129
             message=(
74
-                "null-adapter calibration pending — downstream probes will fall back to "
130
+                f"null JS divergence μ={mean:.4f} ± {std:.4f} "
75
-                "fixed thresholds until the backend-level materialization lands"
131
+                f"(over {spec.runs} seeds × {len(prompts)} prompts) — "
132
+                f"downstream probes will z-score against this baseline"
76
             ),
133
             ),
77
-            evidence={"runs": spec.runs, "rank": spec.rank, "alpha": spec.alpha},
78
         )
134
         )
79
 
135
 
80
 
136
 
@@ -82,7 +138,7 @@ def get_null_stats(ctx: RunContext, probe_kind: str) -> dict[str, float] | None:
82
     """Look up null-adapter stats for ``probe_kind``.
138
     """Look up null-adapter stats for ``probe_kind``.
83
 
139
 
84
     Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for
140
     Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for
85
-    this kind, else ``None``. Probes should treat ``None`` as "fall back
141
+    this kind, else ``None``. Probes treat ``None`` as "fall back to the
86
-    to the fixed threshold from your spec."
142
+    fixed threshold from your spec."
87
     """
143
     """
88
     return ctx.null_stats.get(probe_kind)
144
     return ctx.null_stats.get(probe_kind)
tests/unit/test_null_calibration.pyadded
@@ -0,0 +1,123 @@
1
+"""Tests for null-adapter calibration.
2
+
3
+Covers: dummy backend ``as_null_adapter`` yields a plausibly noisy
4
+view; ``NullAdapterProbe`` populates ``ctx.null_stats`` in a way
5
+downstream probes pick up end-to-end; missing-capability SKIP path.
6
+"""
7
+
8
+from __future__ import annotations
9
+
10
+import numpy as np
11
+
12
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
13
+from dlm_sway.core.result import Verdict
14
+from dlm_sway.core.scoring import NullCalibratedBackend
15
+from dlm_sway.probes.base import RunContext, build_probe
16
+from dlm_sway.suite.runner import run as run_suite
17
+from dlm_sway.suite.spec import SwaySpec
18
+
19
+
20
+def _diverging_backend() -> DummyDifferentialBackend:
21
+    base = DummyResponses()
22
+    ft = DummyResponses()
23
+    return DummyDifferentialBackend(base=base, ft=ft)
24
+
25
+
26
+class TestProtocolConformance:
27
+    def test_dummy_is_null_calibrated(self) -> None:
28
+        assert isinstance(_diverging_backend(), NullCalibratedBackend)
29
+
30
+
31
+class TestAsNullAdapter:
32
+    def test_yields_perturbed_view(self) -> None:
33
+        backend = _diverging_backend()
34
+        with backend.as_base() as base:
35
+            base_dist = base.next_token_dist("hello")
36
+        with backend.as_null_adapter(seed=0) as null:
37
+            null_dist = null.next_token_dist("hello")
38
+        # Some perturbation, but bounded.
39
+        assert not np.allclose(base_dist.logprobs, null_dist.logprobs)
40
+
41
+    def test_different_seeds_yield_different_views(self) -> None:
42
+        backend = _diverging_backend()
43
+        with backend.as_null_adapter(seed=1) as v1:
44
+            d1 = v1.next_token_dist("hello")
45
+        with backend.as_null_adapter(seed=2) as v2:
46
+            d2 = v2.next_token_dist("hello")
47
+        assert not np.allclose(d1.logprobs, d2.logprobs)
48
+
49
+    def test_view_exclusion_enforced(self) -> None:
50
+        import pytest
51
+
52
+        backend = _diverging_backend()
53
+        with backend.as_null_adapter(seed=0), pytest.raises(RuntimeError):
54
+            with backend.as_base():
55
+                pass
56
+
57
+
58
+class TestProbe:
59
+    def test_populates_null_stats(self) -> None:
60
+        backend = _diverging_backend()
61
+        probe, spec = build_probe(
62
+            {
63
+                "name": "null",
64
+                "kind": "null_adapter",
65
+                "runs": 3,
66
+                "prompts": ["q1", "q2"],
67
+            }
68
+        )
69
+        ctx = RunContext(backend=backend)
70
+        result = probe.run(spec, ctx)
71
+        assert result.verdict == Verdict.PASS
72
+        stats = result.evidence["null_stats"]
73
+        assert "delta_kl" in stats
74
+        assert stats["delta_kl"]["n"] == 3.0
75
+        assert stats["delta_kl"]["std"] > 0.0  # seeded perturbations produce variance
76
+
77
+    def test_runner_threads_null_stats_to_subsequent_probes(self) -> None:
78
+        """End-to-end: null_adapter first → delta_kl picks up z-score path."""
79
+        backend = _diverging_backend()
80
+        raw_spec = SwaySpec.model_validate(
81
+            {
82
+                "version": 1,
83
+                "models": {"base": {"base": "b"}, "ft": {"base": "b", "adapter": "/tmp/a"}},
84
+                "suite": [
85
+                    {
86
+                        "name": "null",
87
+                        "kind": "null_adapter",
88
+                        "runs": 3,
89
+                        "prompts": ["p1", "p2"],
90
+                    },
91
+                    {
92
+                        "name": "dk",
93
+                        "kind": "delta_kl",
94
+                        "prompts": ["p1", "p2"],
95
+                        "assert_z_gte": -10.0,  # permissive so we pass regardless
96
+                    },
97
+                ],
98
+            }
99
+        )
100
+        result = run_suite(raw_spec, backend)
101
+        assert len(result.probes) == 2
102
+        null_result = result.probes[0]
103
+        dk_result = result.probes[1]
104
+        assert null_result.verdict == Verdict.PASS
105
+        # The delta_kl probe should have computed a z_score because null_stats was present.
106
+        assert dk_result.z_score is not None, (
107
+            "delta_kl should have z-scored against null baseline, got "
108
+            f"evidence={dk_result.evidence}, message={dk_result.message}"
109
+        )
110
+
111
+    def test_skip_when_backend_not_null_calibrated(self) -> None:
112
+        class _Bare:
113
+            def as_base(self):  # noqa: ANN202
114
+                raise NotImplementedError
115
+
116
+            def as_finetuned(self):  # noqa: ANN202
117
+                raise NotImplementedError
118
+
119
+        probe, spec = build_probe({"name": "null", "kind": "null_adapter"})
120
+        ctx = RunContext(backend=_Bare())  # type: ignore[arg-type]
121
+        result = probe.run(spec, ctx)
122
+        assert result.verdict == Verdict.SKIP
123
+        assert "NullCalibratedBackend" in result.message
tests/unit/test_suite_runner.pymodified
@@ -122,12 +122,13 @@ class TestRunner:
122
         assert result.wall_seconds >= 0
122
         assert result.wall_seconds >= 0
123
         assert result.probes[0].duration_s >= 0
123
         assert result.probes[0].duration_s >= 0
124
 
124
 
125
-    def test_null_adapter_skipped_stable_suite_shape(
125
+    def test_null_adapter_passes_on_null_calibrated_backend(
126
         self, backend: DummyDifferentialBackend
126
         self, backend: DummyDifferentialBackend
127
     ) -> None:
127
     ) -> None:
128
-        spec = _spec({"name": "null", "kind": "null_adapter", "runs": 3})
128
+        # Dummy backend implements NullCalibratedBackend, so calibration runs.
129
+        spec = _spec({"name": "null", "kind": "null_adapter", "runs": 2, "prompts": ["q1"]})
129
         result = run(spec, backend)
130
         result = run(spec, backend)
130
-        # Until the HF-level implementation lands, null_adapter reports SKIP
131
-        # but must not crash the suite.
132
         assert result.probes[0].kind == "null_adapter"
131
         assert result.probes[0].kind == "null_adapter"
133
-        assert result.probes[0].verdict == Verdict.SKIP
132
+        assert result.probes[0].verdict == Verdict.PASS
133
+        # And the suite's null_stats bubbles up onto the result.
134
+        assert "delta_kl" in result.null_stats