Python · 2854 bytes Raw Blame History
1 """Tests for the on-disk null-calibration cache."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from dlm_sway.probes._null_cache import compute_key, load, save
8
9
10 @pytest.fixture
11 def isolated_cache(tmp_path, monkeypatch):
12 """Redirect the cache root into a per-test tmp dir."""
13 monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path))
14 return tmp_path
15
16
17 class TestComputeKey:
18 def test_none_identity_returns_none(self) -> None:
19 assert compute_key(backend_identity=None, params={"runs": 3}) is None
20
21 def test_empty_identity_returns_none(self) -> None:
22 assert compute_key(backend_identity="", params={"runs": 3}) is None
23
24 def test_stable_across_calls(self) -> None:
25 k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
26 k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
27 assert k1 == k2
28
29 def test_changes_when_params_change(self) -> None:
30 k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
31 k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 5})
32 assert k1 != k2
33
34 def test_changes_when_identity_changes(self) -> None:
35 k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
36 k2 = compute_key(backend_identity="hf:foo:/tmp/b", params={"runs": 3})
37 assert k1 != k2
38
39
40 class TestLoadSave:
41 def test_save_then_load_roundtrip(self, isolated_cache) -> None:
42 stats = {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}}
43 key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
44 assert key is not None
45 save(key, stats)
46 loaded = load(key)
47 assert loaded == stats
48
49 def test_load_miss_returns_none(self, isolated_cache) -> None:
50 key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
51 assert load(key) is None
52
53 def test_none_key_roundtrip_noop(self, isolated_cache) -> None:
54 save(None, {"null_stats": {}})
55 assert load(None) is None
56
57 def test_malformed_json_is_treated_as_miss(self, isolated_cache, tmp_path) -> None:
58 key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
59 assert key is not None
60 # Manually write malformed content at the expected path.
61 cache_root = tmp_path / "dlm-sway" / "null-stats"
62 cache_root.mkdir(parents=True)
63 (cache_root / f"{key}.json").write_text("{ not json")
64 assert load(key) is None
65
66 def test_env_disable_bypasses_both(self, isolated_cache, monkeypatch) -> None:
67 monkeypatch.setenv("SWAY_DISABLE_NULL_CACHE", "1")
68 key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3})
69 save(key, {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}})
70 assert load(key) is None