tenseleyflow/sway / 905b890

Browse files

probes/null_adapter: on-disk cache keyed by backend identity + calibration params

Authored by espadonne
SHA
905b890fcf1791c6ffde046d951b17008e6a00b4
Parents
8e52af6
Tree
f287be3

5 changed files

StatusFile+-
M src/dlm_sway/backends/hf.py 9 0
A src/dlm_sway/probes/_null_cache.py 91 0
M src/dlm_sway/probes/null_adapter.py 76 10
A tests/unit/test_null_cache.py 70 0
M tests/unit/test_null_calibration.py 59 0
src/dlm_sway/backends/hf.pymodified
@@ -355,6 +355,15 @@ class HuggingFaceDifferentialBackend:
355355
     _PREFLIGHT_PROMPT = "hello"
356356
     _PREFLIGHT_TOP_K = 8
357357
 
358
+    def cache_identity(self) -> str:
359
+        """Stable string identifying this backend for on-disk caching.
360
+
361
+        The base model id + the adapter's resolved absolute path is
362
+        enough to key a null-calibration cache: swapping either
363
+        invalidates the previously-computed stats.
364
+        """
365
+        return f"hf:{self._spec.base}:{self._adapter_path}"
366
+
358367
     def preflight_finite_check(self) -> tuple[bool, str]:
359368
         """One forward pass per view; assert both produce finite logits.
360369
 
src/dlm_sway/probes/_null_cache.pyadded
@@ -0,0 +1,91 @@
1
+"""On-disk cache for null-adapter calibration stats.
2
+
3
+Null calibration runs a miniature version of every downstream numeric
4
+probe across N seeds before the suite proper. For a 10-probe suite at
5
+``runs=3`` that's ~120 forward passes; on an HF backend against a real
6
+model this can dominate wall time. Results are deterministic in the
7
+calibration inputs — so we cache them at
8
+``~/.dlm-sway/null-stats/<key>.json`` keyed by the tuple that actually
9
+influences the output.
10
+
11
+Scope here is intentionally minimal. Sprint 07 adds a shared
12
+forward-pass cache that cuts into a lower level; this module only
13
+amortizes the per-suite calibration pass.
14
+"""
15
+
16
+from __future__ import annotations
17
+
18
+import hashlib
19
+import json
20
+import os
21
+from pathlib import Path
22
+from typing import Any
23
+
24
+#: Environment knob — set to ``"1"`` to bypass load + save (development
25
+#: / CI tests that want to prove calibration actually runs).
26
+_ENV_DISABLE = "SWAY_DISABLE_NULL_CACHE"
27
+
28
+
29
+def _cache_root() -> Path:
30
+    """Root directory for cached null stats. Honors ``$XDG_CACHE_HOME``
31
+    when set; otherwise falls back to ``~/.dlm-sway/null-stats``."""
32
+    xdg = os.environ.get("XDG_CACHE_HOME")
33
+    if xdg:
34
+        return Path(xdg).expanduser() / "dlm-sway" / "null-stats"
35
+    return Path.home() / ".dlm-sway" / "null-stats"
36
+
37
+
38
+def compute_key(*, backend_identity: str | None, params: dict[str, Any]) -> str | None:
39
+    """Hash backend identity + calibration params into a stable filename.
40
+
41
+    Returns ``None`` when ``backend_identity`` is ``None`` — backends that
42
+    can't uniquely identify themselves (e.g., the dummy backend used in
43
+    tests) skip caching entirely.
44
+    """
45
+    if not backend_identity:
46
+        return None
47
+    payload = {
48
+        "backend": backend_identity,
49
+        "params": params,
50
+    }
51
+    blob = json.dumps(payload, sort_keys=True, default=str).encode("utf-8")
52
+    return hashlib.sha256(blob).hexdigest()[:32]
53
+
54
+
55
+def load(key: str | None) -> dict[str, Any] | None:
56
+    """Return the cached null-stats dict for ``key``, or ``None`` on miss.
57
+
58
+    Malformed / unreadable cache files are treated as a miss — we'd
59
+    rather recompute than crash the suite. A stale / schema-mismatched
60
+    cache can be wiped with ``rm -rf ~/.dlm-sway/null-stats``.
61
+    """
62
+    if key is None or os.environ.get(_ENV_DISABLE) == "1":
63
+        return None
64
+    path = _cache_root() / f"{key}.json"
65
+    if not path.exists():
66
+        return None
67
+    try:
68
+        with path.open("r", encoding="utf-8") as f:
69
+            data = json.load(f)
70
+    except (OSError, json.JSONDecodeError):
71
+        return None
72
+    if not isinstance(data, dict):
73
+        return None
74
+    return data
75
+
76
+
77
+def save(key: str | None, stats: dict[str, Any]) -> None:
78
+    """Persist ``stats`` under ``key``. Silently no-ops on I/O errors —
79
+    the cache is a speed-up, not a correctness contract."""
80
+    if key is None or os.environ.get(_ENV_DISABLE) == "1":
81
+        return
82
+    root = _cache_root()
83
+    try:
84
+        root.mkdir(parents=True, exist_ok=True)
85
+        path = root / f"{key}.json"
86
+        tmp = path.with_suffix(".json.tmp")
87
+        with tmp.open("w", encoding="utf-8") as f:
88
+            json.dump(stats, f, indent=2, sort_keys=True)
89
+        tmp.replace(path)
90
+    except OSError:
91
+        return
src/dlm_sway/probes/null_adapter.pymodified
@@ -42,6 +42,7 @@ from pydantic import Field
4242
 
4343
 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
4444
 from dlm_sway.core.scoring import NullCalibratedBackend
45
+from dlm_sway.probes._null_cache import compute_key, load, save
4546
 from dlm_sway.probes._null_proxy import NullCalibrationBackendProxy
4647
 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext, registry
4748
 
@@ -67,6 +68,11 @@ class NullAdapterSpec(ProbeSpec):
6768
     ``ctx.downstream_kinds`` (the kinds that appear after this probe
6869
     in the suite). Set explicitly to force calibration of specific
6970
     kinds regardless of suite order."""
71
+    cache: bool = True
72
+    """Read / write the on-disk calibration cache under
73
+    ``~/.dlm-sway/null-stats``. Keyed by backend identity + calibration
74
+    params. Disable to force a fresh calibration (e.g. when you suspect
75
+    the cached stats are stale)."""
7076
 
7177
 
7278
 class NullAdapterProbe(Probe):
@@ -113,6 +119,41 @@ class NullAdapterProbe(Probe):
113119
             filtered.append(k)
114120
         target_kinds = filtered
115121
 
122
+        # Cache lookup: backends can opt in by providing a
123
+        # ``cache_identity()`` method returning a stable string. The
124
+        # key incorporates both that identity and the calibration
125
+        # parameters that actually influence the output.
126
+        cache_key: str | None = None
127
+        if spec.cache:
128
+            backend_identity = _backend_identity(ctx.backend)
129
+            cache_key = compute_key(
130
+                backend_identity=backend_identity,
131
+                params={
132
+                    "runs": spec.runs,
133
+                    "init_scale": spec.init_scale,
134
+                    "seed_base": spec.seed_base,
135
+                    "top_k": ctx.top_k,
136
+                    "kinds": sorted(target_kinds),
137
+                },
138
+            )
139
+            cached = load(cache_key)
140
+            if cached is not None and "null_stats" in cached:
141
+                cached_evidence: dict[str, Any] = dict(cached)
142
+                cached_evidence.setdefault("skipped_kinds", [])
143
+                cached_evidence.setdefault("calibrated_kinds", list(cached["null_stats"].keys()))
144
+                cached_evidence["weight"] = spec.weight
145
+                cached_evidence["from_cache"] = True
146
+                return safe_finalize(
147
+                    name=spec.name,
148
+                    kind=spec.kind,
149
+                    verdict=Verdict.PASS,
150
+                    score=1.0,
151
+                    evidence=cached_evidence,
152
+                    message=(
153
+                        f"null calibration: {len(cached['null_stats'])} kinds (loaded from cache)"
154
+                    ),
155
+                )
156
+
116157
         per_kind_stats: dict[str, dict[str, float]] = {}
117158
         per_kind_samples: dict[str, list[float]] = {}
118159
         skipped_kinds: list[dict[str, str]] = []
@@ -122,9 +163,7 @@ class NullAdapterProbe(Probe):
122163
             try:
123164
                 cal_spec = probe_cls.calibrate_spec(ctx)
124165
             except Exception as exc:  # noqa: BLE001 — defensive
125
-                skipped_kinds.append(
126
-                    {"kind": kind, "reason": f"calibrate_spec raised: {exc}"}
127
-                )
166
+                skipped_kinds.append({"kind": kind, "reason": f"calibrate_spec raised: {exc}"})
128167
                 continue
129168
             if cal_spec is None:
130169
                 skipped_kinds.append(
@@ -161,9 +200,7 @@ class NullAdapterProbe(Probe):
161200
                 if raw is not None and math.isfinite(raw):
162201
                     raws.append(float(raw))
163202
                 elif cal_result.verdict == Verdict.ERROR:
164
-                    errors.append(
165
-                        f"seed={seed}: probe ERROR — {cal_result.message}"
166
-                    )
203
+                    errors.append(f"seed={seed}: probe ERROR — {cal_result.message}")
167204
 
168205
             if raws:
169206
                 mean = statistics.fmean(raws)
@@ -192,12 +229,24 @@ class NullAdapterProbe(Probe):
192229
             "init_scale": spec.init_scale,
193230
             "seed_base": spec.seed_base,
194231
             "weight": spec.weight,
232
+            "from_cache": False,
195233
         }
196234
 
197
-        message = (
198
-            f"null calibration: {len(per_kind_stats)} kinds calibrated "
199
-            f"over {spec.runs} seeds"
200
-        )
235
+        if cache_key is not None:
236
+            # Persist the stats dict only — the samples list can be
237
+            # large, and downstream consumers only need the aggregates.
238
+            save(
239
+                cache_key,
240
+                {
241
+                    "null_stats": per_kind_stats,
242
+                    "runs": spec.runs,
243
+                    "init_scale": spec.init_scale,
244
+                    "seed_base": spec.seed_base,
245
+                    "calibrated_kinds": list(per_kind_stats.keys()),
246
+                },
247
+            )
248
+
249
+        message = f"null calibration: {len(per_kind_stats)} kinds calibrated over {spec.runs} seeds"
201250
         if skipped_kinds:
202251
             message += f" ({len(skipped_kinds)} opted out)"
203252
 
@@ -211,6 +260,23 @@ class NullAdapterProbe(Probe):
211260
         )
212261
 
213262
 
263
+def _backend_identity(backend: Any) -> str | None:
264
+    """Ask the backend for a stable cache identity string, if it has one.
265
+
266
+    Duck-typed: backends that can't uniquely identify themselves (the
267
+    dummy backend in tests, for example) simply don't provide this
268
+    method, and caching is skipped for them.
269
+    """
270
+    fn = getattr(backend, "cache_identity", None)
271
+    if not callable(fn):
272
+        return None
273
+    try:
274
+        value = fn()
275
+    except Exception:  # noqa: BLE001 — cache is best-effort
276
+        return None
277
+    return str(value) if value else None
278
+
279
+
214280
 def get_null_stats(ctx: RunContext, probe_kind: str) -> dict[str, float] | None:
215281
     """Look up null-adapter stats for ``probe_kind`` in the run context.
216282
 
tests/unit/test_null_cache.pyadded
@@ -0,0 +1,70 @@
1
+"""Tests for the on-disk null-calibration cache."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm_sway.probes._null_cache import compute_key, load, save
8
+
9
+
10
+@pytest.fixture
11
+def isolated_cache(tmp_path, monkeypatch):
12
+    """Redirect the cache root into a per-test tmp dir."""
13
+    monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
14
+    return tmp_path
15
+
16
+
17
+class TestComputeKey:
18
+    def test_none_identity_returns_none(self) -> None:
19
+        assert compute_key(backend_identity=None, params={"runs": 3}) is None
20
+
21
+    def test_empty_identity_returns_none(self) -> None:
22
+        assert compute_key(backend_identity="", params={"runs": 3}) is None
23
+
24
+    def test_stable_across_calls(self) -> None:
25
+        k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
26
+        k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
27
+        assert k1 == k2
28
+
29
+    def test_changes_when_params_change(self) -> None:
30
+        k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
31
+        k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 5})
32
+        assert k1 != k2
33
+
34
+    def test_changes_when_identity_changes(self) -> None:
35
+        k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
36
+        k2 = compute_key(backend_identity="hf:foo:/tmp/b", params={"runs": 3})
37
+        assert k1 != k2
38
+
39
+
40
+class TestLoadSave:
41
+    def test_save_then_load_roundtrip(self, isolated_cache) -> None:
42
+        stats = {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}}
43
+        key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
44
+        assert key is not None
45
+        save(key, stats)
46
+        loaded = load(key)
47
+        assert loaded == stats
48
+
49
+    def test_load_miss_returns_none(self, isolated_cache) -> None:
50
+        key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
51
+        assert load(key) is None
52
+
53
+    def test_none_key_roundtrip_noop(self, isolated_cache) -> None:
54
+        save(None, {"null_stats": {}})
55
+        assert load(None) is None
56
+
57
+    def test_malformed_json_is_treated_as_miss(self, isolated_cache, tmp_path) -> None:
58
+        key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
59
+        assert key is not None
60
+        # Manually write malformed content at the expected path.
61
+        cache_root = tmp_path / "dlm-sway" / "null-stats"
62
+        cache_root.mkdir(parents=True)
63
+        (cache_root / f"{key}.json").write_text("{ not json")
64
+        assert load(key) is None
65
+
66
+    def test_env_disable_bypasses_both(self, isolated_cache, monkeypatch) -> None:
67
+        monkeypatch.setenv("SWAY_DISABLE_NULL_CACHE", "1")
68
+        key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
69
+        save(key, {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}})
70
+        assert load(key) is None
tests/unit/test_null_calibration.pymodified
@@ -173,6 +173,65 @@ class TestProbe:
173173
             f"evidence={dk_result.evidence}, message={dk_result.message}"
174174
         )
175175
 
176
+    def test_cache_hit_short_circuits_calibration(self, tmp_path, monkeypatch) -> None:
177
+        """A cached stats blob is loaded without re-running any probes."""
178
+        monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
179
+
180
+        class _IdBackend(DummyDifferentialBackend):
181
+            def cache_identity(self) -> str:
182
+                return "test:id-backend"
183
+
184
+        backend = _IdBackend(base=DummyResponses(), ft=DummyResponses())
185
+
186
+        # First call: populates the cache.
187
+        probe, spec = build_probe(
188
+            {
189
+                "name": "null",
190
+                "kind": "null_adapter",
191
+                "runs": 2,
192
+                "calibrate_kinds": ["delta_kl"],
193
+            }
194
+        )
195
+        ctx = RunContext(backend=backend)
196
+        r1 = probe.run(spec, ctx)
197
+        assert r1.evidence["from_cache"] is False
198
+
199
+        # Second call: same params, same identity → cache hit.
200
+        r2 = probe.run(spec, ctx)
201
+        assert r2.evidence["from_cache"] is True
202
+        assert "delta_kl" in r2.evidence["null_stats"]
203
+
204
+    def test_cache_disabled_forces_recompute(self, tmp_path, monkeypatch) -> None:
205
+        """``cache=false`` bypasses the cache even if a prior run populated it."""
206
+        monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
207
+
208
+        class _IdBackend(DummyDifferentialBackend):
209
+            def cache_identity(self) -> str:
210
+                return "test:id-backend-2"
211
+
212
+        backend = _IdBackend(base=DummyResponses(), ft=DummyResponses())
213
+        probe, populating_spec = build_probe(
214
+            {
215
+                "name": "null",
216
+                "kind": "null_adapter",
217
+                "runs": 2,
218
+                "calibrate_kinds": ["delta_kl"],
219
+            }
220
+        )
221
+        probe.run(populating_spec, RunContext(backend=backend))
222
+
223
+        _, fresh_spec = build_probe(
224
+            {
225
+                "name": "null",
226
+                "kind": "null_adapter",
227
+                "runs": 2,
228
+                "calibrate_kinds": ["delta_kl"],
229
+                "cache": False,
230
+            }
231
+        )
232
+        r = probe.run(fresh_spec, RunContext(backend=backend))
233
+        assert r.evidence["from_cache"] is False
234
+
176235
     def test_skip_when_backend_not_null_calibrated(self) -> None:
177236
         class _Bare:
178237
             def as_base(self):  # noqa: ANN202