tenseleyflow/sway / 354c460

Browse files

tests/batched_backend_s23: probe-level batched-path + results-equivalence + footer coverage

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
354c46067360957af0f8e62fc1d6002e7c0ff18b
Parents
4aaa584
Tree
ba61989

1 changed file

StatusFile+-
A tests/unit/test_batched_backend_s23.py 176 0
tests/unit/test_batched_backend_s23.pyadded
@@ -0,0 +1,176 @@
1
+"""S23 — batched backend execution regression tests.
2
+
3
+Pin the three invariants the sprint depends on:
4
+
5
+1. A ``batch_score=True`` probe routes its scoring through
6
+   ``next_token_dist_batch`` (not the single-prompt path), and the
7
+   instrumentation counters reflect that.
8
+2. The dummy backend's batched path produces results identical to the
9
+   single-prompt path — protocol default-loop correctness.
10
+3. The report footer surfaces the batch counters alongside cache stats
11
+   when any batched forward fires.
12
+"""
13
+
14
+from __future__ import annotations
15
+
16
+import numpy as np
17
+
18
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
19
+from dlm_sway.core.scoring import TokenDist
20
+from dlm_sway.probes.base import RunContext, build_probe
21
+from dlm_sway.probes.delta_kl import DeltaKLProbe
22
+from dlm_sway.suite.report import _cache_line
23
+
24
+
25
+def _planted_backend() -> DummyDifferentialBackend:
26
+    """Two prompts with distinguishable base vs ft distributions."""
27
+    base = DummyResponses(
28
+        token_dists={
29
+            "q1": TokenDist(
30
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
31
+                logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)),
32
+                vocab_size=100,
33
+            ),
34
+            "q2": TokenDist(
35
+                token_ids=np.array([5, 6], dtype=np.int64),
36
+                logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)),
37
+                vocab_size=100,
38
+            ),
39
+        }
40
+    )
41
+    ft = DummyResponses(
42
+        token_dists={
43
+            "q1": TokenDist(
44
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
45
+                logprobs=np.log(np.array([0.3, 0.35, 0.35], dtype=np.float32)),
46
+                vocab_size=100,
47
+            ),
48
+            "q2": TokenDist(
49
+                token_ids=np.array([5, 6], dtype=np.int64),
50
+                logprobs=np.log(np.array([0.4, 0.6], dtype=np.float32)),
51
+                vocab_size=100,
52
+            ),
53
+        }
54
+    )
55
+    return DummyDifferentialBackend(base=base, ft=ft)
56
+
57
+
58
+def test_delta_kl_opt_in_flag_is_set() -> None:
59
+    """Guard against a future refactor accidentally unsetting the flag."""
60
+    assert DeltaKLProbe.batch_score is True
61
+
62
+
63
+def test_batched_probe_routes_through_next_token_dist_batch() -> None:
64
+    """Running a batch_score=True probe must call the batched method
65
+    on the view — not fall back to the per-prompt path.
66
+
67
+    The dummy backend has no real forward to amortize, so we spy on
68
+    the batched method directly rather than assert on
69
+    ``batches_sent`` counters (those fire only when HF's real
70
+    batched compute hits ``cached_batch``)."""
71
+    backend = _planted_backend()
72
+    calls: list[tuple[str, tuple[str, ...]]] = []
73
+
74
+    original = backend.__class__.as_base
75
+
76
+    from contextlib import contextmanager
77
+
78
+    @contextmanager
79
+    def tracking_as_base(self):  # type: ignore[no-untyped-def]
80
+        with original(self) as view:
81
+            orig_batch = view.next_token_dist_batch
82
+
83
+            def tracked(prompts, **kwargs):  # type: ignore[no-untyped-def]
84
+                calls.append(("base", tuple(prompts)))
85
+                return orig_batch(prompts, **kwargs)
86
+
87
+            view.next_token_dist_batch = tracked  # type: ignore[method-assign]
88
+            yield view
89
+
90
+    backend.__class__.as_base = tracking_as_base  # type: ignore[method-assign]
91
+    try:
92
+        probe, spec = build_probe(
93
+            {
94
+                "name": "dk",
95
+                "kind": "delta_kl",
96
+                "prompts": ["q1", "q2"],
97
+                "assert_mean_gte": 0.01,
98
+            }
99
+        )
100
+        ctx = RunContext(backend=backend, seed=0, top_k=256)
101
+        probe.run(spec, ctx)
102
+    finally:
103
+        backend.__class__.as_base = original  # type: ignore[method-assign]
104
+
105
+    assert calls == [("base", ("q1", "q2"))], (
106
+        f"expected one batched base call covering both prompts, got {calls!r}"
107
+    )
108
+
109
+
110
+def test_batched_results_equal_serial_results() -> None:
111
+    """Dummy default-loop: batched path is serial internally so the
112
+    divergences must match a hand-computed single-prompt iteration."""
113
+    backend = _planted_backend()
114
+    with backend.as_base() as base_view:
115
+        batched = base_view.next_token_dist_batch(["q1", "q2"], top_k=10)
116
+        # Note: same view call twice so the cache hits on the second pass
117
+        # — but the TokenDists returned must be byte-identical.
118
+        serial_q1 = base_view.next_token_dist("q1", top_k=10)
119
+        serial_q2 = base_view.next_token_dist("q2", top_k=10)
120
+    np.testing.assert_array_equal(batched[0].token_ids, serial_q1.token_ids)
121
+    np.testing.assert_array_equal(batched[0].logprobs, serial_q1.logprobs)
122
+    np.testing.assert_array_equal(batched[1].token_ids, serial_q2.token_ids)
123
+    np.testing.assert_array_equal(batched[1].logprobs, serial_q2.logprobs)
124
+
125
+
126
+def test_report_footer_surfaces_batches_when_nonzero() -> None:
127
+    """The cache_line footer includes the batches segment iff
128
+    batches_sent > 0. Runs without batching show cache line alone."""
129
+    from datetime import UTC, datetime
130
+
131
+    from dlm_sway.core.result import SuiteResult
132
+
133
+    now = datetime.now(tz=UTC)
134
+
135
+    def _suite(stats: dict[str, float | int]) -> SuiteResult:
136
+        return SuiteResult(
137
+            spec_path="x.yaml",
138
+            started_at=now,
139
+            finished_at=now,
140
+            base_model_id="stub",
141
+            adapter_id="stub",
142
+            sway_version="0.1.0",
143
+            backend_stats=stats,
144
+        )
145
+
146
+    # With batching.
147
+    line = _cache_line(
148
+        _suite(
149
+            {
150
+                "cache_hits": 5,
151
+                "cache_misses": 10,
152
+                "batches_sent": 3,
153
+                "batched_prompts": 18,
154
+                "avg_batch_size": 6.0,
155
+                "max_batch_size": 8,
156
+            }
157
+        )
158
+    )
159
+    assert line is not None
160
+    assert "cache: 5/15" in line
161
+    assert "batches: 3" in line
162
+    assert "avg=6.0" in line
163
+
164
+    # Without batching — pre-S23 footer shape preserved.
165
+    line_no_batch = _cache_line(_suite({"cache_hits": 5, "cache_misses": 10, "batches_sent": 0}))
166
+    assert line_no_batch is not None
167
+    assert "batches" not in line_no_batch
168
+
169
+
170
+def test_empty_prompts_short_circuit() -> None:
171
+    """Empty prompt list on the batched path returns an empty list
172
+    without any forward work."""
173
+    backend = _planted_backend()
174
+    with backend.as_base() as base_view:
175
+        out = base_view.next_token_dist_batch([], top_k=10)
176
+    assert out == []