| 1 | """Shared forward-pass cache + tracing + stats (Sprint 07). |
| 2 | |
| 3 | The audit observed that probes re-compute base distributions earlier |
| 4 | probes already produced — ``delta_kl``, ``adapter_ablation``, |
| 5 | ``prompt_collapse``, and the null-adapter calibration matrix all hit |
| 6 | overlapping prompts on the base view, and nothing was sharing the |
| 7 | result. This module is the shared plumbing: |
| 8 | |
| 9 | - :class:`ForwardCache` — a bounded LRU keyed on |
| 10 | ``(op, view_id, prompt_hash, top_k)`` that backend views consult |
| 11 | before computing a forward pass. |
| 12 | - :class:`TraceWriter` — appends one JSONL line per forward pass when |
| 13 | enabled; total no-op when disabled. |
| 14 | - :class:`BackendInstrumentation` — what a backend instance owns (one |
| 15 | cache + one tracer). Views call :meth:`cached` to route their |
| 16 | scoring methods through both. |
| 17 | |
| 18 | Every scoring method on every backend view is exactly three lines of |
| 19 | glue: |
| 20 | |
| 21 | def next_token_dist(self, prompt, *, top_k=256): |
| 22 | return self._inst.cached( |
| 23 | "next_token_dist", self.id, prompt, top_k, |
| 24 | lambda: self._compute_next_token_dist(prompt, top_k=top_k), |
| 25 | ) |
| 26 | |
| 27 | Thread-safety scope is intentionally narrow — the HF backend's |
| 28 | ``_active`` toggle is still single-threaded (B19 design note); the |
| 29 | cache itself is protected by a lock so concurrent view entries |
| 30 | don't corrupt the LRU ordering even in the experimental concurrency |
| 31 | mode. |
| 32 | """ |
| 33 | |
| 34 | from __future__ import annotations |
| 35 | |
| 36 | import json |
| 37 | import threading |
| 38 | import time |
| 39 | from collections import OrderedDict |
| 40 | from collections.abc import Callable, Sequence |
| 41 | from dataclasses import dataclass |
| 42 | from pathlib import Path |
| 43 | from typing import Any, TypeVar |
| 44 | |
| 45 | T = TypeVar("T") |
| 46 | |
| 47 | #: The operations the cache keys on. Kept tight — each op has a |
| 48 | #: distinct return shape (TokenDist vs RollingLogprob vs float) and |
| 49 | #: the backend view is the one that decides which to call. |
| 50 | CacheOp = str # one of "next_token_dist" | "rolling_logprob" | "logprob_of" |
| 51 | |
| 52 | |
| 53 | @dataclass(slots=True) |
| 54 | class BackendStats: |
| 55 | """Counters a backend publishes to the suite result for a run. |
| 56 | |
| 57 | Fields are cumulative across every view transition; the runner |
| 58 | snapshots them after the suite finishes and attaches the snapshot |
| 59 | to :class:`~dlm_sway.core.result.SuiteResult.backend_stats`. |
| 60 | """ |
| 61 | |
| 62 | cache_hits: int = 0 |
| 63 | cache_misses: int = 0 |
| 64 | forward_passes: int = 0 |
| 65 | #: Total wall seconds spent *inside* backend scoring methods |
| 66 | #: (forward-pass compute + cache hits; excludes probe-side work). |
| 67 | scoring_wall_s: float = 0.0 |
| 68 | #: S23 batched-backend-exec counters. ``batches_sent`` is the number |
| 69 | #: of batched forward calls the backend actually issued (i.e. misses |
| 70 | #: routed through the batched path; single-prompt calls and cache |
| 71 | #: hits don't increment). ``batched_prompts`` is the sum of batch |
| 72 | #: sizes, so ``batched_prompts / batches_sent`` = mean batch size. |
| 73 | #: A run with zero batches (older probes / opt-out probes only) |
| 74 | #: reports the avg as 0 via :attr:`avg_batch_size`. |
| 75 | batches_sent: int = 0 |
| 76 | batched_prompts: int = 0 |
| 77 | max_batch_size: int = 0 |
| 78 | |
| 79 | @property |
| 80 | def total_lookups(self) -> int: |
| 81 | return self.cache_hits + self.cache_misses |
| 82 | |
| 83 | @property |
| 84 | def hit_rate(self) -> float: |
| 85 | total = self.total_lookups |
| 86 | return (self.cache_hits / total) if total > 0 else 0.0 |
| 87 | |
| 88 | @property |
| 89 | def avg_batch_size(self) -> float: |
| 90 | return (self.batched_prompts / self.batches_sent) if self.batches_sent > 0 else 0.0 |
| 91 | |
| 92 | def to_dict(self) -> dict[str, float | int]: |
| 93 | return { |
| 94 | "cache_hits": self.cache_hits, |
| 95 | "cache_misses": self.cache_misses, |
| 96 | "forward_passes": self.forward_passes, |
| 97 | "scoring_wall_s": self.scoring_wall_s, |
| 98 | "hit_rate": self.hit_rate, |
| 99 | "batches_sent": self.batches_sent, |
| 100 | "batched_prompts": self.batched_prompts, |
| 101 | "max_batch_size": self.max_batch_size, |
| 102 | "avg_batch_size": self.avg_batch_size, |
| 103 | } |
| 104 | |
| 105 | |
| 106 | class ForwardCache: |
| 107 | """Bounded LRU for backend forward passes, keyed on view + prompt. |
| 108 | |
| 109 | We don't use :func:`functools.lru_cache` because the cacheable |
| 110 | unit isn't a function — it's a ``(op, view_id, prompt, top_k)`` |
| 111 | tuple that spans multiple scoring methods. An explicit |
| 112 | :class:`~collections.OrderedDict` keeps insertion order and lets |
| 113 | us bump entries to the end on hit. |
| 114 | |
| 115 | The key includes ``view_id`` (e.g. ``"base"``, ``"ft"``, |
| 116 | ``"scaled_1.25"``, ``"null_42"``) so a view transition is a |
| 117 | natural cache miss — no explicit invalidation needed. If a future |
| 118 | fix to the HF toggle corrupts base/ft weights in place, a new |
| 119 | view-id string would surface the drift immediately. |
| 120 | """ |
| 121 | |
| 122 | def __init__(self, maxsize: int = 256) -> None: |
| 123 | if maxsize <= 0: |
| 124 | raise ValueError(f"maxsize must be positive; got {maxsize}") |
| 125 | self._maxsize = maxsize |
| 126 | self._store: OrderedDict[tuple[Any, ...], Any] = OrderedDict() |
| 127 | self._lock = threading.Lock() |
| 128 | |
| 129 | def get(self, key: tuple[Any, ...]) -> Any: |
| 130 | with self._lock: |
| 131 | try: |
| 132 | value = self._store[key] |
| 133 | except KeyError: |
| 134 | return _MISS |
| 135 | self._store.move_to_end(key) |
| 136 | return value |
| 137 | |
| 138 | def put(self, key: tuple[Any, ...], value: Any) -> None: |
| 139 | with self._lock: |
| 140 | self._store[key] = value |
| 141 | self._store.move_to_end(key) |
| 142 | while len(self._store) > self._maxsize: |
| 143 | self._store.popitem(last=False) |
| 144 | |
| 145 | def __len__(self) -> int: |
| 146 | return len(self._store) |
| 147 | |
| 148 | def clear(self) -> None: |
| 149 | with self._lock: |
| 150 | self._store.clear() |
| 151 | |
| 152 | |
| 153 | #: Sentinel used to tell ``ForwardCache.get`` misses apart from legit |
| 154 | #: ``None`` values. Downstream callers check ``result is _MISS``. |
| 155 | _MISS: Any = object() |
| 156 | |
| 157 | |
| 158 | @dataclass(slots=True) |
| 159 | class _TraceEvent: |
| 160 | ts: float |
| 161 | probe: str | None |
| 162 | view_id: str |
| 163 | prompt_hash: str |
| 164 | top_k: int |
| 165 | op: str |
| 166 | wall_ms: float |
| 167 | hit: bool |
| 168 | |
| 169 | |
| 170 | class TraceWriter: |
| 171 | """Append-only JSONL writer for forward-pass traces. |
| 172 | |
| 173 | Disabled mode (``path=None``) is a zero-overhead no-op: the |
| 174 | :meth:`write` method short-circuits on the first branch and the |
| 175 | backend pays one ``if`` per forward pass. Enabled mode opens the |
| 176 | file once and appends per event; thread-safe via a simple lock. |
| 177 | """ |
| 178 | |
| 179 | def __init__(self, path: Path | None) -> None: |
| 180 | self._path = path |
| 181 | self._fh: Any = None |
| 182 | self._lock = threading.Lock() |
| 183 | if path is not None: |
| 184 | path.parent.mkdir(parents=True, exist_ok=True) |
| 185 | self._fh = path.open("a", encoding="utf-8") |
| 186 | |
| 187 | def write(self, event: _TraceEvent) -> None: |
| 188 | if self._fh is None: |
| 189 | return |
| 190 | payload = { |
| 191 | "ts": event.ts, |
| 192 | "probe": event.probe, |
| 193 | "view_id": event.view_id, |
| 194 | "prompt_hash": event.prompt_hash, |
| 195 | "top_k": event.top_k, |
| 196 | "op": event.op, |
| 197 | "wall_ms": event.wall_ms, |
| 198 | "hit": event.hit, |
| 199 | } |
| 200 | with self._lock: |
| 201 | self._fh.write(json.dumps(payload) + "\n") |
| 202 | self._fh.flush() |
| 203 | |
| 204 | def close(self) -> None: |
| 205 | if self._fh is not None: |
| 206 | self._fh.close() |
| 207 | self._fh = None |
| 208 | |
| 209 | |
| 210 | class BackendInstrumentation: |
| 211 | """What a backend instance owns: one cache, one tracer, one stats. |
| 212 | |
| 213 | Backend ``__init__`` constructs this; every view holds a ref. |
| 214 | The runner peeks at ``.stats`` at suite-end and ships the snapshot |
| 215 | in :class:`SuiteResult.backend_stats`. |
| 216 | |
| 217 | Context about ``current_probe``: a backend view has no idea which |
| 218 | probe is calling it, but the runner does. The runner sets |
| 219 | :meth:`set_current_probe` before each probe runs so trace events |
| 220 | carry the right label; views don't care. |
| 221 | """ |
| 222 | |
| 223 | def __init__( |
| 224 | self, |
| 225 | *, |
| 226 | cache_maxsize: int = 256, |
| 227 | trace_path: Path | None = None, |
| 228 | ) -> None: |
| 229 | self.cache = ForwardCache(maxsize=cache_maxsize) |
| 230 | self.trace = TraceWriter(trace_path) |
| 231 | self.stats = BackendStats() |
| 232 | self._current_probe: str | None = None |
| 233 | |
| 234 | def set_current_probe(self, name: str | None) -> None: |
| 235 | """Called by the runner between probes. Label for trace events.""" |
| 236 | self._current_probe = name |
| 237 | |
| 238 | def close(self) -> None: |
| 239 | self.trace.close() |
| 240 | |
| 241 | def cached( |
| 242 | self, |
| 243 | op: CacheOp, |
| 244 | view_id: str, |
| 245 | prompt: str, |
| 246 | top_k: int, |
| 247 | compute: Callable[[], T], |
| 248 | ) -> T: |
| 249 | """Route a backend scoring call through the cache + tracer. |
| 250 | |
| 251 | On cache hit, returns the cached value and increments |
| 252 | ``cache_hits``. On miss, runs ``compute()``, stores the |
| 253 | result, increments ``cache_misses`` + ``forward_passes``, |
| 254 | and (when tracing is enabled) appends a JSONL event. |
| 255 | """ |
| 256 | prompt_hash = _prompt_hash(prompt) |
| 257 | key = (op, view_id, prompt_hash, top_k) |
| 258 | hit_start = time.perf_counter() |
| 259 | cached_value = self.cache.get(key) |
| 260 | if cached_value is not _MISS: |
| 261 | self.stats.cache_hits += 1 |
| 262 | wall = time.perf_counter() - hit_start |
| 263 | self.stats.scoring_wall_s += wall |
| 264 | self.trace.write( |
| 265 | _TraceEvent( |
| 266 | ts=time.time(), |
| 267 | probe=self._current_probe, |
| 268 | view_id=view_id, |
| 269 | prompt_hash=prompt_hash, |
| 270 | top_k=top_k, |
| 271 | op=op, |
| 272 | wall_ms=wall * 1000.0, |
| 273 | hit=True, |
| 274 | ) |
| 275 | ) |
| 276 | return cached_value # type: ignore[no-any-return] |
| 277 | |
| 278 | # Miss: compute, store, record. |
| 279 | compute_start = time.perf_counter() |
| 280 | value = compute() |
| 281 | wall = time.perf_counter() - compute_start |
| 282 | self.cache.put(key, value) |
| 283 | self.stats.cache_misses += 1 |
| 284 | self.stats.forward_passes += 1 |
| 285 | self.stats.scoring_wall_s += wall |
| 286 | self.trace.write( |
| 287 | _TraceEvent( |
| 288 | ts=time.time(), |
| 289 | probe=self._current_probe, |
| 290 | view_id=view_id, |
| 291 | prompt_hash=prompt_hash, |
| 292 | top_k=top_k, |
| 293 | op=op, |
| 294 | wall_ms=wall * 1000.0, |
| 295 | hit=False, |
| 296 | ) |
| 297 | ) |
| 298 | return value |
| 299 | |
| 300 | def cached_batch( |
| 301 | self, |
| 302 | op: CacheOp, |
| 303 | view_id: str, |
| 304 | prompts: Sequence[str], |
| 305 | top_k: int, |
| 306 | compute_misses: Callable[[list[int]], list[Any]], |
| 307 | ) -> list[Any]: |
| 308 | """Route a batched scoring call through the cache + tracer. |
| 309 | |
| 310 | Per-prompt cache lookup happens first; entries already in the |
| 311 | cache are served from it without ever entering the batch. The |
| 312 | ``compute_misses`` callback receives the list of indices into |
| 313 | ``prompts`` that missed, and is expected to return a list of |
| 314 | results *in the same order* — the backend is free to pad, call |
| 315 | ``model.forward`` once, and split the logits per row. |
| 316 | |
| 317 | Counters incremented: |
| 318 | - ``cache_hits`` per cached prompt |
| 319 | - ``cache_misses`` + ``forward_passes`` per missed prompt |
| 320 | - ``batches_sent`` once per actual forward (only when |
| 321 | ``compute_misses`` is called, i.e. at least one miss) |
| 322 | - ``batched_prompts`` by the miss count |
| 323 | - ``max_batch_size`` updated to ``max(prev, miss_count)`` |
| 324 | |
| 325 | Trace events are emitted per prompt so the JSONL trace keeps |
| 326 | its per-prompt granularity regardless of how many rows the |
| 327 | backend packed into one GPU call. |
| 328 | """ |
| 329 | results: list[Any] = [None] * len(prompts) |
| 330 | miss_indices: list[int] = [] |
| 331 | prompt_hashes: list[str] = [_prompt_hash(p) for p in prompts] |
| 332 | |
| 333 | # Pass 1: cache lookups. |
| 334 | for i, prompt_hash in enumerate(prompt_hashes): |
| 335 | key = (op, view_id, prompt_hash, top_k) |
| 336 | hit_start = time.perf_counter() |
| 337 | cached_value = self.cache.get(key) |
| 338 | if cached_value is not _MISS: |
| 339 | self.stats.cache_hits += 1 |
| 340 | wall = time.perf_counter() - hit_start |
| 341 | self.stats.scoring_wall_s += wall |
| 342 | self.trace.write( |
| 343 | _TraceEvent( |
| 344 | ts=time.time(), |
| 345 | probe=self._current_probe, |
| 346 | view_id=view_id, |
| 347 | prompt_hash=prompt_hash, |
| 348 | top_k=top_k, |
| 349 | op=op, |
| 350 | wall_ms=wall * 1000.0, |
| 351 | hit=True, |
| 352 | ) |
| 353 | ) |
| 354 | results[i] = cached_value |
| 355 | else: |
| 356 | miss_indices.append(i) |
| 357 | |
| 358 | # Pass 2: one forward call for the miss subset. |
| 359 | if miss_indices: |
| 360 | compute_start = time.perf_counter() |
| 361 | miss_values = compute_misses(miss_indices) |
| 362 | if len(miss_values) != len(miss_indices): |
| 363 | raise RuntimeError( |
| 364 | f"batched compute returned {len(miss_values)} values for " |
| 365 | f"{len(miss_indices)} misses — backend bug" |
| 366 | ) |
| 367 | wall = time.perf_counter() - compute_start |
| 368 | # Divide wall time evenly across misses for |
| 369 | # scoring_wall_s bookkeeping; batched callers don't have a |
| 370 | # per-prompt attribution. |
| 371 | per_miss_wall = wall / len(miss_indices) |
| 372 | self.stats.scoring_wall_s += wall |
| 373 | self.stats.batches_sent += 1 |
| 374 | self.stats.batched_prompts += len(miss_indices) |
| 375 | if len(miss_indices) > self.stats.max_batch_size: |
| 376 | self.stats.max_batch_size = len(miss_indices) |
| 377 | for miss_pos, idx in enumerate(miss_indices): |
| 378 | value = miss_values[miss_pos] |
| 379 | key = (op, view_id, prompt_hashes[idx], top_k) |
| 380 | self.cache.put(key, value) |
| 381 | self.stats.cache_misses += 1 |
| 382 | self.stats.forward_passes += 1 |
| 383 | self.trace.write( |
| 384 | _TraceEvent( |
| 385 | ts=time.time(), |
| 386 | probe=self._current_probe, |
| 387 | view_id=view_id, |
| 388 | prompt_hash=prompt_hashes[idx], |
| 389 | top_k=top_k, |
| 390 | op=op, |
| 391 | wall_ms=per_miss_wall * 1000.0, |
| 392 | hit=False, |
| 393 | ) |
| 394 | ) |
| 395 | results[idx] = value |
| 396 | |
| 397 | return results |
| 398 | |
| 399 | |
| 400 | def _prompt_hash(prompt: str) -> str: |
| 401 | """Stable short hash for cache keys + trace logging. |
| 402 | |
| 403 | SHA-1 truncated to 12 hex chars: enough to avoid collisions at |
| 404 | suite scale (thousands of distinct prompts per run), small enough |
| 405 | to keep cache memory down and JSONL lines readable. |
| 406 | """ |
| 407 | import hashlib |
| 408 | |
| 409 | return hashlib.sha1(prompt.encode("utf-8"), usedforsecurity=False).hexdigest()[:12] |
| 410 | |
| 411 | |
| 412 | __all__ = [ |
| 413 | "BackendInstrumentation", |
| 414 | "BackendStats", |
| 415 | "ForwardCache", |
| 416 | "TraceWriter", |
| 417 | ] |