Python · 15024 bytes Raw Blame History
1 """Shared forward-pass cache + tracing + stats (Sprint 07).
2
3 The audit observed that probes re-compute base distributions earlier
4 probes already produced — ``delta_kl``, ``adapter_ablation``,
5 ``prompt_collapse``, and the null-adapter calibration matrix all hit
6 overlapping prompts on the base view, and nothing was sharing the
7 result. This module is the shared plumbing:
8
9 - :class:`ForwardCache` — a bounded LRU keyed on
10 ``(op, view_id, prompt_hash, top_k)`` that backend views consult
11 before computing a forward pass.
12 - :class:`TraceWriter` — appends one JSONL line per forward pass when
13 enabled; total no-op when disabled.
14 - :class:`BackendInstrumentation` — what a backend instance owns (one
15 cache + one tracer). Views call :meth:`cached` to route their
16 scoring methods through both.
17
18 Every scoring method on every backend view is exactly three lines of
19 glue:
20
21 def next_token_dist(self, prompt, *, top_k=256):
22 return self._inst.cached(
23 "next_token_dist", self.id, prompt, top_k,
24 lambda: self._compute_next_token_dist(prompt, top_k=top_k),
25 )
26
27 Thread-safety scope is intentionally narrow — the HF backend's
28 ``_active`` toggle is still single-threaded (B19 design note); the
29 cache itself is protected by a lock so concurrent view entries
30 don't corrupt the LRU ordering even in the experimental concurrency
31 mode.
32 """
33
34 from __future__ import annotations
35
36 import json
37 import threading
38 import time
39 from collections import OrderedDict
40 from collections.abc import Callable, Sequence
41 from dataclasses import dataclass
42 from pathlib import Path
43 from typing import Any, TypeVar
44
45 T = TypeVar("T")
46
47 #: The operations the cache keys on. Kept tight — each op has a
48 #: distinct return shape (TokenDist vs RollingLogprob vs float) and
49 #: the backend view is the one that decides which to call.
50 CacheOp = str # one of "next_token_dist" | "rolling_logprob" | "logprob_of"
51
52
53 @dataclass(slots=True)
54 class BackendStats:
55 """Counters a backend publishes to the suite result for a run.
56
57 Fields are cumulative across every view transition; the runner
58 snapshots them after the suite finishes and attaches the snapshot
59 to :class:`~dlm_sway.core.result.SuiteResult.backend_stats`.
60 """
61
62 cache_hits: int = 0
63 cache_misses: int = 0
64 forward_passes: int = 0
65 #: Total wall seconds spent *inside* backend scoring methods
66 #: (forward-pass compute + cache hits; excludes probe-side work).
67 scoring_wall_s: float = 0.0
68 #: S23 batched-backend-exec counters. ``batches_sent`` is the number
69 #: of batched forward calls the backend actually issued (i.e. misses
70 #: routed through the batched path; single-prompt calls and cache
71 #: hits don't increment). ``batched_prompts`` is the sum of batch
72 #: sizes, so ``batched_prompts / batches_sent`` = mean batch size.
73 #: A run with zero batches (older probes / opt-out probes only)
74 #: reports the avg as 0 via :attr:`avg_batch_size`.
75 batches_sent: int = 0
76 batched_prompts: int = 0
77 max_batch_size: int = 0
78
79 @property
80 def total_lookups(self) -> int:
81 return self.cache_hits + self.cache_misses
82
83 @property
84 def hit_rate(self) -> float:
85 total = self.total_lookups
86 return (self.cache_hits / total) if total > 0 else 0.0
87
88 @property
89 def avg_batch_size(self) -> float:
90 return (self.batched_prompts / self.batches_sent) if self.batches_sent > 0 else 0.0
91
92 def to_dict(self) -> dict[str, float | int]:
93 return {
94 "cache_hits": self.cache_hits,
95 "cache_misses": self.cache_misses,
96 "forward_passes": self.forward_passes,
97 "scoring_wall_s": self.scoring_wall_s,
98 "hit_rate": self.hit_rate,
99 "batches_sent": self.batches_sent,
100 "batched_prompts": self.batched_prompts,
101 "max_batch_size": self.max_batch_size,
102 "avg_batch_size": self.avg_batch_size,
103 }
104
105
106 class ForwardCache:
107 """Bounded LRU for backend forward passes, keyed on view + prompt.
108
109 We don't use :func:`functools.lru_cache` because the cacheable
110 unit isn't a function — it's a ``(op, view_id, prompt, top_k)``
111 tuple that spans multiple scoring methods. An explicit
112 :class:`~collections.OrderedDict` keeps insertion order and lets
113 us bump entries to the end on hit.
114
115 The key includes ``view_id`` (e.g. ``"base"``, ``"ft"``,
116 ``"scaled_1.25"``, ``"null_42"``) so a view transition is a
117 natural cache miss — no explicit invalidation needed. If a future
118 fix to the HF toggle corrupts base/ft weights in place, a new
119 view-id string would surface the drift immediately.
120 """
121
122 def __init__(self, maxsize: int = 256) -> None:
123 if maxsize <= 0:
124 raise ValueError(f"maxsize must be positive; got {maxsize}")
125 self._maxsize = maxsize
126 self._store: OrderedDict[tuple[Any, ...], Any] = OrderedDict()
127 self._lock = threading.Lock()
128
129 def get(self, key: tuple[Any, ...]) -> Any:
130 with self._lock:
131 try:
132 value = self._store[key]
133 except KeyError:
134 return _MISS
135 self._store.move_to_end(key)
136 return value
137
138 def put(self, key: tuple[Any, ...], value: Any) -> None:
139 with self._lock:
140 self._store[key] = value
141 self._store.move_to_end(key)
142 while len(self._store) > self._maxsize:
143 self._store.popitem(last=False)
144
145 def __len__(self) -> int:
146 return len(self._store)
147
148 def clear(self) -> None:
149 with self._lock:
150 self._store.clear()
151
152
153 #: Sentinel used to tell ``ForwardCache.get`` misses apart from legit
154 #: ``None`` values. Downstream callers check ``result is _MISS``.
155 _MISS: Any = object()
156
157
158 @dataclass(slots=True)
159 class _TraceEvent:
160 ts: float
161 probe: str | None
162 view_id: str
163 prompt_hash: str
164 top_k: int
165 op: str
166 wall_ms: float
167 hit: bool
168
169
170 class TraceWriter:
171 """Append-only JSONL writer for forward-pass traces.
172
173 Disabled mode (``path=None``) is a zero-overhead no-op: the
174 :meth:`write` method short-circuits on the first branch and the
175 backend pays one ``if`` per forward pass. Enabled mode opens the
176 file once and appends per event; thread-safe via a simple lock.
177 """
178
179 def __init__(self, path: Path | None) -> None:
180 self._path = path
181 self._fh: Any = None
182 self._lock = threading.Lock()
183 if path is not None:
184 path.parent.mkdir(parents=True, exist_ok=True)
185 self._fh = path.open("a", encoding="utf-8")
186
187 def write(self, event: _TraceEvent) -> None:
188 if self._fh is None:
189 return
190 payload = {
191 "ts": event.ts,
192 "probe": event.probe,
193 "view_id": event.view_id,
194 "prompt_hash": event.prompt_hash,
195 "top_k": event.top_k,
196 "op": event.op,
197 "wall_ms": event.wall_ms,
198 "hit": event.hit,
199 }
200 with self._lock:
201 self._fh.write(json.dumps(payload) + "\n")
202 self._fh.flush()
203
204 def close(self) -> None:
205 if self._fh is not None:
206 self._fh.close()
207 self._fh = None
208
209
210 class BackendInstrumentation:
211 """What a backend instance owns: one cache, one tracer, one stats.
212
213 Backend ``__init__`` constructs this; every view holds a ref.
214 The runner peeks at ``.stats`` at suite-end and ships the snapshot
215 in :class:`SuiteResult.backend_stats`.
216
217 Context about ``current_probe``: a backend view has no idea which
218 probe is calling it, but the runner does. The runner sets
219 :meth:`set_current_probe` before each probe runs so trace events
220 carry the right label; views don't care.
221 """
222
223 def __init__(
224 self,
225 *,
226 cache_maxsize: int = 256,
227 trace_path: Path | None = None,
228 ) -> None:
229 self.cache = ForwardCache(maxsize=cache_maxsize)
230 self.trace = TraceWriter(trace_path)
231 self.stats = BackendStats()
232 self._current_probe: str | None = None
233
234 def set_current_probe(self, name: str | None) -> None:
235 """Called by the runner between probes. Label for trace events."""
236 self._current_probe = name
237
238 def close(self) -> None:
239 self.trace.close()
240
241 def cached(
242 self,
243 op: CacheOp,
244 view_id: str,
245 prompt: str,
246 top_k: int,
247 compute: Callable[[], T],
248 ) -> T:
249 """Route a backend scoring call through the cache + tracer.
250
251 On cache hit, returns the cached value and increments
252 ``cache_hits``. On miss, runs ``compute()``, stores the
253 result, increments ``cache_misses`` + ``forward_passes``,
254 and (when tracing is enabled) appends a JSONL event.
255 """
256 prompt_hash = _prompt_hash(prompt)
257 key = (op, view_id, prompt_hash, top_k)
258 hit_start = time.perf_counter()
259 cached_value = self.cache.get(key)
260 if cached_value is not _MISS:
261 self.stats.cache_hits += 1
262 wall = time.perf_counter() - hit_start
263 self.stats.scoring_wall_s += wall
264 self.trace.write(
265 _TraceEvent(
266 ts=time.time(),
267 probe=self._current_probe,
268 view_id=view_id,
269 prompt_hash=prompt_hash,
270 top_k=top_k,
271 op=op,
272 wall_ms=wall * 1000.0,
273 hit=True,
274 )
275 )
276 return cached_value # type: ignore[no-any-return]
277
278 # Miss: compute, store, record.
279 compute_start = time.perf_counter()
280 value = compute()
281 wall = time.perf_counter() - compute_start
282 self.cache.put(key, value)
283 self.stats.cache_misses += 1
284 self.stats.forward_passes += 1
285 self.stats.scoring_wall_s += wall
286 self.trace.write(
287 _TraceEvent(
288 ts=time.time(),
289 probe=self._current_probe,
290 view_id=view_id,
291 prompt_hash=prompt_hash,
292 top_k=top_k,
293 op=op,
294 wall_ms=wall * 1000.0,
295 hit=False,
296 )
297 )
298 return value
299
300 def cached_batch(
301 self,
302 op: CacheOp,
303 view_id: str,
304 prompts: Sequence[str],
305 top_k: int,
306 compute_misses: Callable[[list[int]], list[Any]],
307 ) -> list[Any]:
308 """Route a batched scoring call through the cache + tracer.
309
310 Per-prompt cache lookup happens first; entries already in the
311 cache are served from it without ever entering the batch. The
312 ``compute_misses`` callback receives the list of indices into
313 ``prompts`` that missed, and is expected to return a list of
314 results *in the same order* — the backend is free to pad, call
315 ``model.forward`` once, and split the logits per row.
316
317 Counters incremented:
318 - ``cache_hits`` per cached prompt
319 - ``cache_misses`` + ``forward_passes`` per missed prompt
320 - ``batches_sent`` once per actual forward (only when
321 ``compute_misses`` is called, i.e. at least one miss)
322 - ``batched_prompts`` by the miss count
323 - ``max_batch_size`` updated to ``max(prev, miss_count)``
324
325 Trace events are emitted per prompt so the JSONL trace keeps
326 its per-prompt granularity regardless of how many rows the
327 backend packed into one GPU call.
328 """
329 results: list[Any] = [None] * len(prompts)
330 miss_indices: list[int] = []
331 prompt_hashes: list[str] = [_prompt_hash(p) for p in prompts]
332
333 # Pass 1: cache lookups.
334 for i, prompt_hash in enumerate(prompt_hashes):
335 key = (op, view_id, prompt_hash, top_k)
336 hit_start = time.perf_counter()
337 cached_value = self.cache.get(key)
338 if cached_value is not _MISS:
339 self.stats.cache_hits += 1
340 wall = time.perf_counter() - hit_start
341 self.stats.scoring_wall_s += wall
342 self.trace.write(
343 _TraceEvent(
344 ts=time.time(),
345 probe=self._current_probe,
346 view_id=view_id,
347 prompt_hash=prompt_hash,
348 top_k=top_k,
349 op=op,
350 wall_ms=wall * 1000.0,
351 hit=True,
352 )
353 )
354 results[i] = cached_value
355 else:
356 miss_indices.append(i)
357
358 # Pass 2: one forward call for the miss subset.
359 if miss_indices:
360 compute_start = time.perf_counter()
361 miss_values = compute_misses(miss_indices)
362 if len(miss_values) != len(miss_indices):
363 raise RuntimeError(
364 f"batched compute returned {len(miss_values)} values for "
365 f"{len(miss_indices)} misses — backend bug"
366 )
367 wall = time.perf_counter() - compute_start
368 # Divide wall time evenly across misses for
369 # scoring_wall_s bookkeeping; batched callers don't have a
370 # per-prompt attribution.
371 per_miss_wall = wall / len(miss_indices)
372 self.stats.scoring_wall_s += wall
373 self.stats.batches_sent += 1
374 self.stats.batched_prompts += len(miss_indices)
375 if len(miss_indices) > self.stats.max_batch_size:
376 self.stats.max_batch_size = len(miss_indices)
377 for miss_pos, idx in enumerate(miss_indices):
378 value = miss_values[miss_pos]
379 key = (op, view_id, prompt_hashes[idx], top_k)
380 self.cache.put(key, value)
381 self.stats.cache_misses += 1
382 self.stats.forward_passes += 1
383 self.trace.write(
384 _TraceEvent(
385 ts=time.time(),
386 probe=self._current_probe,
387 view_id=view_id,
388 prompt_hash=prompt_hashes[idx],
389 top_k=top_k,
390 op=op,
391 wall_ms=per_miss_wall * 1000.0,
392 hit=False,
393 )
394 )
395 results[idx] = value
396
397 return results
398
399
400 def _prompt_hash(prompt: str) -> str:
401 """Stable short hash for cache keys + trace logging.
402
403 SHA-1 truncated to 12 hex chars: enough to avoid collisions at
404 suite scale (thousands of distinct prompts per run), small enough
405 to keep cache memory down and JSONL lines readable.
406 """
407 import hashlib
408
409 return hashlib.sha1(prompt.encode("utf-8"), usedforsecurity=False).hexdigest()[:12]
410
411
412 __all__ = [
413 "BackendInstrumentation",
414 "BackendStats",
415 "ForwardCache",
416 "TraceWriter",
417 ]