tenseleyflow/sway / 4aaa584

Browse files

tests/scoring+instrumentation: new FakeScoring.next_token_dist_batch + cached_batch coverage

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
4aaa584f0fea515c9eee47bbe0be669524cb3978
Parents
198cd55
Tree
9cb055b

2 changed files

StatusFile+-
M tests/unit/test_backend_instrumentation.py 110 0
M tests/unit/test_scoring.py 11 0
tests/unit/test_backend_instrumentation.pymodified
@@ -156,6 +156,116 @@ class TestBackendStats:
156156
         assert d["scoring_wall_s"] == pytest.approx(1.5)
157157
         assert d["hit_rate"] == pytest.approx(0.3)
158158
 
159
+    def test_avg_batch_size_zero_when_empty(self) -> None:
160
+        """S23 — no batches fired yet → avg is 0, not a div-by-zero."""
161
+        s = BackendStats()
162
+        assert s.avg_batch_size == 0.0
163
+        assert s.to_dict()["avg_batch_size"] == 0.0
164
+
165
+    def test_batch_counters_surface_in_to_dict(self) -> None:
166
+        """S23 — batch counters round-trip through to_dict()."""
167
+        s = BackendStats(batches_sent=2, batched_prompts=12, max_batch_size=8)
168
+        d = s.to_dict()
169
+        assert d["batches_sent"] == 2
170
+        assert d["batched_prompts"] == 12
171
+        assert d["max_batch_size"] == 8
172
+        assert d["avg_batch_size"] == pytest.approx(6.0)
173
+
174
+
175
+class TestBackendInstrumentationCachedBatch:
176
+    """S23 — cached_batch routing + counter bookkeeping."""
177
+
178
+    def test_all_misses_fire_one_batch(self) -> None:
179
+        inst = BackendInstrumentation()
180
+        calls: list[list[int]] = []
181
+
182
+        def compute(miss_indices: list[int]) -> list[str]:
183
+            calls.append(list(miss_indices))
184
+            return [f"v{i}" for i in miss_indices]
185
+
186
+        out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute)
187
+        assert out == ["v0", "v1", "v2"]
188
+        # One forward call covering all 3.
189
+        assert calls == [[0, 1, 2]]
190
+        assert inst.stats.batches_sent == 1
191
+        assert inst.stats.batched_prompts == 3
192
+        assert inst.stats.max_batch_size == 3
193
+        assert inst.stats.avg_batch_size == pytest.approx(3.0)
194
+        assert inst.stats.cache_misses == 3
195
+        assert inst.stats.cache_hits == 0
196
+        assert inst.stats.forward_passes == 3
197
+
198
+    def test_partial_cache_hit_skips_cached_from_batch(self) -> None:
199
+        """Cache-per-prompt: hits skip the batch; only misses enter compute."""
200
+        inst = BackendInstrumentation()
201
+
202
+        # Warm one entry.
203
+        inst.cached("next_token_dist", "base", "p1", 32, lambda: "cached_v1")
204
+
205
+        misses: list[list[int]] = []
206
+
207
+        def compute(miss_indices: list[int]) -> list[str]:
208
+            misses.append(list(miss_indices))
209
+            # Only produces values for miss positions.
210
+            return [f"fresh_{i}" for i in miss_indices]
211
+
212
+        out = inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, compute)
213
+        # p1 served from cache; p2, p3 computed.
214
+        assert out == ["cached_v1", "fresh_1", "fresh_2"]
215
+        assert misses == [[1, 2]]
216
+        assert inst.stats.batches_sent == 1
217
+        assert inst.stats.batched_prompts == 2  # only the miss count
218
+        # Warmup was a miss; cached_batch hit p1 once + missed p2/p3.
219
+        assert inst.stats.cache_hits == 1
220
+        assert inst.stats.cache_misses == 3  # warmup + 2 batch misses
221
+
222
+    def test_all_cached_skips_forward(self) -> None:
223
+        """No misses → compute is never called, batches_sent stays 0."""
224
+        inst = BackendInstrumentation()
225
+        for p in ("p1", "p2"):
226
+            inst.cached("next_token_dist", "base", p, 32, lambda p=p: f"v_{p}")
227
+        inst.stats.batches_sent = 0  # reset from warmups
228
+        inst.stats.batched_prompts = 0
229
+        inst.stats.max_batch_size = 0
230
+
231
+        def compute(_idx: list[int]) -> list[str]:
232
+            raise AssertionError("compute should not have been called")
233
+
234
+        out = inst.cached_batch("next_token_dist", "base", ["p1", "p2"], 32, compute)
235
+        assert out == ["v_p1", "v_p2"]
236
+        assert inst.stats.batches_sent == 0
237
+        assert inst.stats.batched_prompts == 0
238
+
239
+    def test_max_batch_size_tracks_largest(self) -> None:
240
+        inst = BackendInstrumentation()
241
+
242
+        def c1(idx: list[int]) -> list[int]:
243
+            return list(idx)
244
+
245
+        inst.cached_batch("next_token_dist", "base", ["a", "b", "c"], 32, c1)
246
+        inst.cached_batch("next_token_dist", "base", ["d", "e"], 32, c1)
247
+        assert inst.stats.max_batch_size == 3
248
+
249
+    def test_wrong_return_length_raises(self) -> None:
250
+        inst = BackendInstrumentation()
251
+
252
+        def bad(idx: list[int]) -> list[int]:
253
+            return [0]  # wrong length
254
+
255
+        with pytest.raises(RuntimeError, match="backend bug"):
256
+            inst.cached_batch("next_token_dist", "base", ["p1", "p2", "p3"], 32, bad)
257
+
258
+    def test_empty_prompts_returns_empty(self) -> None:
259
+        """Sanity: an empty prompt list doesn't fire a batch."""
260
+        inst = BackendInstrumentation()
261
+
262
+        def compute(_idx: list[int]) -> list[int]:
263
+            raise AssertionError("compute should not have been called")
264
+
265
+        out = inst.cached_batch("next_token_dist", "base", [], 32, compute)
266
+        assert out == []
267
+        assert inst.stats.batches_sent == 0
268
+
159269
 
160270
 class TestTraceWriter:
161271
     def test_disabled_is_noop(self, tmp_path: Path) -> None:
tests/unit/test_scoring.pymodified
@@ -88,6 +88,17 @@ class TestProtocols:
8888
                     vocab_size=1,
8989
                 )
9090
 
91
+            def next_token_dist_batch(
92
+                self,
93
+                prompts,  # type: ignore[no-untyped-def]
94
+                *,
95
+                top_k: int = 256,
96
+            ) -> list[TokenDist]:
97
+                # S23 — Protocol requires the batched method at
98
+                # runtime. Defer to the single-prompt path; enough to
99
+                # satisfy the runtime_checkable isinstance check.
100
+                return [self.next_token_dist(p, top_k=top_k) for p in prompts]
101
+
91102
         assert isinstance(FakeScoring(), ScoringBackend)
92103
 
93104
     def test_differential_backend_runtime_checkable(self) -> None: