@@ -0,0 +1,206 @@ |
| 1 | +"""Tests for :mod:`dlm_sway.suite.trace_analysis` (S14 / F12).""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import json |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +from rich.console import Console |
| 9 | +from typer.testing import CliRunner |
| 10 | + |
| 11 | +from dlm_sway.cli.app import app |
| 12 | +from dlm_sway.suite.trace_analysis import ( |
| 13 | + ProbeSummary, |
| 14 | + TraceEvent, |
| 15 | + build_report, |
| 16 | + load, |
| 17 | + per_probe_summary, |
| 18 | + per_view_summary, |
| 19 | + render_json, |
| 20 | + render_markdown, |
| 21 | + render_terminal, |
| 22 | + slowest_events, |
| 23 | +) |
| 24 | + |
| 25 | +FIXTURE = Path(__file__).parent.parent / "fixtures" / "trace_sample.jsonl" |
| 26 | + |
| 27 | + |
| 28 | +class TestLoad: |
| 29 | + def test_loads_every_event(self) -> None: |
| 30 | + events = load(FIXTURE) |
| 31 | + assert len(events) == 8 |
| 32 | + assert all(isinstance(e, TraceEvent) for e in events) |
| 33 | + |
| 34 | + def test_skips_malformed_lines(self, tmp_path: Path) -> None: |
| 35 | + path = tmp_path / "trace.jsonl" |
| 36 | + path.write_text( |
| 37 | + '{"ts": 1, "probe": "p", "view_id": "v", "prompt_hash": "a",' |
| 38 | + ' "top_k": 0, "op": "op", "wall_ms": 1.0, "hit": false}\n' |
| 39 | + "not json at all\n" |
| 40 | + "\n" |
| 41 | + '{"ts": 2, "probe": "q", "view_id": "w", "prompt_hash": "b",' |
| 42 | + ' "top_k": 0, "op": "op", "wall_ms": 2.0, "hit": true}\n', |
| 43 | + encoding="utf-8", |
| 44 | + ) |
| 45 | + events = load(path) |
| 46 | + assert len(events) == 2 |
| 47 | + |
| 48 | + def test_missing_optional_fields_fallback(self, tmp_path: Path) -> None: |
| 49 | + path = tmp_path / "trace.jsonl" |
| 50 | + # pre-S07 shape — no probe, no hit. |
| 51 | + path.write_text( |
| 52 | + '{"ts": 0, "view_id": "base", "prompt_hash": "x",' |
| 53 | + ' "top_k": 0, "op": "next_token_dist", "wall_ms": 5.0}\n', |
| 54 | + encoding="utf-8", |
| 55 | + ) |
| 56 | + events = load(path) |
| 57 | + assert len(events) == 1 |
| 58 | + assert events[0].probe is None |
| 59 | + assert events[0].hit is False |
| 60 | + |
| 61 | + |
| 62 | +class TestPerProbeSummary: |
| 63 | + def test_buckets_by_probe(self) -> None: |
| 64 | + events = load(FIXTURE) |
| 65 | + summaries = per_probe_summary(events) |
| 66 | + by_name = {s.probe: s for s in summaries} |
| 67 | + assert set(by_name) == {"dk", "sis"} |
| 68 | + assert by_name["dk"].n_events == 4 |
| 69 | + assert by_name["sis"].n_events == 4 |
| 70 | + |
| 71 | + def test_wall_ms_accumulates(self) -> None: |
| 72 | + events = load(FIXTURE) |
| 73 | + summaries = per_probe_summary(events) |
| 74 | + by_name = {s.probe: s for s in summaries} |
| 75 | + # dk: 180.5 + 175.2 + 0.1 + 172.8 = 528.6 |
| 76 | + assert by_name["dk"].total_ms == round(180.5 + 175.2 + 0.1 + 172.8, 1) |
| 77 | + |
| 78 | + def test_cache_hit_tracking(self) -> None: |
| 79 | + events = load(FIXTURE) |
| 80 | + summaries = per_probe_summary(events) |
| 81 | + by_name = {s.probe: s for s in summaries} |
| 82 | + # Each probe had 1 hit, 3 misses in the fixture. |
| 83 | + assert by_name["dk"].cache_hits == 1 |
| 84 | + assert by_name["dk"].cache_misses == 3 |
| 85 | + assert by_name["dk"].hit_rate == 0.25 |
| 86 | + |
| 87 | + def test_sorted_by_wall_ms_descending(self) -> None: |
| 88 | + events = load(FIXTURE) |
| 89 | + summaries = per_probe_summary(events) |
| 90 | + # sis has bigger total (two ~500 ms events vs dk's ~180 ms events) |
| 91 | + assert summaries[0].probe == "sis" |
| 92 | + |
| 93 | + |
| 94 | +class TestPerViewSummary: |
| 95 | + def test_buckets_by_view(self) -> None: |
| 96 | + events = load(FIXTURE) |
| 97 | + summaries = per_view_summary(events) |
| 98 | + view_ids = {s.view_id for s in summaries} |
| 99 | + assert view_ids == {"base", "ft"} |
| 100 | + |
| 101 | + def test_sorted_by_wall_ms(self) -> None: |
| 102 | + events = load(FIXTURE) |
| 103 | + summaries = per_view_summary(events) |
| 104 | + assert summaries[0].total_ms >= summaries[1].total_ms |
| 105 | + |
| 106 | + |
| 107 | +class TestSlowestEvents: |
| 108 | + def test_returns_top_k(self) -> None: |
| 109 | + events = load(FIXTURE) |
| 110 | + slowest = slowest_events(events, k=3) |
| 111 | + assert len(slowest) == 3 |
| 112 | + # Descending by wall_ms |
| 113 | + assert slowest[0].wall_ms >= slowest[1].wall_ms >= slowest[2].wall_ms |
| 114 | + # The slowest is the sis/base first rolling_logprob (520.3 ms). |
| 115 | + assert slowest[0].wall_ms == 520.3 |
| 116 | + |
| 117 | + def test_k_larger_than_events_returns_all(self) -> None: |
| 118 | + events = load(FIXTURE) |
| 119 | + slowest = slowest_events(events, k=100) |
| 120 | + assert len(slowest) == len(events) |
| 121 | + |
| 122 | + |
| 123 | +class TestRenderers: |
| 124 | + def test_json_shape(self) -> None: |
| 125 | + events = load(FIXTURE) |
| 126 | + report = build_report(events, slowest_k=3) |
| 127 | + payload = json.loads(render_json(report)) |
| 128 | + assert payload["total_events"] == 8 |
| 129 | + assert payload["overall_hit_rate"] == 0.25 |
| 130 | + assert len(payload["per_probe"]) == 2 |
| 131 | + assert len(payload["slowest"]) == 3 |
| 132 | + |
| 133 | + def test_markdown_nonempty(self) -> None: |
| 134 | + events = load(FIXTURE) |
| 135 | + md = render_markdown(build_report(events)) |
| 136 | + assert "# sway trace" in md |
| 137 | + assert "## per-probe" in md |
| 138 | + assert "## per-view" in md |
| 139 | + assert "dk" in md |
| 140 | + assert "sis" in md |
| 141 | + |
| 142 | + def test_terminal_renders_without_error(self) -> None: |
| 143 | + events = load(FIXTURE) |
| 144 | + console = Console(record=True, width=160) |
| 145 | + render_terminal(build_report(events), console=console) |
| 146 | + text = console.export_text() |
| 147 | + assert "sway trace" in text |
| 148 | + assert "dk" in text |
| 149 | + assert "sis" in text |
| 150 | + |
| 151 | + |
| 152 | +class TestBuildReport: |
| 153 | + def test_empty_events(self) -> None: |
| 154 | + report = build_report([]) |
| 155 | + assert report.total_events == 0 |
| 156 | + assert report.total_wall_ms == 0.0 |
| 157 | + assert report.overall_hit_rate == 0.0 |
| 158 | + assert report.per_probe == [] |
| 159 | + |
| 160 | + def test_probe_summary_matches(self) -> None: |
| 161 | + events = load(FIXTURE) |
| 162 | + report = build_report(events) |
| 163 | + assert isinstance(report.per_probe[0], ProbeSummary) |
| 164 | + assert report.total_events == len(events) |
| 165 | + |
| 166 | + |
| 167 | +class TestCli: |
| 168 | + def test_terminal_default(self) -> None: |
| 169 | + result = CliRunner().invoke(app, ["trace", str(FIXTURE)]) |
| 170 | + assert result.exit_code == 0, result.stdout |
| 171 | + assert "sway trace" in result.stdout |
| 172 | + assert "dk" in result.stdout |
| 173 | + assert "per-probe" in result.stdout |
| 174 | + |
| 175 | + def test_json_format(self) -> None: |
| 176 | + result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "json"]) |
| 177 | + assert result.exit_code == 0 |
| 178 | + parsed = json.loads(result.stdout) |
| 179 | + assert parsed["total_events"] == 8 |
| 180 | + |
| 181 | + def test_markdown_format(self) -> None: |
| 182 | + result = CliRunner().invoke(app, ["trace", str(FIXTURE), "--format", "md"]) |
| 183 | + assert result.exit_code == 0 |
| 184 | + assert "# sway trace" in result.stdout |
| 185 | + |
| 186 | + def test_missing_file_exits_2(self, tmp_path: Path) -> None: |
| 187 | + missing = tmp_path / "nope.jsonl" |
| 188 | + result = CliRunner().invoke(app, ["trace", str(missing)]) |
| 189 | + # Typer's PATH argument rejects non-existent paths with its own |
| 190 | + # validation (exit 2) before our code runs. |
| 191 | + assert result.exit_code == 2 |
| 192 | + |
| 193 | + def test_empty_file_exits_1(self, tmp_path: Path) -> None: |
| 194 | + empty = tmp_path / "empty.jsonl" |
| 195 | + empty.write_text("", encoding="utf-8") |
| 196 | + result = CliRunner().invoke(app, ["trace", str(empty)]) |
| 197 | + assert result.exit_code == 1 |
| 198 | + assert "no events" in result.stdout + result.stderr |
| 199 | + |
| 200 | + def test_slowest_override(self) -> None: |
| 201 | + result = CliRunner().invoke( |
| 202 | + app, ["trace", str(FIXTURE), "--format", "json", "--slowest", "2"] |
| 203 | + ) |
| 204 | + assert result.exit_code == 0 |
| 205 | + parsed = json.loads(result.stdout) |
| 206 | + assert len(parsed["slowest"]) == 2 |