tenseleyflow/sway / e2e0d7d

Browse files

tests/unit: trace_analysis + trace_cmd CLI coverage (22 tests)

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
e2e0d7d3bc28dee9aac058d2bea6cec885baa6c6
Parents
af4493c
Tree
b7b6c57

1 changed file

StatusFile+-
A tests/unit/test_trace_analysis.py 206 0
tests/unit/test_trace_analysis.pyadded
@@ -0,0 +1,206 @@
1
+"""Tests for :mod:`dlm_sway.suite.trace_analysis` (S14 / F12)."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+from pathlib import Path
7
+
8
+from rich.console import Console
9
+from typer.testing import CliRunner
10
+
11
+from dlm_sway.cli.app import app
12
+from dlm_sway.suite.trace_analysis import (
13
+    ProbeSummary,
14
+    TraceEvent,
15
+    build_report,
16
+    load,
17
+    per_probe_summary,
18
+    per_view_summary,
19
+    render_json,
20
+    render_markdown,
21
+    render_terminal,
22
+    slowest_events,
23
+)
24
+
25
+FIXTURE = Path(__file__).parent.parent / "fixtures" / "trace_sample.jsonl"
26
+
27
+
28
+class TestLoad:
29
+    def test_loads_every_event(self) -> None:
30
+        events = load(FIXTURE)
31
+        assert len(events) == 8
32
+        assert all(isinstance(e, TraceEvent) for e in events)
33
+
34
+    def test_skips_malformed_lines(self, tmp_path: Path) -> None:
35
+        path = tmp_path / "trace.jsonl"
36
+        path.write_text(
37
+            '{"ts": 1, "probe": "p", "view_id": "v", "prompt_hash": "a",'
38
+            ' "top_k": 0, "op": "op", "wall_ms": 1.0, "hit": false}\n'
39
+            "not json at all\n"
40
+            "\n"
41
+            '{"ts": 2, "probe": "q", "view_id": "w", "prompt_hash": "b",'
42
+            ' "top_k": 0, "op": "op", "wall_ms": 2.0, "hit": true}\n',
43
+            encoding="utf-8",
44
+        )
45
+        events = load(path)
46
+        assert len(events) == 2
47
+
48
+    def test_missing_optional_fields_fallback(self, tmp_path: Path) -> None:
49
+        path = tmp_path / "trace.jsonl"
50
+        # pre-S07 shape — no probe, no hit.
51
+        path.write_text(
52
+            '{"ts": 0, "view_id": "base", "prompt_hash": "x",'
53
+            ' "top_k": 0, "op": "next_token_dist", "wall_ms": 5.0}\n',
54
+            encoding="utf-8",
55
+        )
56
+        events = load(path)
57
+        assert len(events) == 1
58
+        assert events[0].probe is None
59
+        assert events[0].hit is False
60
+
61
+
62
+class TestPerProbeSummary:
63
+    def test_buckets_by_probe(self) -> None:
64
+        events = load(FIXTURE)
65
+        summaries = per_probe_summary(events)
66
+        by_name = {s.probe: s for s in summaries}
67
+        assert set(by_name) == {"dk", "sis"}
68
+        assert by_name["dk"].n_events == 4
69
+        assert by_name["sis"].n_events == 4
70
+
71
+    def test_wall_ms_accumulates(self) -> None:
72
+        events = load(FIXTURE)
73
+        summaries = per_probe_summary(events)
74
+        by_name = {s.probe: s for s in summaries}
75
+        # dk: 180.5 + 175.2 + 0.1 + 172.8 = 528.6
76
+        assert by_name["dk"].total_ms == round(180.5 + 175.2 + 0.1 + 172.8, 1)
77
+
78
+    def test_cache_hit_tracking(self) -> None:
79
+        events = load(FIXTURE)
80
+        summaries = per_probe_summary(events)
81
+        by_name = {s.probe: s for s in summaries}
82
+        # Each probe had 1 hit, 3 misses in the fixture.
83
+        assert by_name["dk"].cache_hits == 1
84
+        assert by_name["dk"].cache_misses == 3
85
+        assert by_name["dk"].hit_rate == 0.25
86
+
87
+    def test_sorted_by_wall_ms_descending(self) -> None:
88
+        events = load(FIXTURE)
89
+        summaries = per_probe_summary(events)
90
+        # sis has bigger total (two ~500 ms events vs dk's ~180 ms events)
91
+        assert summaries[0].probe == "sis"
92
+
93
+
94
+class TestPerViewSummary:
95
+    def test_buckets_by_view(self) -> None:
96
+        events = load(FIXTURE)
97
+        summaries = per_view_summary(events)
98
+        view_ids = {s.view_id for s in summaries}
99
+        assert view_ids == {"base", "ft"}
100
+
101
+    def test_sorted_by_wall_ms(self) -> None:
102
+        events = load(FIXTURE)
103
+        summaries = per_view_summary(events)
104
+        assert summaries[0].total_ms >= summaries[1].total_ms
105
+
106
+
107
+class TestSlowestEvents:
108
+    def test_returns_top_k(self) -> None:
109
+        events = load(FIXTURE)
110
+        slowest = slowest_events(events, k=3)
111
+        assert len(slowest) == 3
112
+        # Descending by wall_ms
113
+        assert slowest[0].wall_ms >= slowest[1].wall_ms >= slowest[2].wall_ms
114
+        # The slowest is the sis/base first rolling_logprob (520.3 ms).
115
+        assert slowest[0].wall_ms == 520.3
116
+
117
+    def test_k_larger_than_events_returns_all(self) -> None:
118
+        events = load(FIXTURE)
119
+        slowest = slowest_events(events, k=100)
120
+        assert len(slowest) == len(events)
121
+
122
+
123
+class TestRenderers:
124
+    def test_json_shape(self) -> None:
125
+        events = load(FIXTURE)
126
+        report = build_report(events, slowest_k=3)
127
+        payload = json.loads(render_json(report))
128
+        assert payload["total_events"] == 8
129
+        assert payload["overall_hit_rate"] == 0.25
130
+        assert len(payload["per_probe"]) == 2
131
+        assert len(payload["slowest"]) == 3
132
+
133
+    def test_markdown_nonempty(self) -> None:
134
+        events = load(FIXTURE)
135
+        md = render_markdown(build_report(events))
136
+        assert "# sway trace" in md
137
+        assert "## per-probe" in md
138
+        assert "## per-view" in md
139
+        assert "dk" in md
140
+        assert "sis" in md
141
+
142
+    def test_terminal_renders_without_error(self) -> None:
143
+        events = load(FIXTURE)
144
+        console = Console(record=True, width=160)
145
+        render_terminal(build_report(events), console=console)
146
+        text = console.export_text()
147
+        assert "sway trace" in text
148
+        assert "dk" in text
149
+        assert "sis" in text
150
+
151
+
152
+class TestBuildReport:
153
+    def test_empty_events(self) -> None:
154
+        report = build_report([])
155
+        assert report.total_events == 0
156
+        assert report.total_wall_ms == 0.0
157
+        assert report.overall_hit_rate == 0.0
158
+        assert report.per_probe == []
159
+
160
+    def test_probe_summary_matches(self) -> None:
161
+        events = load(FIXTURE)
162
+        report = build_report(events)
163
+        assert isinstance(report.per_probe[0], ProbeSummary)
164
+        assert report.total_events == len(events)
165
+
166
+
167
+class TestCli:
168
+    def test_terminal_default(self) -> None:
169
+        result = CliRunner().invoke(app, ["trace", str(FIXTURE)])
170
+        assert result.exit_code == 0, result.stdout
171
+        assert "sway trace" in result.stdout
172
+        assert "dk" in result.stdout
173
+        assert "per-probe" in result.stdout
174
+
175
+    def test_json_format(self) -> None:
176
+        result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "json"])
177
+        assert result.exit_code == 0
178
+        parsed = json.loads(result.stdout)
179
+        assert parsed["total_events"] == 8
180
+
181
+    def test_markdown_format(self) -> None:
182
+        result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "md"])
183
+        assert result.exit_code == 0
184
+        assert "# sway trace" in result.stdout
185
+
186
+    def test_missing_file_exits_2(self, tmp_path: Path) -> None:
187
+        missing = tmp_path / "nope.jsonl"
188
+        result = CliRunner().invoke(app, ["trace", str(missing)])
189
+        # Typer's PATH argument rejects non-existent paths with its own
190
+        # validation (exit 2) before our code runs.
191
+        assert result.exit_code == 2
192
+
193
+    def test_empty_file_exits_1(self, tmp_path: Path) -> None:
194
+        empty = tmp_path / "empty.jsonl"
195
+        empty.write_text("", encoding="utf-8")
196
+        result = CliRunner().invoke(app, ["trace", str(empty)])
197
+        assert result.exit_code == 1
198
+        assert "no events" in result.stdout + result.stderr
199
+
200
+    def test_slowest_override(self) -> None:
201
+        result = CliRunner().invoke(
202
+            app, ["trace", str(FIXTURE), "--format", "json", "--slowest", "2"]
203
+        )
204
+        assert result.exit_code == 0
205
+        parsed = json.loads(result.stdout)
206
+        assert len(parsed["slowest"]) == 2