Python · 7298 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.suite.trace_analysis` (S14 / F12)."""
2
3 from __future__ import annotations
4
5 import json
6 from pathlib import Path
7
8 from rich.console import Console
9 from typer.testing import CliRunner
10
11 from dlm_sway.cli.app import app
12 from dlm_sway.suite.trace_analysis import (
13 ProbeSummary,
14 TraceEvent,
15 build_report,
16 load,
17 per_probe_summary,
18 per_view_summary,
19 render_json,
20 render_markdown,
21 render_terminal,
22 slowest_events,
23 )
24
25 FIXTURE = Path(__file__).parent.parent / "fixtures" / "trace_sample.jsonl"
26
27
28 class TestLoad:
29 def test_loads_every_event(self) -> None:
30 events = load(FIXTURE)
31 assert len(events) == 8
32 assert all(isinstance(e, TraceEvent) for e in events)
33
34 def test_skips_malformed_lines(self, tmp_path: Path) -> None:
35 path = tmp_path / "trace.jsonl"
36 path.write_text(
37 '{"ts": 1, "probe": "p", "view_id": "v", "prompt_hash": "a",'
38 ' "top_k": 0, "op": "op", "wall_ms": 1.0, "hit": false}\n'
39 "not json at all\n"
40 "\n"
41 '{"ts": 2, "probe": "q", "view_id": "w", "prompt_hash": "b",'
42 ' "top_k": 0, "op": "op", "wall_ms": 2.0, "hit": true}\n',
43 encoding="utf-8",
44 )
45 events = load(path)
46 assert len(events) == 2
47
48 def test_missing_optional_fields_fallback(self, tmp_path: Path) -> None:
49 path = tmp_path / "trace.jsonl"
50 # pre-S07 shape — no probe, no hit.
51 path.write_text(
52 '{"ts": 0, "view_id": "base", "prompt_hash": "x",'
53 ' "top_k": 0, "op": "next_token_dist", "wall_ms": 5.0}\n',
54 encoding="utf-8",
55 )
56 events = load(path)
57 assert len(events) == 1
58 assert events[0].probe is None
59 assert events[0].hit is False
60
61
62 class TestPerProbeSummary:
63 def test_buckets_by_probe(self) -> None:
64 events = load(FIXTURE)
65 summaries = per_probe_summary(events)
66 by_name = {s.probe: s for s in summaries}
67 assert set(by_name) == {"dk", "sis"}
68 assert by_name["dk"].n_events == 4
69 assert by_name["sis"].n_events == 4
70
71 def test_wall_ms_accumulates(self) -> None:
72 events = load(FIXTURE)
73 summaries = per_probe_summary(events)
74 by_name = {s.probe: s for s in summaries}
75 # dk: 180.5 + 175.2 + 0.1 + 172.8 = 528.6
76 assert by_name["dk"].total_ms == round(180.5 + 175.2 + 0.1 + 172.8, 1)
77
78 def test_cache_hit_tracking(self) -> None:
79 events = load(FIXTURE)
80 summaries = per_probe_summary(events)
81 by_name = {s.probe: s for s in summaries}
82 # Each probe had 1 hit, 3 misses in the fixture.
83 assert by_name["dk"].cache_hits == 1
84 assert by_name["dk"].cache_misses == 3
85 assert by_name["dk"].hit_rate == 0.25
86
87 def test_sorted_by_wall_ms_descending(self) -> None:
88 events = load(FIXTURE)
89 summaries = per_probe_summary(events)
90 # sis has bigger total (two ~500 ms events vs dk's ~180 ms events)
91 assert summaries[0].probe == "sis"
92
93
94 class TestPerViewSummary:
95 def test_buckets_by_view(self) -> None:
96 events = load(FIXTURE)
97 summaries = per_view_summary(events)
98 view_ids = {s.view_id for s in summaries}
99 assert view_ids == {"base", "ft"}
100
101 def test_sorted_by_wall_ms(self) -> None:
102 events = load(FIXTURE)
103 summaries = per_view_summary(events)
104 assert summaries[0].total_ms >= summaries[1].total_ms
105
106
107 class TestSlowestEvents:
108 def test_returns_top_k(self) -> None:
109 events = load(FIXTURE)
110 slowest = slowest_events(events, k=3)
111 assert len(slowest) == 3
112 # Descending by wall_ms
113 assert slowest[0].wall_ms >= slowest[1].wall_ms >= slowest[2].wall_ms
114 # The slowest is the sis/base first rolling_logprob (520.3 ms).
115 assert slowest[0].wall_ms == 520.3
116
117 def test_k_larger_than_events_returns_all(self) -> None:
118 events = load(FIXTURE)
119 slowest = slowest_events(events, k=100)
120 assert len(slowest) == len(events)
121
122
123 class TestRenderers:
124 def test_json_shape(self) -> None:
125 events = load(FIXTURE)
126 report = build_report(events, slowest_k=3)
127 payload = json.loads(render_json(report))
128 assert payload["total_events"] == 8
129 assert payload["overall_hit_rate"] == 0.25
130 assert len(payload["per_probe"]) == 2
131 assert len(payload["slowest"]) == 3
132
133 def test_markdown_nonempty(self) -> None:
134 events = load(FIXTURE)
135 md = render_markdown(build_report(events))
136 assert "# sway trace" in md
137 assert "## per-probe" in md
138 assert "## per-view" in md
139 assert "dk" in md
140 assert "sis" in md
141
142 def test_terminal_renders_without_error(self) -> None:
143 events = load(FIXTURE)
144 console = Console(record=True, width=160)
145 render_terminal(build_report(events), console=console)
146 text = console.export_text()
147 assert "sway trace" in text
148 assert "dk" in text
149 assert "sis" in text
150
151
152 class TestBuildReport:
153 def test_empty_events(self) -> None:
154 report = build_report([])
155 assert report.total_events == 0
156 assert report.total_wall_ms == 0.0
157 assert report.overall_hit_rate == 0.0
158 assert report.per_probe == []
159
160 def test_probe_summary_matches(self) -> None:
161 events = load(FIXTURE)
162 report = build_report(events)
163 assert isinstance(report.per_probe[0], ProbeSummary)
164 assert report.total_events == len(events)
165
166
167 class TestCli:
168 def test_terminal_default(self) -> None:
169 result = CliRunner().invoke(app, ["trace", str(FIXTURE)])
170 assert result.exit_code == 0, result.stdout
171 assert "sway trace" in result.stdout
172 assert "dk" in result.stdout
173 assert "per-probe" in result.stdout
174
175 def test_json_format(self) -> None:
176 result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "json"])
177 assert result.exit_code == 0
178 parsed = json.loads(result.stdout)
179 assert parsed["total_events"] == 8
180
181 def test_markdown_format(self) -> None:
182 result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "md"])
183 assert result.exit_code == 0
184 assert "# sway trace" in result.stdout
185
186 def test_missing_file_exits_2(self, tmp_path: Path) -> None:
187 missing = tmp_path / "nope.jsonl"
188 result = CliRunner().invoke(app, ["trace", str(missing)])
189 # Typer's PATH argument rejects non-existent paths with its own
190 # validation (exit 2) before our code runs.
191 assert result.exit_code == 2
192
193 def test_empty_file_exits_1(self, tmp_path: Path) -> None:
194 empty = tmp_path / "empty.jsonl"
195 empty.write_text("", encoding="utf-8")
196 result = CliRunner().invoke(app, ["trace", str(empty)])
197 assert result.exit_code == 1
198 assert "no events" in result.stdout + result.stderr
199
200 def test_slowest_override(self) -> None:
201 result = CliRunner().invoke(
202 app, ["trace", str(FIXTURE), "--format", "json", "--slowest", "2"]
203 )
204 assert result.exit_code == 0
205 parsed = json.loads(result.stdout)
206 assert len(parsed["slowest"]) == 2