Python · 11547 bytes Raw Blame History
1 """Analyze forward-pass trace JSONLs (S14 / F12).
2
3 The S07 trace infrastructure writes one line per backend scoring call
4 into a JSONL file when ``sway run --trace path.jsonl`` is set. Each
5 event carries ``(probe, view_id, prompt_hash, top_k, op, wall_ms, hit)``.
6 Raw events are easy for backend debugging; they're not easy for
7 answering the user-level questions the audit's F12 pitched:
8
9 - *"Which probe × view pair dominated suite wall time?"*
10 - *"Did the S07 cache actually help this run — overall and per probe?"*
11 - *"Which individual prompts are the slowest? Can I shrink the suite
12 by cutting the expensive tail?"*
13
14 This module takes the raw event stream and produces three summary
15 views answering those three questions, plus three renderers
16 (terminal, markdown, JSON) that match the conventions of
17 :mod:`dlm_sway.suite.report` and :mod:`dlm_sway.suite.compare`.
18 """
19
20 from __future__ import annotations
21
22 import json
23 from dataclasses import dataclass, field
24 from io import StringIO
25 from pathlib import Path
26 from typing import Any
27
28 from rich.console import Console
29 from rich.table import Table
30 from rich.text import Text
31
32
33 @dataclass(frozen=True, slots=True)
34 class TraceEvent:
35 """One forward-pass event from the ``sway run --trace`` JSONL."""
36
37 ts: float
38 probe: str | None
39 view_id: str
40 prompt_hash: str
41 top_k: int
42 op: str
43 wall_ms: float
44 hit: bool
45
46
47 @dataclass(frozen=True, slots=True)
48 class ProbeSummary:
49 """Per-probe aggregation across a trace file."""
50
51 probe: str
52 n_events: int
53 total_ms: float
54 cache_hits: int
55 cache_misses: int
56
57 @property
58 def hit_rate(self) -> float:
59 total = self.cache_hits + self.cache_misses
60 return float(self.cache_hits) / total if total else 0.0
61
62
63 @dataclass(frozen=True, slots=True)
64 class ViewSummary:
65 """Per-view (``base``, ``ft``, ``null_42``, …) aggregation."""
66
67 view_id: str
68 n_events: int
69 total_ms: float
70
71
72 @dataclass(frozen=True, slots=True)
73 class TraceReport:
74 """Full analysis of a trace file — the shape every renderer consumes."""
75
76 total_events: int
77 total_wall_ms: float
78 overall_hit_rate: float
79 per_probe: list[ProbeSummary] = field(default_factory=list)
80 per_view: list[ViewSummary] = field(default_factory=list)
81 slowest: list[TraceEvent] = field(default_factory=list)
82
83
84 # ----------------------------------------------------------------------
85 # Load + analyze
86 # ----------------------------------------------------------------------
87
88
89 def load(path: Path) -> list[TraceEvent]:
90 """Parse the JSONL trace at ``path`` into :class:`TraceEvent` objects.
91
92 Malformed lines are skipped with no error — old traces from a
93 schema-bumped writer shouldn't crash the analyzer. Missing
94 optional fields fall back to sensible defaults (``probe=None`` for
95 pre-S07 traces; ``top_k=0`` for ``logprob_of`` which doesn't carry
96 one).
97 """
98 events: list[TraceEvent] = []
99 with path.open(encoding="utf-8") as fh:
100 for line in fh:
101 line = line.strip()
102 if not line:
103 continue
104 try:
105 payload: dict[str, Any] = json.loads(line)
106 except json.JSONDecodeError:
107 continue
108 try:
109 events.append(
110 TraceEvent(
111 ts=float(payload.get("ts", 0.0)),
112 probe=payload.get("probe"),
113 view_id=str(payload.get("view_id", "")),
114 prompt_hash=str(payload.get("prompt_hash", "")),
115 top_k=int(payload.get("top_k", 0)),
116 op=str(payload.get("op", "")),
117 wall_ms=float(payload.get("wall_ms", 0.0)),
118 hit=bool(payload.get("hit", False)),
119 )
120 )
121 except (TypeError, ValueError):
122 # Partial / corrupted event — skip without halting.
123 continue
124 return events
125
126
127 def per_probe_summary(events: list[TraceEvent]) -> list[ProbeSummary]:
128 """Group events by probe, sorted by total wall time (descending)."""
129 buckets: dict[str, dict[str, Any]] = {}
130 for e in events:
131 key = e.probe or "<unassigned>"
132 b = buckets.setdefault(
133 key,
134 {"n": 0, "ms": 0.0, "hits": 0, "misses": 0},
135 )
136 b["n"] += 1
137 b["ms"] += e.wall_ms
138 if e.hit:
139 b["hits"] += 1
140 else:
141 b["misses"] += 1
142 out = [
143 ProbeSummary(
144 probe=name,
145 n_events=b["n"],
146 total_ms=b["ms"],
147 cache_hits=b["hits"],
148 cache_misses=b["misses"],
149 )
150 for name, b in buckets.items()
151 ]
152 out.sort(key=lambda s: -s.total_ms)
153 return out
154
155
156 def per_view_summary(events: list[TraceEvent]) -> list[ViewSummary]:
157 """Group events by view_id, sorted by total wall time (descending)."""
158 buckets: dict[str, dict[str, Any]] = {}
159 for e in events:
160 b = buckets.setdefault(e.view_id, {"n": 0, "ms": 0.0})
161 b["n"] += 1
162 b["ms"] += e.wall_ms
163 out = [
164 ViewSummary(view_id=name, n_events=b["n"], total_ms=b["ms"]) for name, b in buckets.items()
165 ]
166 out.sort(key=lambda s: -s.total_ms)
167 return out
168
169
170 def slowest_events(events: list[TraceEvent], *, k: int = 10) -> list[TraceEvent]:
171 """Return the ``k`` events with the largest ``wall_ms``.
172
173 Cache hits dominate the event count but rarely the wall-time tail —
174 a miss that took 2 seconds is far more actionable than ten hits of
175 0.1 ms each. We include both hits and misses in the sort and let
176 the caller read ``hit`` per event to interpret.
177 """
178 return sorted(events, key=lambda e: -e.wall_ms)[:k]
179
180
181 def build_report(events: list[TraceEvent], *, slowest_k: int = 10) -> TraceReport:
182 """One-call helper for renderers: turn events into a full report."""
183 total_ms = sum(e.wall_ms for e in events)
184 total_hits = sum(1 for e in events if e.hit)
185 overall_hit_rate = total_hits / len(events) if events else 0.0
186 return TraceReport(
187 total_events=len(events),
188 total_wall_ms=total_ms,
189 overall_hit_rate=overall_hit_rate,
190 per_probe=per_probe_summary(events),
191 per_view=per_view_summary(events),
192 slowest=slowest_events(events, k=slowest_k),
193 )
194
195
196 # ----------------------------------------------------------------------
197 # Renderers
198 # ----------------------------------------------------------------------
199
200
201 def render_terminal(report: TraceReport, *, console: Console | None = None) -> None:
202 """Rich-formatted three-table breakdown of the trace."""
203 c = console or Console()
204 c.print(
205 Text(
206 f"sway trace — {report.total_events} events, "
207 f"total {report.total_wall_ms / 1000.0:.2f}s, "
208 f"cache hit rate {report.overall_hit_rate:.1%}",
209 style="bold",
210 )
211 )
212 c.print()
213
214 probe_table = Table(
215 show_header=True, header_style="bold", box=None, padding=(0, 1), title="per-probe"
216 )
217 probe_table.add_column("probe", style="cyan")
218 probe_table.add_column("events", justify="right")
219 probe_table.add_column("wall_ms", justify="right")
220 probe_table.add_column("hits", justify="right")
221 probe_table.add_column("hit_rate", justify="right")
222 for p in report.per_probe:
223 probe_table.add_row(
224 p.probe,
225 str(p.n_events),
226 f"{p.total_ms:,.1f}",
227 str(p.cache_hits),
228 f"{p.hit_rate:.1%}",
229 )
230 c.print(probe_table)
231 c.print()
232
233 view_table = Table(
234 show_header=True, header_style="bold", box=None, padding=(0, 1), title="per-view"
235 )
236 view_table.add_column("view_id", style="magenta")
237 view_table.add_column("events", justify="right")
238 view_table.add_column("wall_ms", justify="right")
239 for v in report.per_view:
240 view_table.add_row(v.view_id, str(v.n_events), f"{v.total_ms:,.1f}")
241 c.print(view_table)
242 c.print()
243
244 if report.slowest:
245 slow_table = Table(
246 show_header=True,
247 header_style="bold",
248 box=None,
249 padding=(0, 1),
250 title=f"slowest {len(report.slowest)} events",
251 )
252 slow_table.add_column("probe")
253 slow_table.add_column("view")
254 slow_table.add_column("op")
255 slow_table.add_column("wall_ms", justify="right")
256 slow_table.add_column("hit")
257 for e in report.slowest:
258 slow_table.add_row(
259 e.probe or "—",
260 e.view_id,
261 e.op,
262 f"{e.wall_ms:,.1f}",
263 "yes" if e.hit else "no",
264 )
265 c.print(slow_table)
266
267
268 def render_markdown(report: TraceReport) -> str:
269 """Markdown equivalent of :func:`render_terminal`."""
270 buf = StringIO()
271 buf.write(
272 f"# sway trace — {report.total_events} events, "
273 f"total {report.total_wall_ms / 1000.0:.2f}s, "
274 f"hit rate {report.overall_hit_rate:.1%}\n\n"
275 )
276 buf.write("## per-probe\n\n")
277 buf.write("| probe | events | wall_ms | hits | hit_rate |\n|---|---:|---:|---:|---:|\n")
278 for p in report.per_probe:
279 buf.write(
280 f"| {p.probe} | {p.n_events} | {p.total_ms:,.1f} | "
281 f"{p.cache_hits} | {p.hit_rate:.1%} |\n"
282 )
283 buf.write("\n## per-view\n\n")
284 buf.write("| view_id | events | wall_ms |\n|---|---:|---:|\n")
285 for v in report.per_view:
286 buf.write(f"| {v.view_id} | {v.n_events} | {v.total_ms:,.1f} |\n")
287 if report.slowest:
288 buf.write(f"\n## slowest {len(report.slowest)} events\n\n")
289 buf.write("| probe | view | op | wall_ms | hit |\n|---|---|---|---:|---|\n")
290 for e in report.slowest:
291 buf.write(
292 f"| {e.probe or '—'} | {e.view_id} | {e.op} | "
293 f"{e.wall_ms:,.1f} | {'yes' if e.hit else 'no'} |\n"
294 )
295 return buf.getvalue()
296
297
298 def render_json(report: TraceReport) -> str:
299 """Machine-readable JSON. Same field names as the dataclasses."""
300 payload: dict[str, Any] = {
301 "total_events": report.total_events,
302 "total_wall_ms": report.total_wall_ms,
303 "overall_hit_rate": report.overall_hit_rate,
304 "per_probe": [
305 {
306 "probe": p.probe,
307 "n_events": p.n_events,
308 "total_ms": p.total_ms,
309 "cache_hits": p.cache_hits,
310 "cache_misses": p.cache_misses,
311 "hit_rate": p.hit_rate,
312 }
313 for p in report.per_probe
314 ],
315 "per_view": [
316 {
317 "view_id": v.view_id,
318 "n_events": v.n_events,
319 "total_ms": v.total_ms,
320 }
321 for v in report.per_view
322 ],
323 "slowest": [
324 {
325 "ts": e.ts,
326 "probe": e.probe,
327 "view_id": e.view_id,
328 "prompt_hash": e.prompt_hash,
329 "top_k": e.top_k,
330 "op": e.op,
331 "wall_ms": e.wall_ms,
332 "hit": e.hit,
333 }
334 for e in report.slowest
335 ],
336 }
337 return json.dumps(payload, indent=2)
338
339
340 __all__ = [
341 "ProbeSummary",
342 "TraceEvent",
343 "TraceReport",
344 "ViewSummary",
345 "build_report",
346 "load",
347 "per_probe_summary",
348 "per_view_summary",
349 "render_json",
350 "render_markdown",
351 "render_terminal",
352 "slowest_events",
353 ]