| 1 | """Analyze forward-pass trace JSONLs (S14 / F12). |
| 2 | |
| 3 | The S07 trace infrastructure writes one line per backend scoring call |
| 4 | into a JSONL file when ``sway run --trace path.jsonl`` is set. Each |
| 5 | event carries ``(probe, view_id, prompt_hash, top_k, op, wall_ms, hit)``. |
| 6 | Raw events are easy for backend debugging; they're not easy for |
| 7 | answering the user-level questions the audit's F12 pitched: |
| 8 | |
| 9 | - *"Which probe × view pair dominated suite wall time?"* |
| 10 | - *"Did the S07 cache actually help this run — overall and per probe?"* |
| 11 | - *"Which individual prompts are the slowest? Can I shrink the suite |
| 12 | by cutting the expensive tail?"* |
| 13 | |
| 14 | This module takes the raw event stream and produces three summary |
| 15 | views answering those three questions, plus three renderers |
| 16 | (terminal, markdown, JSON) that match the conventions of |
| 17 | :mod:`dlm_sway.suite.report` and :mod:`dlm_sway.suite.compare`. |
| 18 | """ |
| 19 | |
| 20 | from __future__ import annotations |
| 21 | |
| 22 | import json |
| 23 | from dataclasses import dataclass, field |
| 24 | from io import StringIO |
| 25 | from pathlib import Path |
| 26 | from typing import Any |
| 27 | |
| 28 | from rich.console import Console |
| 29 | from rich.table import Table |
| 30 | from rich.text import Text |
| 31 | |
| 32 | |
| 33 | @dataclass(frozen=True, slots=True) |
| 34 | class TraceEvent: |
| 35 | """One forward-pass event from the ``sway run --trace`` JSONL.""" |
| 36 | |
| 37 | ts: float |
| 38 | probe: str | None |
| 39 | view_id: str |
| 40 | prompt_hash: str |
| 41 | top_k: int |
| 42 | op: str |
| 43 | wall_ms: float |
| 44 | hit: bool |
| 45 | |
| 46 | |
| 47 | @dataclass(frozen=True, slots=True) |
| 48 | class ProbeSummary: |
| 49 | """Per-probe aggregation across a trace file.""" |
| 50 | |
| 51 | probe: str |
| 52 | n_events: int |
| 53 | total_ms: float |
| 54 | cache_hits: int |
| 55 | cache_misses: int |
| 56 | |
| 57 | @property |
| 58 | def hit_rate(self) -> float: |
| 59 | total = self.cache_hits + self.cache_misses |
| 60 | return float(self.cache_hits) / total if total else 0.0 |
| 61 | |
| 62 | |
| 63 | @dataclass(frozen=True, slots=True) |
| 64 | class ViewSummary: |
| 65 | """Per-view (``base``, ``ft``, ``null_42``, …) aggregation.""" |
| 66 | |
| 67 | view_id: str |
| 68 | n_events: int |
| 69 | total_ms: float |
| 70 | |
| 71 | |
| 72 | @dataclass(frozen=True, slots=True) |
| 73 | class TraceReport: |
| 74 | """Full analysis of a trace file — the shape every renderer consumes.""" |
| 75 | |
| 76 | total_events: int |
| 77 | total_wall_ms: float |
| 78 | overall_hit_rate: float |
| 79 | per_probe: list[ProbeSummary] = field(default_factory=list) |
| 80 | per_view: list[ViewSummary] = field(default_factory=list) |
| 81 | slowest: list[TraceEvent] = field(default_factory=list) |
| 82 | |
| 83 | |
| 84 | # ---------------------------------------------------------------------- |
| 85 | # Load + analyze |
| 86 | # ---------------------------------------------------------------------- |
| 87 | |
| 88 | |
| 89 | def load(path: Path) -> list[TraceEvent]: |
| 90 | """Parse the JSONL trace at ``path`` into :class:`TraceEvent` objects. |
| 91 | |
| 92 | Malformed lines are skipped with no error — old traces from a |
| 93 | schema-bumped writer shouldn't crash the analyzer. Missing |
| 94 | optional fields fall back to sensible defaults (``probe=None`` for |
| 95 | pre-S07 traces; ``top_k=0`` for ``logprob_of`` which doesn't carry |
| 96 | one). |
| 97 | """ |
| 98 | events: list[TraceEvent] = [] |
| 99 | with path.open(encoding="utf-8") as fh: |
| 100 | for line in fh: |
| 101 | line = line.strip() |
| 102 | if not line: |
| 103 | continue |
| 104 | try: |
| 105 | payload: dict[str, Any] = json.loads(line) |
| 106 | except json.JSONDecodeError: |
| 107 | continue |
| 108 | try: |
| 109 | events.append( |
| 110 | TraceEvent( |
| 111 | ts=float(payload.get("ts", 0.0)), |
| 112 | probe=payload.get("probe"), |
| 113 | view_id=str(payload.get("view_id", "")), |
| 114 | prompt_hash=str(payload.get("prompt_hash", "")), |
| 115 | top_k=int(payload.get("top_k", 0)), |
| 116 | op=str(payload.get("op", "")), |
| 117 | wall_ms=float(payload.get("wall_ms", 0.0)), |
| 118 | hit=bool(payload.get("hit", False)), |
| 119 | ) |
| 120 | ) |
| 121 | except (TypeError, ValueError): |
| 122 | # Partial / corrupted event — skip without halting. |
| 123 | continue |
| 124 | return events |
| 125 | |
| 126 | |
| 127 | def per_probe_summary(events: list[TraceEvent]) -> list[ProbeSummary]: |
| 128 | """Group events by probe, sorted by total wall time (descending).""" |
| 129 | buckets: dict[str, dict[str, Any]] = {} |
| 130 | for e in events: |
| 131 | key = e.probe or "<unassigned>" |
| 132 | b = buckets.setdefault( |
| 133 | key, |
| 134 | {"n": 0, "ms": 0.0, "hits": 0, "misses": 0}, |
| 135 | ) |
| 136 | b["n"] += 1 |
| 137 | b["ms"] += e.wall_ms |
| 138 | if e.hit: |
| 139 | b["hits"] += 1 |
| 140 | else: |
| 141 | b["misses"] += 1 |
| 142 | out = [ |
| 143 | ProbeSummary( |
| 144 | probe=name, |
| 145 | n_events=b["n"], |
| 146 | total_ms=b["ms"], |
| 147 | cache_hits=b["hits"], |
| 148 | cache_misses=b["misses"], |
| 149 | ) |
| 150 | for name, b in buckets.items() |
| 151 | ] |
| 152 | out.sort(key=lambda s: -s.total_ms) |
| 153 | return out |
| 154 | |
| 155 | |
| 156 | def per_view_summary(events: list[TraceEvent]) -> list[ViewSummary]: |
| 157 | """Group events by view_id, sorted by total wall time (descending).""" |
| 158 | buckets: dict[str, dict[str, Any]] = {} |
| 159 | for e in events: |
| 160 | b = buckets.setdefault(e.view_id, {"n": 0, "ms": 0.0}) |
| 161 | b["n"] += 1 |
| 162 | b["ms"] += e.wall_ms |
| 163 | out = [ |
| 164 | ViewSummary(view_id=name, n_events=b["n"], total_ms=b["ms"]) for name, b in buckets.items() |
| 165 | ] |
| 166 | out.sort(key=lambda s: -s.total_ms) |
| 167 | return out |
| 168 | |
| 169 | |
| 170 | def slowest_events(events: list[TraceEvent], *, k: int = 10) -> list[TraceEvent]: |
| 171 | """Return the ``k`` events with the largest ``wall_ms``. |
| 172 | |
| 173 | Cache hits dominate the event count but rarely the wall-time tail — |
| 174 | a miss that took 2 seconds is far more actionable than ten hits of |
| 175 | 0.1 ms each. We include both hits and misses in the sort and let |
| 176 | the caller read ``hit`` per event to interpret. |
| 177 | """ |
| 178 | return sorted(events, key=lambda e: -e.wall_ms)[:k] |
| 179 | |
| 180 | |
| 181 | def build_report(events: list[TraceEvent], *, slowest_k: int = 10) -> TraceReport: |
| 182 | """One-call helper for renderers: turn events into a full report.""" |
| 183 | total_ms = sum(e.wall_ms for e in events) |
| 184 | total_hits = sum(1 for e in events if e.hit) |
| 185 | overall_hit_rate = total_hits / len(events) if events else 0.0 |
| 186 | return TraceReport( |
| 187 | total_events=len(events), |
| 188 | total_wall_ms=total_ms, |
| 189 | overall_hit_rate=overall_hit_rate, |
| 190 | per_probe=per_probe_summary(events), |
| 191 | per_view=per_view_summary(events), |
| 192 | slowest=slowest_events(events, k=slowest_k), |
| 193 | ) |
| 194 | |
| 195 | |
| 196 | # ---------------------------------------------------------------------- |
| 197 | # Renderers |
| 198 | # ---------------------------------------------------------------------- |
| 199 | |
| 200 | |
| 201 | def render_terminal(report: TraceReport, *, console: Console | None = None) -> None: |
| 202 | """Rich-formatted three-table breakdown of the trace.""" |
| 203 | c = console or Console() |
| 204 | c.print( |
| 205 | Text( |
| 206 | f"sway trace — {report.total_events} events, " |
| 207 | f"total {report.total_wall_ms / 1000.0:.2f}s, " |
| 208 | f"cache hit rate {report.overall_hit_rate:.1%}", |
| 209 | style="bold", |
| 210 | ) |
| 211 | ) |
| 212 | c.print() |
| 213 | |
| 214 | probe_table = Table( |
| 215 | show_header=True, header_style="bold", box=None, padding=(0, 1), title="per-probe" |
| 216 | ) |
| 217 | probe_table.add_column("probe", style="cyan") |
| 218 | probe_table.add_column("events", justify="right") |
| 219 | probe_table.add_column("wall_ms", justify="right") |
| 220 | probe_table.add_column("hits", justify="right") |
| 221 | probe_table.add_column("hit_rate", justify="right") |
| 222 | for p in report.per_probe: |
| 223 | probe_table.add_row( |
| 224 | p.probe, |
| 225 | str(p.n_events), |
| 226 | f"{p.total_ms:,.1f}", |
| 227 | str(p.cache_hits), |
| 228 | f"{p.hit_rate:.1%}", |
| 229 | ) |
| 230 | c.print(probe_table) |
| 231 | c.print() |
| 232 | |
| 233 | view_table = Table( |
| 234 | show_header=True, header_style="bold", box=None, padding=(0, 1), title="per-view" |
| 235 | ) |
| 236 | view_table.add_column("view_id", style="magenta") |
| 237 | view_table.add_column("events", justify="right") |
| 238 | view_table.add_column("wall_ms", justify="right") |
| 239 | for v in report.per_view: |
| 240 | view_table.add_row(v.view_id, str(v.n_events), f"{v.total_ms:,.1f}") |
| 241 | c.print(view_table) |
| 242 | c.print() |
| 243 | |
| 244 | if report.slowest: |
| 245 | slow_table = Table( |
| 246 | show_header=True, |
| 247 | header_style="bold", |
| 248 | box=None, |
| 249 | padding=(0, 1), |
| 250 | title=f"slowest {len(report.slowest)} events", |
| 251 | ) |
| 252 | slow_table.add_column("probe") |
| 253 | slow_table.add_column("view") |
| 254 | slow_table.add_column("op") |
| 255 | slow_table.add_column("wall_ms", justify="right") |
| 256 | slow_table.add_column("hit") |
| 257 | for e in report.slowest: |
| 258 | slow_table.add_row( |
| 259 | e.probe or "—", |
| 260 | e.view_id, |
| 261 | e.op, |
| 262 | f"{e.wall_ms:,.1f}", |
| 263 | "yes" if e.hit else "no", |
| 264 | ) |
| 265 | c.print(slow_table) |
| 266 | |
| 267 | |
| 268 | def render_markdown(report: TraceReport) -> str: |
| 269 | """Markdown equivalent of :func:`render_terminal`.""" |
| 270 | buf = StringIO() |
| 271 | buf.write( |
| 272 | f"# sway trace — {report.total_events} events, " |
| 273 | f"total {report.total_wall_ms / 1000.0:.2f}s, " |
| 274 | f"hit rate {report.overall_hit_rate:.1%}\n\n" |
| 275 | ) |
| 276 | buf.write("## per-probe\n\n") |
| 277 | buf.write("| probe | events | wall_ms | hits | hit_rate |\n|---|---:|---:|---:|---:|\n") |
| 278 | for p in report.per_probe: |
| 279 | buf.write( |
| 280 | f"| {p.probe} | {p.n_events} | {p.total_ms:,.1f} | " |
| 281 | f"{p.cache_hits} | {p.hit_rate:.1%} |\n" |
| 282 | ) |
| 283 | buf.write("\n## per-view\n\n") |
| 284 | buf.write("| view_id | events | wall_ms |\n|---|---:|---:|\n") |
| 285 | for v in report.per_view: |
| 286 | buf.write(f"| {v.view_id} | {v.n_events} | {v.total_ms:,.1f} |\n") |
| 287 | if report.slowest: |
| 288 | buf.write(f"\n## slowest {len(report.slowest)} events\n\n") |
| 289 | buf.write("| probe | view | op | wall_ms | hit |\n|---|---|---|---:|---|\n") |
| 290 | for e in report.slowest: |
| 291 | buf.write( |
| 292 | f"| {e.probe or '—'} | {e.view_id} | {e.op} | " |
| 293 | f"{e.wall_ms:,.1f} | {'yes' if e.hit else 'no'} |\n" |
| 294 | ) |
| 295 | return buf.getvalue() |
| 296 | |
| 297 | |
| 298 | def render_json(report: TraceReport) -> str: |
| 299 | """Machine-readable JSON. Same field names as the dataclasses.""" |
| 300 | payload: dict[str, Any] = { |
| 301 | "total_events": report.total_events, |
| 302 | "total_wall_ms": report.total_wall_ms, |
| 303 | "overall_hit_rate": report.overall_hit_rate, |
| 304 | "per_probe": [ |
| 305 | { |
| 306 | "probe": p.probe, |
| 307 | "n_events": p.n_events, |
| 308 | "total_ms": p.total_ms, |
| 309 | "cache_hits": p.cache_hits, |
| 310 | "cache_misses": p.cache_misses, |
| 311 | "hit_rate": p.hit_rate, |
| 312 | } |
| 313 | for p in report.per_probe |
| 314 | ], |
| 315 | "per_view": [ |
| 316 | { |
| 317 | "view_id": v.view_id, |
| 318 | "n_events": v.n_events, |
| 319 | "total_ms": v.total_ms, |
| 320 | } |
| 321 | for v in report.per_view |
| 322 | ], |
| 323 | "slowest": [ |
| 324 | { |
| 325 | "ts": e.ts, |
| 326 | "probe": e.probe, |
| 327 | "view_id": e.view_id, |
| 328 | "prompt_hash": e.prompt_hash, |
| 329 | "top_k": e.top_k, |
| 330 | "op": e.op, |
| 331 | "wall_ms": e.wall_ms, |
| 332 | "hit": e.hit, |
| 333 | } |
| 334 | for e in report.slowest |
| 335 | ], |
| 336 | } |
| 337 | return json.dumps(payload, indent=2) |
| 338 | |
| 339 | |
| 340 | __all__ = [ |
| 341 | "ProbeSummary", |
| 342 | "TraceEvent", |
| 343 | "TraceReport", |
| 344 | "ViewSummary", |
| 345 | "build_report", |
| 346 | "load", |
| 347 | "per_probe_summary", |
| 348 | "per_view_summary", |
| 349 | "render_json", |
| 350 | "render_markdown", |
| 351 | "render_terminal", |
| 352 | "slowest_events", |
| 353 | ] |