Python · 12534 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.backends._instrumentation` (Sprint 07).
2
3 Covers the three invariants the cache + trace + stats plumbing must
4 hold:
5
6 1. **LRU correctness**: hit / miss / eviction at capacity.
7 2. **View-id isolation**: the same ``(prompt, top_k)`` on a different
8 view_id is a miss — the Sprint 07 cache must *not* cross-pollute
9 base and ft results (that would silently poison every probe).
10 3. **Trace writer**: ``path=None`` is a zero-overhead no-op; with a
11 path, the file is JSONL-parseable and carries the expected fields.
12 """
13
14 from __future__ import annotations
15
16 import json
17 from pathlib import Path
18
19 import pytest
20
21 from dlm_sway.backends._instrumentation import (
22 BackendInstrumentation,
23 BackendStats,
24 ForwardCache,
25 TraceWriter,
26 )
27
28
29 class TestForwardCache:
30 def test_hit_after_put(self) -> None:
31 cache = ForwardCache(maxsize=4)
32 key = ("next_token_dist", "base", "abc", 32)
33 cache.put(key, "value")
34 assert cache.get(key) == "value"
35
36 def test_miss_returns_sentinel(self) -> None:
37 from dlm_sway.backends._instrumentation import _MISS
38
39 cache = ForwardCache(maxsize=4)
40 assert cache.get(("missing",)) is _MISS
41
42 def test_lru_eviction(self) -> None:
43 cache = ForwardCache(maxsize=2)
44 cache.put(("a",), 1)
45 cache.put(("b",), 2)
46 cache.put(("c",), 3) # evicts "a" (LRU)
47 from dlm_sway.backends._instrumentation import _MISS
48
49 assert cache.get(("a",)) is _MISS
50 assert cache.get(("b",)) == 2
51 assert cache.get(("c",)) == 3
52
53 def test_get_promotes_to_mru(self) -> None:
54 """Re-accessing an entry bumps it to the front of the LRU."""
55 cache = ForwardCache(maxsize=2)
56 cache.put(("a",), 1)
57 cache.put(("b",), 2)
58 _ = cache.get(("a",)) # promote a to MRU
59 cache.put(("c",), 3) # should evict b, not a
60 from dlm_sway.backends._instrumentation import _MISS
61
62 assert cache.get(("a",)) == 1
63 assert cache.get(("b",)) is _MISS
64
65 def test_invalid_maxsize_rejected(self) -> None:
66 with pytest.raises(ValueError, match="maxsize must be positive"):
67 ForwardCache(maxsize=0)
68
69 def test_clear(self) -> None:
70 cache = ForwardCache(maxsize=4)
71 cache.put(("a",), 1)
72 cache.clear()
73 assert len(cache) == 0
74
75
76 class TestBackendInstrumentationCached:
77 def test_miss_then_hit(self) -> None:
78 inst = BackendInstrumentation()
79 call_count = {"n": 0}
80
81 def compute() -> str:
82 call_count["n"] += 1
83 return "computed"
84
85 v1 = inst.cached("next_token_dist", "base", "prompt", 32, compute)
86 v2 = inst.cached("next_token_dist", "base", "prompt", 32, compute)
87 assert v1 == v2 == "computed"
88 assert call_count["n"] == 1
89 assert inst.stats.cache_hits == 1
90 assert inst.stats.cache_misses == 1
91 assert inst.stats.forward_passes == 1
92
93 def test_view_id_isolates_cache_entries(self) -> None:
94 """Base and ft views on the same prompt must *not* collide."""
95 inst = BackendInstrumentation()
96 hits: list[str] = []
97
98 def compute_base() -> str:
99 hits.append("base")
100 return "base_value"
101
102 def compute_ft() -> str:
103 hits.append("ft")
104 return "ft_value"
105
106 v_base = inst.cached("next_token_dist", "base", "p", 32, compute_base)
107 v_ft = inst.cached("next_token_dist", "ft", "p", 32, compute_ft)
108 assert v_base == "base_value"
109 assert v_ft == "ft_value"
110 assert hits == ["base", "ft"] # neither side short-circuited
111
112 def test_top_k_isolates_cache_entries(self) -> None:
113 """top_k=8 vs top_k=32 are different cache entries."""
114 inst = BackendInstrumentation()
115 hits = 0
116
117 def compute() -> int:
118 nonlocal hits
119 hits += 1
120 return hits
121
122 a = inst.cached("next_token_dist", "base", "p", 8, compute)
123 b = inst.cached("next_token_dist", "base", "p", 32, compute)
124 assert a != b
125 assert hits == 2
126
127 def test_op_isolates_cache_entries(self) -> None:
128 """Same prompt via logprob_of vs next_token_dist → distinct keys."""
129 inst = BackendInstrumentation()
130 calls: list[str] = []
131
132 def c_lp() -> float:
133 calls.append("lp")
134 return -3.14
135
136 def c_dist() -> str:
137 calls.append("dist")
138 return "d"
139
140 inst.cached("logprob_of", "base", "p", 0, c_lp)
141 inst.cached("next_token_dist", "base", "p", 0, c_dist)
142 assert calls == ["lp", "dist"]
143
144
145 class TestBackendStats:
146 def test_hit_rate_zero_when_empty(self) -> None:
147 s = BackendStats()
148 assert s.hit_rate == 0.0
149
150 def test_to_dict_shape(self) -> None:
151 s = BackendStats(cache_hits=3, cache_misses=7, forward_passes=7, scoring_wall_s=1.5)
152 d = s.to_dict()
153 assert d["cache_hits"] == 3
154 assert d["cache_misses"] == 7
155 assert d["forward_passes"] == 7
156 assert d["scoring_wall_s"] == pytest.approx(1.5)
157 assert d["hit_rate"] == pytest.approx(0.3)
158
159 def test_avg_batch_size_zero_when_empty(self) -> None:
160 """S23 — no batches fired yet → avg is 0, not a div-by-zero."""
161 s = BackendStats()
162 assert s.avg_batch_size == 0.0
163 assert s.to_dict()["avg_batch_size"] == 0.0
164
165 def test_batch_counters_surface_in_to_dict(self) -> None:
166 """S23 — batch counters round-trip through to_dict()."""
167 s = BackendStats(batches_sent=2, batched_prompts=12, max_batch_size=8)
168 d = s.to_dict()
169 assert d["batches_sent"] == 2
170 assert d["batched_prompts"] == 12
171 assert d["max_batch_size"] == 8
172 assert d["avg_batch_size"] == pytest.approx(6.0)
173
174
175 class TestBackendInstrumentationCachedBatch:
176 """S23 — cached_batch routing + counter bookkeeping."""
177
178 def test_all_misses_fire_one_batch(self) -> None:
179 inst = BackendInstrumentation()
180 calls: list[list[int]] = []
181
182 def compute(miss_indices: list[int]) -> list[str]:
183 calls.append(list(miss_indices))
184 return [f"v{i}" for i in miss_indices]
185
186 out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute)
187 assert out == ["v0", "v1", "v2"]
188 # One forward call covering all 3.
189 assert calls == [[0, 1, 2]]
190 assert inst.stats.batches_sent == 1
191 assert inst.stats.batched_prompts == 3
192 assert inst.stats.max_batch_size == 3
193 assert inst.stats.avg_batch_size == pytest.approx(3.0)
194 assert inst.stats.cache_misses == 3
195 assert inst.stats.cache_hits == 0
196 assert inst.stats.forward_passes == 3
197
198 def test_partial_cache_hit_skips_cached_from_batch(self) -> None:
199 """Cache-per-prompt: hits skip the batch; only misses enter compute."""
200 inst = BackendInstrumentation()
201
202 # Warm one entry.
203 inst.cached("next_token_dist", "base", "p1", 32, lambda: "cached_v1")
204
205 misses: list[list[int]] = []
206
207 def compute(miss_indices: list[int]) -> list[str]:
208 misses.append(list(miss_indices))
209 # Only produces values for miss positions.
210 return [f"fresh_{i}" for i in miss_indices]
211
212 out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute)
213 # p1 served from cache; p2, p3 computed.
214 assert out == ["cached_v1", "fresh_1", "fresh_2"]
215 assert misses == [[1, 2]]
216 assert inst.stats.batches_sent == 1
217 assert inst.stats.batched_prompts == 2 # only the miss count
218 # Warmup was a miss; cached_batch hit p1 once + missed p2/p3.
219 assert inst.stats.cache_hits == 1
220 assert inst.stats.cache_misses == 3 # warmup + 2 batch misses
221
222 def test_all_cached_skips_forward(self) -> None:
223 """No misses → compute is never called, batches_sent stays 0."""
224 inst = BackendInstrumentation()
225 for p in ("p1", "p2"):
226 inst.cached("next_token_dist", "base", p, 32, lambda p=p: f"v_{p}")
227 inst.stats.batches_sent = 0 # reset from warmups
228 inst.stats.batched_prompts = 0
229 inst.stats.max_batch_size = 0
230
231 def compute(_idx: list[int]) -> list[str]:
232 raise AssertionError("compute should not have been called")
233
234 out = inst.cached_batch("next_token_dist", "base", ["p1", "p2"], 32, compute)
235 assert out == ["v_p1", "v_p2"]
236 assert inst.stats.batches_sent == 0
237 assert inst.stats.batched_prompts == 0
238
239 def test_max_batch_size_tracks_largest(self) -> None:
240 inst = BackendInstrumentation()
241
242 def c1(idx: list[int]) -> list[int]:
243 return list(idx)
244
245 inst.cached_batch("next_token_dist", "base", ["a", "b", "c"], 32, c1)
246 inst.cached_batch("next_token_dist", "base", ["d", "e"], 32, c1)
247 assert inst.stats.max_batch_size == 3
248
249 def test_wrong_return_length_raises(self) -> None:
250 inst = BackendInstrumentation()
251
252 def bad(idx: list[int]) -> list[int]:
253 return [0] # wrong length
254
255 with pytest.raises(RuntimeError, match="backend bug"):
256 inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, bad)
257
258 def test_empty_prompts_returns_empty(self) -> None:
259 """Sanity: an empty prompt list doesn't fire a batch."""
260 inst = BackendInstrumentation()
261
262 def compute(_idx: list[int]) -> list[int]:
263 raise AssertionError("compute should not have been called")
264
265 out = inst.cached_batch("next_token_dist", "base", [], 32, compute)
266 assert out == []
267 assert inst.stats.batches_sent == 0
268
269
270 class TestTraceWriter:
271 def test_disabled_is_noop(self, tmp_path: Path) -> None:
272 """``path=None`` creates no file and writes nothing."""
273 from dlm_sway.backends._instrumentation import _TraceEvent
274
275 writer = TraceWriter(None)
276 writer.write(
277 _TraceEvent(
278 ts=0.0,
279 probe="p",
280 view_id="base",
281 prompt_hash="x",
282 top_k=0,
283 op="o",
284 wall_ms=1.0,
285 hit=True,
286 )
287 )
288 writer.close()
289 # Nothing written anywhere — no assertion needed beyond the
290 # non-crash, but confirm the tmp_path is empty.
291 assert not any(tmp_path.iterdir())
292
293 def test_enabled_writes_jsonl(self, tmp_path: Path) -> None:
294 from dlm_sway.backends._instrumentation import _TraceEvent
295
296 trace_file = tmp_path / "trace.jsonl"
297 writer = TraceWriter(trace_file)
298 writer.write(
299 _TraceEvent(
300 ts=1.23,
301 probe="dk",
302 view_id="base",
303 prompt_hash="abc",
304 top_k=32,
305 op="next_token_dist",
306 wall_ms=0.75,
307 hit=False,
308 )
309 )
310 writer.write(
311 _TraceEvent(
312 ts=1.24,
313 probe="dk",
314 view_id="base",
315 prompt_hash="abc",
316 top_k=32,
317 op="next_token_dist",
318 wall_ms=0.01,
319 hit=True,
320 )
321 )
322 writer.close()
323
324 lines = trace_file.read_text(encoding="utf-8").strip().splitlines()
325 assert len(lines) == 2
326 first = json.loads(lines[0])
327 assert first["probe"] == "dk"
328 assert first["view_id"] == "base"
329 assert first["op"] == "next_token_dist"
330 assert first["hit"] is False
331 assert json.loads(lines[1])["hit"] is True
332
333 def test_instrumentation_end_to_end_trace(self, tmp_path: Path) -> None:
334 """Full path: BackendInstrumentation → cached() → trace file."""
335 trace_file = tmp_path / "trace.jsonl"
336 inst = BackendInstrumentation()
337 inst.trace = TraceWriter(trace_file)
338 inst.set_current_probe("my_probe")
339
340 inst.cached("next_token_dist", "base", "the capital", 16, lambda: "computed")
341 inst.cached("next_token_dist", "base", "the capital", 16, lambda: "computed")
342 inst.close()
343
344 lines = trace_file.read_text(encoding="utf-8").strip().splitlines()
345 events = [json.loads(line) for line in lines]
346 assert len(events) == 2
347 assert events[0]["probe"] == "my_probe"
348 assert events[0]["hit"] is False
349 assert events[1]["hit"] is True