| 1 | """Tests for :mod:`dlm_sway.suite.trace_analysis` (S14 / F12).""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import json |
| 6 | from pathlib import Path |
| 7 | |
| 8 | from rich.console import Console |
| 9 | from typer.testing import CliRunner |
| 10 | |
| 11 | from dlm_sway.cli.app import app |
| 12 | from dlm_sway.suite.trace_analysis import ( |
| 13 | ProbeSummary, |
| 14 | TraceEvent, |
| 15 | build_report, |
| 16 | load, |
| 17 | per_probe_summary, |
| 18 | per_view_summary, |
| 19 | render_json, |
| 20 | render_markdown, |
| 21 | render_terminal, |
| 22 | slowest_events, |
| 23 | ) |
| 24 | |
| 25 | FIXTURE = Path(__file__).parent.parent / "fixtures" / "trace_sample.jsonl" |
| 26 | |
| 27 | |
| 28 | class TestLoad: |
| 29 | def test_loads_every_event(self) -> None: |
| 30 | events = load(FIXTURE) |
| 31 | assert len(events) == 8 |
| 32 | assert all(isinstance(e, TraceEvent) for e in events) |
| 33 | |
| 34 | def test_skips_malformed_lines(self, tmp_path: Path) -> None: |
| 35 | path = tmp_path / "trace.jsonl" |
| 36 | path.write_text( |
| 37 | '{"ts": 1, "probe": "p", "view_id": "v", "prompt_hash": "a",' |
| 38 | ' "top_k": 0, "op": "op", "wall_ms": 1.0, "hit": false}\n' |
| 39 | "not json at all\n" |
| 40 | "\n" |
| 41 | '{"ts": 2, "probe": "q", "view_id": "w", "prompt_hash": "b",' |
| 42 | ' "top_k": 0, "op": "op", "wall_ms": 2.0, "hit": true}\n', |
| 43 | encoding="utf-8", |
| 44 | ) |
| 45 | events = load(path) |
| 46 | assert len(events) == 2 |
| 47 | |
| 48 | def test_missing_optional_fields_fallback(self, tmp_path: Path) -> None: |
| 49 | path = tmp_path / "trace.jsonl" |
| 50 | # pre-S07 shape — no probe, no hit. |
| 51 | path.write_text( |
| 52 | '{"ts": 0, "view_id": "base", "prompt_hash": "x",' |
| 53 | ' "top_k": 0, "op": "next_token_dist", "wall_ms": 5.0}\n', |
| 54 | encoding="utf-8", |
| 55 | ) |
| 56 | events = load(path) |
| 57 | assert len(events) == 1 |
| 58 | assert events[0].probe is None |
| 59 | assert events[0].hit is False |
| 60 | |
| 61 | |
| 62 | class TestPerProbeSummary: |
| 63 | def test_buckets_by_probe(self) -> None: |
| 64 | events = load(FIXTURE) |
| 65 | summaries = per_probe_summary(events) |
| 66 | by_name = {s.probe: s for s in summaries} |
| 67 | assert set(by_name) == {"dk", "sis"} |
| 68 | assert by_name["dk"].n_events == 4 |
| 69 | assert by_name["sis"].n_events == 4 |
| 70 | |
| 71 | def test_wall_ms_accumulates(self) -> None: |
| 72 | events = load(FIXTURE) |
| 73 | summaries = per_probe_summary(events) |
| 74 | by_name = {s.probe: s for s in summaries} |
| 75 | # dk: 180.5 + 175.2 + 0.1 + 172.8 = 528.6 |
| 76 | assert by_name["dk"].total_ms == round(180.5 + 175.2 + 0.1 + 172.8, 1) |
| 77 | |
| 78 | def test_cache_hit_tracking(self) -> None: |
| 79 | events = load(FIXTURE) |
| 80 | summaries = per_probe_summary(events) |
| 81 | by_name = {s.probe: s for s in summaries} |
| 82 | # Each probe had 1 hit, 3 misses in the fixture. |
| 83 | assert by_name["dk"].cache_hits == 1 |
| 84 | assert by_name["dk"].cache_misses == 3 |
| 85 | assert by_name["dk"].hit_rate == 0.25 |
| 86 | |
| 87 | def test_sorted_by_wall_ms_descending(self) -> None: |
| 88 | events = load(FIXTURE) |
| 89 | summaries = per_probe_summary(events) |
| 90 | # sis has bigger total (two ~500 ms events vs dk's ~180 ms events) |
| 91 | assert summaries[0].probe == "sis" |
| 92 | |
| 93 | |
| 94 | class TestPerViewSummary: |
| 95 | def test_buckets_by_view(self) -> None: |
| 96 | events = load(FIXTURE) |
| 97 | summaries = per_view_summary(events) |
| 98 | view_ids = {s.view_id for s in summaries} |
| 99 | assert view_ids == {"base", "ft"} |
| 100 | |
| 101 | def test_sorted_by_wall_ms(self) -> None: |
| 102 | events = load(FIXTURE) |
| 103 | summaries = per_view_summary(events) |
| 104 | assert summaries[0].total_ms >= summaries[1].total_ms |
| 105 | |
| 106 | |
| 107 | class TestSlowestEvents: |
| 108 | def test_returns_top_k(self) -> None: |
| 109 | events = load(FIXTURE) |
| 110 | slowest = slowest_events(events, k=3) |
| 111 | assert len(slowest) == 3 |
| 112 | # Descending by wall_ms |
| 113 | assert slowest[0].wall_ms >= slowest[1].wall_ms >= slowest[2].wall_ms |
| 114 | # The slowest is the sis/base first rolling_logprob (520.3 ms). |
| 115 | assert slowest[0].wall_ms == 520.3 |
| 116 | |
| 117 | def test_k_larger_than_events_returns_all(self) -> None: |
| 118 | events = load(FIXTURE) |
| 119 | slowest = slowest_events(events, k=100) |
| 120 | assert len(slowest) == len(events) |
| 121 | |
| 122 | |
| 123 | class TestRenderers: |
| 124 | def test_json_shape(self) -> None: |
| 125 | events = load(FIXTURE) |
| 126 | report = build_report(events, slowest_k=3) |
| 127 | payload = json.loads(render_json(report)) |
| 128 | assert payload["total_events"] == 8 |
| 129 | assert payload["overall_hit_rate"] == 0.25 |
| 130 | assert len(payload["per_probe"]) == 2 |
| 131 | assert len(payload["slowest"]) == 3 |
| 132 | |
| 133 | def test_markdown_nonempty(self) -> None: |
| 134 | events = load(FIXTURE) |
| 135 | md = render_markdown(build_report(events)) |
| 136 | assert "# sway trace" in md |
| 137 | assert "## per-probe" in md |
| 138 | assert "## per-view" in md |
| 139 | assert "dk" in md |
| 140 | assert "sis" in md |
| 141 | |
| 142 | def test_terminal_renders_without_error(self) -> None: |
| 143 | events = load(FIXTURE) |
| 144 | console = Console(record=True, width=160) |
| 145 | render_terminal(build_report(events), console=console) |
| 146 | text = console.export_text() |
| 147 | assert "sway trace" in text |
| 148 | assert "dk" in text |
| 149 | assert "sis" in text |
| 150 | |
| 151 | |
| 152 | class TestBuildReport: |
| 153 | def test_empty_events(self) -> None: |
| 154 | report = build_report([]) |
| 155 | assert report.total_events == 0 |
| 156 | assert report.total_wall_ms == 0.0 |
| 157 | assert report.overall_hit_rate == 0.0 |
| 158 | assert report.per_probe == [] |
| 159 | |
| 160 | def test_probe_summary_matches(self) -> None: |
| 161 | events = load(FIXTURE) |
| 162 | report = build_report(events) |
| 163 | assert isinstance(report.per_probe[0], ProbeSummary) |
| 164 | assert report.total_events == len(events) |
| 165 | |
| 166 | |
| 167 | class TestCli: |
| 168 | def test_terminal_default(self) -> None: |
| 169 | result = CliRunner().invoke(app, ["trace", str(FIXTURE)]) |
| 170 | assert result.exit_code == 0, result.stdout |
| 171 | assert "sway trace" in result.stdout |
| 172 | assert "dk" in result.stdout |
| 173 | assert "per-probe" in result.stdout |
| 174 | |
| 175 | def test_json_format(self) -> None: |
| 176 | result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "json"]) |
| 177 | assert result.exit_code == 0 |
| 178 | parsed = json.loads(result.stdout) |
| 179 | assert parsed["total_events"] == 8 |
| 180 | |
| 181 | def test_markdown_format(self) -> None: |
| 182 | result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "md"]) |
| 183 | assert result.exit_code == 0 |
| 184 | assert "# sway trace" in result.stdout |
| 185 | |
| 186 | def test_missing_file_exits_2(self, tmp_path: Path) -> None: |
| 187 | missing = tmp_path / "nope.jsonl" |
| 188 | result = CliRunner().invoke(app, ["trace", str(missing)]) |
| 189 | # Typer's PATH argument rejects non-existent paths with its own |
| 190 | # validation (exit 2) before our code runs. |
| 191 | assert result.exit_code == 2 |
| 192 | |
| 193 | def test_empty_file_exits_1(self, tmp_path: Path) -> None: |
| 194 | empty = tmp_path / "empty.jsonl" |
| 195 | empty.write_text("", encoding="utf-8") |
| 196 | result = CliRunner().invoke(app, ["trace", str(empty)]) |
| 197 | assert result.exit_code == 1 |
| 198 | assert "no events" in result.stdout + result.stderr |
| 199 | |
| 200 | def test_slowest_override(self) -> None: |
| 201 | result = CliRunner().invoke( |
| 202 | app, ["trace", str(FIXTURE), "--format", "json", "--slowest", "2"] |
| 203 | ) |
| 204 | assert result.exit_code == 0 |
| 205 | parsed = json.loads(result.stdout) |
| 206 | assert len(parsed["slowest"]) == 2 |