| 1 | """Tests for the on-disk null-calibration cache.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import pytest |
| 6 | |
| 7 | from dlm_sway.probes._null_cache import compute_key, load, save |
| 8 | |
| 9 | |
| 10 | @pytest.fixture |
| 11 | def isolated_cache(tmp_path, monkeypatch): |
| 12 | """Redirect the cache root into a per-test tmp dir.""" |
| 13 | monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path)) |
| 14 | return tmp_path |
| 15 | |
| 16 | |
| 17 | class TestComputeKey: |
| 18 | def test_none_identity_returns_none(self) -> None: |
| 19 | assert compute_key(backend_identity=None, params={"runs": 3}) is None |
| 20 | |
| 21 | def test_empty_identity_returns_none(self) -> None: |
| 22 | assert compute_key(backend_identity="", params={"runs": 3}) is None |
| 23 | |
| 24 | def test_stable_across_calls(self) -> None: |
| 25 | k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 26 | k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 27 | assert k1 == k2 |
| 28 | |
| 29 | def test_changes_when_params_change(self) -> None: |
| 30 | k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 31 | k2 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 5}) |
| 32 | assert k1 != k2 |
| 33 | |
| 34 | def test_changes_when_identity_changes(self) -> None: |
| 35 | k1 = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 36 | k2 = compute_key(backend_identity="hf:foo:/tmp/b", params={"runs": 3}) |
| 37 | assert k1 != k2 |
| 38 | |
| 39 | |
| 40 | class TestLoadSave: |
| 41 | def test_save_then_load_roundtrip(self, isolated_cache) -> None: |
| 42 | stats = {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}} |
| 43 | key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 44 | assert key is not None |
| 45 | save(key, stats) |
| 46 | loaded = load(key) |
| 47 | assert loaded == stats |
| 48 | |
| 49 | def test_load_miss_returns_none(self, isolated_cache) -> None: |
| 50 | key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 51 | assert load(key) is None |
| 52 | |
| 53 | def test_none_key_roundtrip_noop(self, isolated_cache) -> None: |
| 54 | save(None, {"null_stats": {}}) |
| 55 | assert load(None) is None |
| 56 | |
| 57 | def test_malformed_json_is_treated_as_miss(self, isolated_cache, tmp_path) -> None: |
| 58 | key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 59 | assert key is not None |
| 60 | # Manually write malformed content at the expected path. |
| 61 | cache_root = tmp_path / "dlm-sway" / "null-stats" |
| 62 | cache_root.mkdir(parents=True) |
| 63 | (cache_root / f"{key}.json").write_text("{ not json") |
| 64 | assert load(key) is None |
| 65 | |
| 66 | def test_env_disable_bypasses_both(self, isolated_cache, monkeypatch) -> None: |
| 67 | monkeypatch.setenv("SWAY_DISABLE_NULL_CACHE", "1") |
| 68 | key = compute_key(backend_identity="hf:foo:/tmp/a", params={"runs": 3}) |
| 69 | save(key, {"null_stats": {"delta_kl": {"mean": 0.01, "std": 0.002, "n": 3}}}) |
| 70 | assert load(key) is None |