| 1 | """Tests for :mod:`dlm_sway.backends._instrumentation` (Sprint 07). |
| 2 | |
| 3 | Covers the three invariants the cache + trace + stats plumbing must |
| 4 | hold: |
| 5 | |
| 6 | 1. **LRU correctness**: hit / miss / eviction at capacity. |
| 7 | 2. **View-id isolation**: the same ``(prompt, top_k)`` on a different |
| 8 | view_id is a miss — the Sprint 07 cache must *not* cross-pollute |
| 9 | base and ft results (that would silently poison every probe). |
| 10 | 3. **Trace writer**: ``path=None`` is a zero-overhead no-op; with a |
| 11 | path, the file is JSONL-parseable and carries the expected fields. |
| 12 | """ |
| 13 | |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import json |
| 17 | from pathlib import Path |
| 18 | |
| 19 | import pytest |
| 20 | |
| 21 | from dlm_sway.backends._instrumentation import ( |
| 22 | BackendInstrumentation, |
| 23 | BackendStats, |
| 24 | ForwardCache, |
| 25 | TraceWriter, |
| 26 | ) |
| 27 | |
| 28 | |
| 29 | class TestForwardCache: |
| 30 | def test_hit_after_put(self) -> None: |
| 31 | cache = ForwardCache(maxsize=4) |
| 32 | key = ("next_token_dist", "base", "abc", 32) |
| 33 | cache.put(key, "value") |
| 34 | assert cache.get(key) == "value" |
| 35 | |
| 36 | def test_miss_returns_sentinel(self) -> None: |
| 37 | from dlm_sway.backends._instrumentation import _MISS |
| 38 | |
| 39 | cache = ForwardCache(maxsize=4) |
| 40 | assert cache.get(("missing",)) is _MISS |
| 41 | |
| 42 | def test_lru_eviction(self) -> None: |
| 43 | cache = ForwardCache(maxsize=2) |
| 44 | cache.put(("a",), 1) |
| 45 | cache.put(("b",), 2) |
| 46 | cache.put(("c",), 3) # evicts "a" (LRU) |
| 47 | from dlm_sway.backends._instrumentation import _MISS |
| 48 | |
| 49 | assert cache.get(("a",)) is _MISS |
| 50 | assert cache.get(("b",)) == 2 |
| 51 | assert cache.get(("c",)) == 3 |
| 52 | |
| 53 | def test_get_promotes_to_mru(self) -> None: |
| 54 | """Re-accessing an entry bumps it to the front of the LRU.""" |
| 55 | cache = ForwardCache(maxsize=2) |
| 56 | cache.put(("a",), 1) |
| 57 | cache.put(("b",), 2) |
| 58 | _ = cache.get(("a",)) # promote a to MRU |
| 59 | cache.put(("c",), 3) # should evict b, not a |
| 60 | from dlm_sway.backends._instrumentation import _MISS |
| 61 | |
| 62 | assert cache.get(("a",)) == 1 |
| 63 | assert cache.get(("b",)) is _MISS |
| 64 | |
| 65 | def test_invalid_maxsize_rejected(self) -> None: |
| 66 | with pytest.raises(ValueError, match="maxsize must be positive"): |
| 67 | ForwardCache(maxsize=0) |
| 68 | |
| 69 | def test_clear(self) -> None: |
| 70 | cache = ForwardCache(maxsize=4) |
| 71 | cache.put(("a",), 1) |
| 72 | cache.clear() |
| 73 | assert len(cache) == 0 |
| 74 | |
| 75 | |
| 76 | class TestBackendInstrumentationCached: |
| 77 | def test_miss_then_hit(self) -> None: |
| 78 | inst = BackendInstrumentation() |
| 79 | call_count = {"n": 0} |
| 80 | |
| 81 | def compute() -> str: |
| 82 | call_count["n"] += 1 |
| 83 | return "computed" |
| 84 | |
| 85 | v1 = inst.cached("next_token_dist", "base", "prompt", 32, compute) |
| 86 | v2 = inst.cached("next_token_dist", "base", "prompt", 32, compute) |
| 87 | assert v1 == v2 == "computed" |
| 88 | assert call_count["n"] == 1 |
| 89 | assert inst.stats.cache_hits == 1 |
| 90 | assert inst.stats.cache_misses == 1 |
| 91 | assert inst.stats.forward_passes == 1 |
| 92 | |
| 93 | def test_view_id_isolates_cache_entries(self) -> None: |
| 94 | """Base and ft views on the same prompt must *not* collide.""" |
| 95 | inst = BackendInstrumentation() |
| 96 | hits: list[str] = [] |
| 97 | |
| 98 | def compute_base() -> str: |
| 99 | hits.append("base") |
| 100 | return "base_value" |
| 101 | |
| 102 | def compute_ft() -> str: |
| 103 | hits.append("ft") |
| 104 | return "ft_value" |
| 105 | |
| 106 | v_base = inst.cached("next_token_dist", "base", "p", 32, compute_base) |
| 107 | v_ft = inst.cached("next_token_dist", "ft", "p", 32, compute_ft) |
| 108 | assert v_base == "base_value" |
| 109 | assert v_ft == "ft_value" |
| 110 | assert hits == ["base", "ft"] # neither side short-circuited |
| 111 | |
| 112 | def test_top_k_isolates_cache_entries(self) -> None: |
| 113 | """top_k=8 vs top_k=32 are different cache entries.""" |
| 114 | inst = BackendInstrumentation() |
| 115 | hits = 0 |
| 116 | |
| 117 | def compute() -> int: |
| 118 | nonlocal hits |
| 119 | hits += 1 |
| 120 | return hits |
| 121 | |
| 122 | a = inst.cached("next_token_dist", "base", "p", 8, compute) |
| 123 | b = inst.cached("next_token_dist", "base", "p", 32, compute) |
| 124 | assert a != b |
| 125 | assert hits == 2 |
| 126 | |
| 127 | def test_op_isolates_cache_entries(self) -> None: |
| 128 | """Same prompt via logprob_of vs next_token_dist → distinct keys.""" |
| 129 | inst = BackendInstrumentation() |
| 130 | calls: list[str] = [] |
| 131 | |
| 132 | def c_lp() -> float: |
| 133 | calls.append("lp") |
| 134 | return -3.14 |
| 135 | |
| 136 | def c_dist() -> str: |
| 137 | calls.append("dist") |
| 138 | return "d" |
| 139 | |
| 140 | inst.cached("logprob_of", "base", "p", 0, c_lp) |
| 141 | inst.cached("next_token_dist", "base", "p", 0, c_dist) |
| 142 | assert calls == ["lp", "dist"] |
| 143 | |
| 144 | |
| 145 | class TestBackendStats: |
| 146 | def test_hit_rate_zero_when_empty(self) -> None: |
| 147 | s = BackendStats() |
| 148 | assert s.hit_rate == 0.0 |
| 149 | |
| 150 | def test_to_dict_shape(self) -> None: |
| 151 | s = BackendStats(cache_hits=3, cache_misses=7, forward_passes=7, scoring_wall_s=1.5) |
| 152 | d = s.to_dict() |
| 153 | assert d["cache_hits"] == 3 |
| 154 | assert d["cache_misses"] == 7 |
| 155 | assert d["forward_passes"] == 7 |
| 156 | assert d["scoring_wall_s"] == pytest.approx(1.5) |
| 157 | assert d["hit_rate"] == pytest.approx(0.3) |
| 158 | |
| 159 | def test_avg_batch_size_zero_when_empty(self) -> None: |
| 160 | """S23 — no batches fired yet → avg is 0, not a div-by-zero.""" |
| 161 | s = BackendStats() |
| 162 | assert s.avg_batch_size == 0.0 |
| 163 | assert s.to_dict()["avg_batch_size"] == 0.0 |
| 164 | |
| 165 | def test_batch_counters_surface_in_to_dict(self) -> None: |
| 166 | """S23 — batch counters round-trip through to_dict().""" |
| 167 | s = BackendStats(batches_sent=2, batched_prompts=12, max_batch_size=8) |
| 168 | d = s.to_dict() |
| 169 | assert d["batches_sent"] == 2 |
| 170 | assert d["batched_prompts"] == 12 |
| 171 | assert d["max_batch_size"] == 8 |
| 172 | assert d["avg_batch_size"] == pytest.approx(6.0) |
| 173 | |
| 174 | |
| 175 | class TestBackendInstrumentationCachedBatch: |
| 176 | """S23 — cached_batch routing + counter bookkeeping.""" |
| 177 | |
| 178 | def test_all_misses_fire_one_batch(self) -> None: |
| 179 | inst = BackendInstrumentation() |
| 180 | calls: list[list[int]] = [] |
| 181 | |
| 182 | def compute(miss_indices: list[int]) -> list[str]: |
| 183 | calls.append(list(miss_indices)) |
| 184 | return [f"v{i}" for i in miss_indices] |
| 185 | |
| 186 | out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute) |
| 187 | assert out == ["v0", "v1", "v2"] |
| 188 | # One forward call covering all 3. |
| 189 | assert calls == [[0, 1, 2]] |
| 190 | assert inst.stats.batches_sent == 1 |
| 191 | assert inst.stats.batched_prompts == 3 |
| 192 | assert inst.stats.max_batch_size == 3 |
| 193 | assert inst.stats.avg_batch_size == pytest.approx(3.0) |
| 194 | assert inst.stats.cache_misses == 3 |
| 195 | assert inst.stats.cache_hits == 0 |
| 196 | assert inst.stats.forward_passes == 3 |
| 197 | |
| 198 | def test_partial_cache_hit_skips_cached_from_batch(self) -> None: |
| 199 | """Cache-per-prompt: hits skip the batch; only misses enter compute.""" |
| 200 | inst = BackendInstrumentation() |
| 201 | |
| 202 | # Warm one entry. |
| 203 | inst.cached("next_token_dist", "base", "p1", 32, lambda: "cached_v1") |
| 204 | |
| 205 | misses: list[list[int]] = [] |
| 206 | |
| 207 | def compute(miss_indices: list[int]) -> list[str]: |
| 208 | misses.append(list(miss_indices)) |
| 209 | # Only produces values for miss positions. |
| 210 | return [f"fresh_{i}" for i in miss_indices] |
| 211 | |
| 212 | out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute) |
| 213 | # p1 served from cache; p2, p3 computed. |
| 214 | assert out == ["cached_v1", "fresh_1", "fresh_2"] |
| 215 | assert misses == [[1, 2]] |
| 216 | assert inst.stats.batches_sent == 1 |
| 217 | assert inst.stats.batched_prompts == 2 # only the miss count |
| 218 | # Warmup was a miss; cached_batch hit p1 once + missed p2/p3. |
| 219 | assert inst.stats.cache_hits == 1 |
| 220 | assert inst.stats.cache_misses == 3 # warmup + 2 batch misses |
| 221 | |
| 222 | def test_all_cached_skips_forward(self) -> None: |
| 223 | """No misses → compute is never called, batches_sent stays 0.""" |
| 224 | inst = BackendInstrumentation() |
| 225 | for p in ("p1", "p2"): |
| 226 | inst.cached("next_token_dist", "base", p, 32, lambda p=p: f"v_{p}") |
| 227 | inst.stats.batches_sent = 0 # reset from warmups |
| 228 | inst.stats.batched_prompts = 0 |
| 229 | inst.stats.max_batch_size = 0 |
| 230 | |
| 231 | def compute(_idx: list[int]) -> list[str]: |
| 232 | raise AssertionError("compute should not have been called") |
| 233 | |
| 234 | out = inst.cached_batch("next_token_dist", "base", ["p1", "p2"], 32, compute) |
| 235 | assert out == ["v_p1", "v_p2"] |
| 236 | assert inst.stats.batches_sent == 0 |
| 237 | assert inst.stats.batched_prompts == 0 |
| 238 | |
| 239 | def test_max_batch_size_tracks_largest(self) -> None: |
| 240 | inst = BackendInstrumentation() |
| 241 | |
| 242 | def c1(idx: list[int]) -> list[int]: |
| 243 | return list(idx) |
| 244 | |
| 245 | inst.cached_batch("next_token_dist", "base", ["a", "b", "c"], 32, c1) |
| 246 | inst.cached_batch("next_token_dist", "base", ["d", "e"], 32, c1) |
| 247 | assert inst.stats.max_batch_size == 3 |
| 248 | |
| 249 | def test_wrong_return_length_raises(self) -> None: |
| 250 | inst = BackendInstrumentation() |
| 251 | |
| 252 | def bad(idx: list[int]) -> list[int]: |
| 253 | return [0] # wrong length |
| 254 | |
| 255 | with pytest.raises(RuntimeError, match="backend bug"): |
| 256 | inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, bad) |
| 257 | |
| 258 | def test_empty_prompts_returns_empty(self) -> None: |
| 259 | """Sanity: an empty prompt list doesn't fire a batch.""" |
| 260 | inst = BackendInstrumentation() |
| 261 | |
| 262 | def compute(_idx: list[int]) -> list[int]: |
| 263 | raise AssertionError("compute should not have been called") |
| 264 | |
| 265 | out = inst.cached_batch("next_token_dist", "base", [], 32, compute) |
| 266 | assert out == [] |
| 267 | assert inst.stats.batches_sent == 0 |
| 268 | |
| 269 | |
| 270 | class TestTraceWriter: |
| 271 | def test_disabled_is_noop(self, tmp_path: Path) -> None: |
| 272 | """``path=None`` creates no file and writes nothing.""" |
| 273 | from dlm_sway.backends._instrumentation import _TraceEvent |
| 274 | |
| 275 | writer = TraceWriter(None) |
| 276 | writer.write( |
| 277 | _TraceEvent( |
| 278 | ts=0.0, |
| 279 | probe="p", |
| 280 | view_id="base", |
| 281 | prompt_hash="x", |
| 282 | top_k=0, |
| 283 | op="o", |
| 284 | wall_ms=1.0, |
| 285 | hit=True, |
| 286 | ) |
| 287 | ) |
| 288 | writer.close() |
| 289 | # Nothing written anywhere — no assertion needed beyond the |
| 290 | # non-crash, but confirm the tmp_path is empty. |
| 291 | assert not any(tmp_path.iterdir()) |
| 292 | |
| 293 | def test_enabled_writes_jsonl(self, tmp_path: Path) -> None: |
| 294 | from dlm_sway.backends._instrumentation import _TraceEvent |
| 295 | |
| 296 | trace_file = tmp_path / "trace.jsonl" |
| 297 | writer = TraceWriter(trace_file) |
| 298 | writer.write( |
| 299 | _TraceEvent( |
| 300 | ts=1.23, |
| 301 | probe="dk", |
| 302 | view_id="base", |
| 303 | prompt_hash="abc", |
| 304 | top_k=32, |
| 305 | op="next_token_dist", |
| 306 | wall_ms=0.75, |
| 307 | hit=False, |
| 308 | ) |
| 309 | ) |
| 310 | writer.write( |
| 311 | _TraceEvent( |
| 312 | ts=1.24, |
| 313 | probe="dk", |
| 314 | view_id="base", |
| 315 | prompt_hash="abc", |
| 316 | top_k=32, |
| 317 | op="next_token_dist", |
| 318 | wall_ms=0.01, |
| 319 | hit=True, |
| 320 | ) |
| 321 | ) |
| 322 | writer.close() |
| 323 | |
| 324 | lines = trace_file.read_text(encoding="utf-8").strip().splitlines() |
| 325 | assert len(lines) == 2 |
| 326 | first = json.loads(lines[0]) |
| 327 | assert first["probe"] == "dk" |
| 328 | assert first["view_id"] == "base" |
| 329 | assert first["op"] == "next_token_dist" |
| 330 | assert first["hit"] is False |
| 331 | assert json.loads(lines[1])["hit"] is True |
| 332 | |
| 333 | def test_instrumentation_end_to_end_trace(self, tmp_path: Path) -> None: |
| 334 | """Full path: BackendInstrumentation → cached() → trace file.""" |
| 335 | trace_file = tmp_path / "trace.jsonl" |
| 336 | inst = BackendInstrumentation() |
| 337 | inst.trace = TraceWriter(trace_file) |
| 338 | inst.set_current_probe("my_probe") |
| 339 | |
| 340 | inst.cached("next_token_dist", "base", "the capital", 16, lambda: "computed") |
| 341 | inst.cached("next_token_dist", "base", "the capital", 16, lambda: "computed") |
| 342 | inst.close() |
| 343 | |
| 344 | lines = trace_file.read_text(encoding="utf-8").strip().splitlines() |
| 345 | events = [json.loads(line) for line in lines] |
| 346 | assert len(events) == 2 |
| 347 | assert events[0]["probe"] == "my_probe" |
| 348 | assert events[0]["hit"] is False |
| 349 | assert events[1]["hit"] is True |