Python · 9258 bytes Raw Blame History
1 """End-to-end Sprint 07 checks via the suite runner.
2
3 Asserts:
4
5 - ``SuiteResult.backend_stats`` is populated after a run
6 - duplicate (view, prompt, top_k) lookups hit the cache
7 - ``--trace`` writes JSONL with the expected per-probe labels
8 """
9
10 from __future__ import annotations
11
12 import json
13 from pathlib import Path
14
15 import numpy as np
16
17 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
18 from dlm_sway.core.scoring import TokenDist
19 from dlm_sway.suite.runner import run as run_suite
20 from dlm_sway.suite.spec import SwaySpec
21
22
23 def _programmable_backend() -> DummyDifferentialBackend:
24 base_dist = TokenDist(
25 token_ids=np.array([1, 2, 3], dtype=np.int64),
26 logprobs=np.log(np.array([0.7, 0.2, 0.1], dtype=np.float32)),
27 vocab_size=100,
28 )
29 ft_dist = TokenDist(
30 token_ids=np.array([1, 2, 3], dtype=np.int64),
31 logprobs=np.log(np.array([0.2, 0.4, 0.4], dtype=np.float32)),
32 vocab_size=100,
33 )
34 return DummyDifferentialBackend(
35 base=DummyResponses(token_dists={"hello": base_dist, "world": base_dist}),
36 ft=DummyResponses(token_dists={"hello": ft_dist, "world": ft_dist}),
37 )
38
39
40 def _spec_with_repeated_prompts() -> SwaySpec:
41 """Two delta_kl probes over overlapping prompts — the cache should
42 serve the second probe's base/ft dists from the first probe's hits."""
43 return SwaySpec.model_validate(
44 {
45 "version": 1,
46 "models": {
47 "base": {"base": "b"},
48 "ft": {"base": "b", "adapter": "/tmp/a"},
49 },
50 "suite": [
51 {
52 "name": "dk1",
53 "kind": "delta_kl",
54 "prompts": ["hello", "world"],
55 "assert_mean_gte": 0.0,
56 },
57 {
58 "name": "dk2",
59 "kind": "delta_kl",
60 "prompts": ["hello", "world"],
61 "assert_mean_gte": 0.0,
62 },
63 ],
64 }
65 )
66
67
68 def test_backend_stats_populated_after_run() -> None:
69 backend = _programmable_backend()
70 result = run_suite(_spec_with_repeated_prompts(), backend)
71 stats = result.backend_stats
72 assert stats, "backend_stats should be non-empty when the backend is instrumented"
73 for key in ("cache_hits", "cache_misses", "forward_passes", "hit_rate"):
74 assert key in stats
75
76
77 def test_second_probe_hits_cache_for_duplicate_prompts() -> None:
78 """dk2 runs the same 2 prompts × 2 views as dk1 → 4 hits."""
79 backend = _programmable_backend()
80 result = run_suite(_spec_with_repeated_prompts(), backend)
81 stats = result.backend_stats
82 # dk1 misses all 4 (2 prompts × 2 views); dk2 hits all 4.
83 assert stats["cache_hits"] >= 4, f"expected ≥ 4 hits, got {stats}"
84 assert stats["forward_passes"] == stats["cache_misses"]
85
86
87 def test_trace_writer_produces_jsonl(tmp_path: Path) -> None:
88 trace_path = tmp_path / "trace.jsonl"
89 backend = _programmable_backend()
90 run_suite(_spec_with_repeated_prompts(), backend, trace_path=trace_path)
91
92 lines = [
93 json.loads(line)
94 for line in trace_path.read_text(encoding="utf-8").splitlines()
95 if line.strip()
96 ]
97 assert lines
98 # Probe labels flow from the runner into the trace events.
99 probe_names = {line["probe"] for line in lines}
100 assert "dk1" in probe_names
101 assert "dk2" in probe_names
102 # Some hits, some misses — the whole point.
103 assert any(line["hit"] for line in lines)
104 assert any(not line["hit"] for line in lines)
105
106
107 def test_trace_roundtrip_matches_backend_stats(tmp_path: Path) -> None:
108 """F09 regression — trace writer and analyzer share the same schema.
109
110 The writer lives in ``backends/_instrumentation`` and the analyzer
111 lives in ``suite/trace_analysis``. They aren't linked by a shared
112 type. A writer-side field removal would parse cleanly in the
113 analyzer (everything is ``Optional``) and silently roll up every
114 event under ``<unassigned>``. This test runs a real suite with
115 ``trace_path=`` set, loads the file back, and asserts the loaded
116 events are consistent with the backend_stats counters produced by
117 the same run.
118 """
119 from dlm_sway.suite.trace_analysis import load as load_trace
120
121 trace_path = tmp_path / "trace.jsonl"
122 backend = _programmable_backend()
123 result = run_suite(_spec_with_repeated_prompts(), backend, trace_path=trace_path)
124
125 events = load_trace(trace_path)
126 assert events, "trace writer produced no events"
127 # Probe-scoped events must tag with their probe name. Pre-probe
128 # preflight events legitimately carry ``probe=None``; we split the
129 # stream and hold the probe-tagged subset to the stricter invariant.
130 probe_tagged = [e for e in events if e.probe is not None]
131 assert probe_tagged, "no probe-tagged events — writer dropped the probe field"
132 # dk1 + dk2 are the only two probes; both should surface.
133 probe_labels = {e.probe for e in probe_tagged}
134 assert probe_labels == {"dk1", "dk2"}
135 # base + ft are the only two views touched by delta_kl.
136 view_ids = {e.view_id for e in probe_tagged}
137 assert view_ids == {"base", "ft"}
138
139 # Trace hit-rate must match backend_stats hit-rate to within
140 # integer counts. ``backend_stats`` aggregates *every* cache
141 # access the instrumented backend saw during the run, including
142 # the pre-probe preflight call (which carries ``probe=None`` in
143 # the trace). Compare against the full event list, not just the
144 # probe-tagged subset.
145 stats = result.backend_stats
146 traced_hits = sum(1 for e in events if e.hit)
147 traced_misses = sum(1 for e in events if not e.hit)
148 assert traced_hits == stats["cache_hits"], (
149 f"trace says {traced_hits} hits; backend_stats says {stats['cache_hits']}. "
150 "Writer/analyzer schemas have drifted."
151 )
152 assert traced_misses == stats["cache_misses"]
153 # Every event carries the four fields the analyzer reads; a writer
154 # that dropped ``op`` or ``wall_ms`` would parse cleanly (both are
155 # required-but-defaulted on the dataclass) and roll up under empty
156 # buckets — this assertion catches that silent mode.
157 assert all(e.op == "next_token_dist" for e in events)
158 assert all(e.wall_ms >= 0.0 for e in events)
159
160
161 def test_ci_95_survives_runner_roundtrip() -> None:
162 """F01 regression — ``_with_duration`` must forward every field.
163
164 The bug: ``_with_duration`` rebuilt the dataclass by hand and
165 silently dropped ``ci_95`` on its way out of the runner. Every
166 bootstrap CI was stripped before reaching the ``SuiteResult``.
167 """
168 # delta_kl only emits a bootstrap CI when it has ≥ 4 samples to
169 # resample — tailor a fixture that clears that floor.
170 base_dist = TokenDist(
171 token_ids=np.array([1, 2, 3], dtype=np.int64),
172 logprobs=np.log(np.array([0.7, 0.2, 0.1], dtype=np.float32)),
173 vocab_size=100,
174 )
175 ft_dists = {
176 "p1": TokenDist(
177 token_ids=np.array([1, 2, 3], dtype=np.int64),
178 logprobs=np.log(np.array([0.2, 0.4, 0.4], dtype=np.float32)),
179 vocab_size=100,
180 ),
181 "p2": TokenDist(
182 token_ids=np.array([1, 2, 3], dtype=np.int64),
183 logprobs=np.log(np.array([0.3, 0.3, 0.4], dtype=np.float32)),
184 vocab_size=100,
185 ),
186 "p3": TokenDist(
187 token_ids=np.array([1, 2, 3], dtype=np.int64),
188 logprobs=np.log(np.array([0.1, 0.5, 0.4], dtype=np.float32)),
189 vocab_size=100,
190 ),
191 "p4": TokenDist(
192 token_ids=np.array([1, 2, 3], dtype=np.int64),
193 logprobs=np.log(np.array([0.25, 0.35, 0.4], dtype=np.float32)),
194 vocab_size=100,
195 ),
196 }
197 backend = DummyDifferentialBackend(
198 base=DummyResponses(token_dists=dict.fromkeys(ft_dists, base_dist)),
199 ft=DummyResponses(token_dists=ft_dists),
200 )
201 spec = SwaySpec.model_validate(
202 {
203 "version": 1,
204 "models": {
205 "base": {"base": "b"},
206 "ft": {"base": "b", "adapter": "/tmp/a"},
207 },
208 "suite": [
209 {
210 "name": "dk",
211 "kind": "delta_kl",
212 "prompts": list(ft_dists.keys()),
213 "assert_mean_gte": 0.0,
214 }
215 ],
216 }
217 )
218 result = run_suite(spec, backend)
219 probe = result.probes[0]
220 assert probe.kind == "delta_kl"
221 assert probe.ci_95 is not None, (
222 "delta_kl emits a bootstrap CI at N=4; the runner dropped it. "
223 "Check suite/runner.py:_with_duration."
224 )
225 lo, hi = probe.ci_95
226 assert lo <= (probe.raw or 0.0) <= hi
227
228
229 def test_report_footer_includes_cache_hit_rate() -> None:
230 """Report surface shows the ``cache: N/M = X%`` line when stats exist."""
231 from dlm_sway.suite import report
232 from dlm_sway.suite.score import compute as compute_score
233
234 backend = _programmable_backend()
235 result = run_suite(_spec_with_repeated_prompts(), backend)
236 score = compute_score(result)
237 md = report.to_markdown(result, score)
238 assert "cache:" in md
239 assert "%" in md