Python · 6406 bytes Raw Blame History
1 """S23 — batched backend execution regression tests.
2
3 Pin the three invariants the sprint depends on:
4
5 1. A ``batch_score=True`` probe routes its scoring through
6 ``next_token_dist_batch`` (not the single-prompt path), and the
7 instrumentation counters reflect that.
8 2. The dummy backend's batched path produces results identical to the
9 single-prompt path — protocol default-loop correctness.
10 3. The report footer surfaces the batch counters alongside cache stats
11 when any batched forward fires.
12 """
13
14 from __future__ import annotations
15
16 import numpy as np
17
18 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
19 from dlm_sway.core.scoring import TokenDist
20 from dlm_sway.probes.base import RunContext, build_probe
21 from dlm_sway.probes.delta_kl import DeltaKLProbe
22 from dlm_sway.suite.report import _cache_line
23
24
25 def _planted_backend() -> DummyDifferentialBackend:
26 """Two prompts with distinguishable base vs ft distributions."""
27 base = DummyResponses(
28 token_dists={
29 "q1": TokenDist(
30 token_ids=np.array([1, 2, 3], dtype=np.int64),
31 logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)),
32 vocab_size=100,
33 ),
34 "q2": TokenDist(
35 token_ids=np.array([5, 6], dtype=np.int64),
36 logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)),
37 vocab_size=100,
38 ),
39 }
40 )
41 ft = DummyResponses(
42 token_dists={
43 "q1": TokenDist(
44 token_ids=np.array([1, 2, 3], dtype=np.int64),
45 logprobs=np.log(np.array([0.3, 0.35, 0.35], dtype=np.float32)),
46 vocab_size=100,
47 ),
48 "q2": TokenDist(
49 token_ids=np.array([5, 6], dtype=np.int64),
50 logprobs=np.log(np.array([0.4, 0.6], dtype=np.float32)),
51 vocab_size=100,
52 ),
53 }
54 )
55 return DummyDifferentialBackend(base=base, ft=ft)
56
57
58 def test_delta_kl_opt_in_flag_is_set() -> None:
59 """Guard against a future refactor accidentally unsetting the flag."""
60 assert DeltaKLProbe.batch_score is True
61
62
63 def test_batched_probe_routes_through_next_token_dist_batch() -> None:
64 """Running a batch_score=True probe must call the batched method
65 on the view — not fall back to the per-prompt path.
66
67 The dummy backend has no real forward to amortize, so we spy on
68 the batched method directly rather than assert on
69 ``batches_sent`` counters (those fire only when HF's real
70 batched compute hits ``cached_batch``)."""
71 backend = _planted_backend()
72 calls: list[tuple[str, tuple[str, ...]]] = []
73
74 original = backend.__class__.as_base
75
76 from contextlib import contextmanager
77
78 @contextmanager
79 def tracking_as_base(self): # type: ignore[no-untyped-def]
80 with original(self) as view:
81 orig_batch = view.next_token_dist_batch
82
83 def tracked(prompts, **kwargs): # type: ignore[no-untyped-def]
84 calls.append(("base", tuple(prompts)))
85 return orig_batch(prompts, **kwargs)
86
87 view.next_token_dist_batch = tracked # type: ignore[method-assign]
88 yield view
89
90 backend.__class__.as_base = tracking_as_base # type: ignore[method-assign]
91 try:
92 probe, spec = build_probe(
93 {
94 "name": "dk",
95 "kind": "delta_kl",
96 "prompts": ["q1", "q2"],
97 "assert_mean_gte": 0.01,
98 }
99 )
100 ctx = RunContext(backend=backend, seed=0, top_k=256)
101 probe.run(spec, ctx)
102 finally:
103 backend.__class__.as_base = original # type: ignore[method-assign]
104
105 assert calls == [("base", ("q1", "q2"))], (
106 f"expected one batched base call covering both prompts, got {calls!r}"
107 )
108
109
110 def test_batched_results_equal_serial_results() -> None:
111 """Dummy default-loop: batched path is serial internally so the
112 divergences must match a hand-computed single-prompt iteration."""
113 backend = _planted_backend()
114 with backend.as_base() as base_view:
115 batched = base_view.next_token_dist_batch(["q1", "q2"], top_k=10)
116 # Note: same view call twice so the cache hits on the second pass
117 # — but the TokenDists returned must be byte-identical.
118 serial_q1 = base_view.next_token_dist("q1", top_k=10)
119 serial_q2 = base_view.next_token_dist("q2", top_k=10)
120 np.testing.assert_array_equal(batched[0].token_ids, serial_q1.token_ids)
121 np.testing.assert_array_equal(batched[0].logprobs, serial_q1.logprobs)
122 np.testing.assert_array_equal(batched[1].token_ids, serial_q2.token_ids)
123 np.testing.assert_array_equal(batched[1].logprobs, serial_q2.logprobs)
124
125
126 def test_report_footer_surfaces_batches_when_nonzero() -> None:
127 """The cache_line footer includes the batches segment iff
128 batches_sent > 0. Runs without batching show cache line alone."""
129 from datetime import UTC, datetime
130
131 from dlm_sway.core.result import SuiteResult
132
133 now = datetime.now(tz=UTC)
134
135 def _suite(stats: dict[str, float | int]) -> SuiteResult:
136 return SuiteResult(
137 spec_path="x.yaml",
138 started_at=now,
139 finished_at=now,
140 base_model_id="stub",
141 adapter_id="stub",
142 sway_version="0.1.0",
143 backend_stats=stats,
144 )
145
146 # With batching.
147 line = _cache_line(
148 _suite(
149 {
150 "cache_hits": 5,
151 "cache_misses": 10,
152 "batches_sent": 3,
153 "batched_prompts": 18,
154 "avg_batch_size": 6.0,
155 "max_batch_size": 8,
156 }
157 )
158 )
159 assert line is not None
160 assert "cache: 5/15" in line
161 assert "batches: 3" in line
162 assert "avg=6.0" in line
163
164 # Without batching — pre-S23 footer shape preserved.
165 line_no_batch = _cache_line(_suite({"cache_hits": 5, "cache_misses": 10, "batches_sent": 0}))
166 assert line_no_batch is not None
167 assert "batches" not in line_no_batch
168
169
170 def test_empty_prompts_short_circuit() -> None:
171 """Empty prompt list on the batched path returns an empty list
172 without any forward work."""
173 backend = _planted_backend()
174 with backend.as_base() as base_view:
175 out = base_view.next_token_dist_batch([], top_k=10)
176 assert out == []