Python · 13653 bytes Raw Blame History
1 """Per-store tokenized-section cache.
2
3 Avoid re-tokenizing unchanged directive-sourced files on every
4 `dlm train` run. At 50K+ files this is the difference between an
5 hour of retokenization and seconds of cache warm-up.
6
7 Integration status: the trainer consults this cache during the
8 pre-tokenization pass between `build_dataset` and SFTTrainer. Rows
9 are tokenized to `input_ids` / `attention_mask`, cached by
10 `(section_id, tokenizer_sha256, sequence_len)`, and then handed to
11 SFTTrainer in pre-processed form so repeated runs can skip most
12 tokenizer work without changing training behavior.
13
14 Layout (per store):
15
16 ~/.dlm/store/<dlm_id>/tokenized-cache/
17 manifest.json { version, tokenizer_sha256, total_bytes,
18 entries: {key_str: {size, last_access_ts,
19 shard, filename}} }
20 entries/
21 <section_id[:2]>/ sharded to avoid 50K files in one dir
22 <...>.npz numpy save of input_ids + attention_mask
23
24 Entries are keyed by `(section_id, tokenizer_sha, sequence_len)`. A
25 change to any input produces a new key — stale entries are garbage-
26 collected by `prune` when their last-access age exceeds the cutoff.
27
28 Atomicity: `put` writes to a tmp file + `os.replace`; the manifest
29 update is the last step, so a mid-put SIGTERM leaves no torn entries
30 (the tmp file may orphan, `prune` sweeps it).
31
32 LRU eviction fires on `put` when `total_bytes + incoming_size >
33 max_bytes`. Oldest `last_access_ts` wins; **current-run entries are
34 protected** so a cold cache doesn't self-starve.
35 """
36
37 from __future__ import annotations
38
39 import json
40 import logging
41 import time
42 from dataclasses import dataclass
43 from pathlib import Path
44 from typing import Any
45
46 import numpy as np
47
48 from dlm.directives.cache_key import CacheKey
49
50 _LOG = logging.getLogger(__name__)
51
52 _CACHE_VERSION = 1
53 _MANIFEST_FILENAME = "manifest.json"
54 _ENTRIES_DIR = "entries"
55 _DEFAULT_MAX_BYTES = 10 * 1024 * 1024 * 1024 # 10 GiB
56
57
58 @dataclass(frozen=True)
59 class CachedTokens:
60 """Tokenizer output pulled from the cache.
61
62 `input_ids` and `attention_mask` are 1D numpy int arrays matching
63 what `tokenizer(text, truncation=True, padding=False, max_length=seq_len)`
64 would return. Callers convert to torch tensors at the dataloader
65 boundary.
66 """
67
68 input_ids: np.ndarray
69 attention_mask: np.ndarray
70
71
72 @dataclass
73 class _Entry:
74 """Manifest row for one cached tokenization.
75
76 `key_str` is the canonical string form of the CacheKey. `size`
77 is bytes on disk (best-effort stat). `last_access_ts` is Unix
78 seconds as a float.
79 """
80
81 key_str: str
82 size: int
83 last_access_ts: float
84 shard: str
85 filename: str
86 tokenizer_sha: str
87
88
89 def _key_str(key: CacheKey) -> str:
90 """Canonical str form: used as manifest dict key."""
91 return f"{key.section_id}|{key.tokenizer_sha}|{key.sequence_len}"
92
93
94 class TokenizedCache:
95 """Per-store tokenized-section cache.
96
97 Open via `TokenizedCache.open(store.tokenized_cache_dir)`. The
98 constructor eagerly loads the manifest (cheap: one JSON file with
99 N entries). `get` and `put` touch disk for the actual tensors.
100 """
101
102 def __init__(
103 self,
104 root: Path,
105 *,
106 manifest: dict[str, _Entry],
107 tokenizer_sha_hint: str | None = None,
108 max_bytes: int = _DEFAULT_MAX_BYTES,
109 ) -> None:
110 self._root = root
111 self._manifest = manifest
112 self._tokenizer_sha_hint = tokenizer_sha_hint
113 self._max_bytes = max_bytes
114 # Entries inserted or touched during this session — protected
115 # from LRU eviction to avoid cold-cache self-starvation.
116 self._touched_this_run: set[str] = set()
117 # Counters for end-of-run metrics.
118 self._hits = 0
119 self._misses = 0
120
121 # ---- Open / construct --------------------------------------------
122
123 @classmethod
124 def open(cls, root: Path, *, max_bytes: int = _DEFAULT_MAX_BYTES) -> TokenizedCache:
125 """Open (or create) a cache at `root`.
126
127 Creates the directory layout idempotently. Missing manifest →
128 fresh empty cache. Corrupt manifest → log a WARN and start
129 fresh, leaving any orphaned entry files to `prune` later.
130 """
131 root.mkdir(parents=True, exist_ok=True)
132 (root / _ENTRIES_DIR).mkdir(exist_ok=True)
133 manifest_path = root / _MANIFEST_FILENAME
134
135 if not manifest_path.is_file():
136 return cls(root=root, manifest={}, max_bytes=max_bytes)
137
138 try:
139 raw = json.loads(manifest_path.read_text(encoding="utf-8"))
140 except (OSError, json.JSONDecodeError) as exc:
141 _LOG.warning(
142 "cache: manifest at %s unreadable (%s); starting fresh",
143 manifest_path,
144 exc,
145 )
146 return cls(root=root, manifest={}, max_bytes=max_bytes)
147
148 if not isinstance(raw, dict) or raw.get("version") != _CACHE_VERSION:
149 _LOG.warning(
150 "cache: manifest version mismatch at %s; starting fresh",
151 manifest_path,
152 )
153 return cls(root=root, manifest={}, max_bytes=max_bytes)
154
155 entries_raw = raw.get("entries", {})
156 if not isinstance(entries_raw, dict):
157 return cls(root=root, manifest={}, max_bytes=max_bytes)
158
159 manifest: dict[str, _Entry] = {}
160 for key_str, row in entries_raw.items():
161 if not isinstance(row, dict):
162 continue
163 try:
164 manifest[key_str] = _Entry(
165 key_str=key_str,
166 size=int(row["size"]),
167 last_access_ts=float(row["last_access_ts"]),
168 shard=str(row["shard"]),
169 filename=str(row["filename"]),
170 tokenizer_sha=str(row.get("tokenizer_sha", "")),
171 )
172 except (KeyError, TypeError, ValueError) as exc:
173 _LOG.warning("cache: skipping malformed entry %s: %s", key_str, exc)
174
175 return cls(
176 root=root,
177 manifest=manifest,
178 tokenizer_sha_hint=raw.get("tokenizer_sha256"),
179 max_bytes=max_bytes,
180 )
181
182 # ---- Properties --------------------------------------------------
183
184 @property
185 def root(self) -> Path:
186 return self._root
187
188 @property
189 def total_bytes(self) -> int:
190 return sum(e.size for e in self._manifest.values())
191
192 @property
193 def entry_count(self) -> int:
194 return len(self._manifest)
195
196 @property
197 def hits(self) -> int:
198 return self._hits
199
200 @property
201 def misses(self) -> int:
202 return self._misses
203
204 @property
205 def hit_rate(self) -> float:
206 total = self._hits + self._misses
207 return self._hits / total if total else 0.0
208
209 # ---- Get / Put ---------------------------------------------------
210
211 def get(self, key: CacheKey) -> CachedTokens | None:
212 """Return cached tokens for `key` or None on miss."""
213 key_str = _key_str(key)
214 entry = self._manifest.get(key_str)
215 if entry is None:
216 self._misses += 1
217 return None
218
219 path = self._entry_path(entry)
220 if not path.is_file():
221 # Manifest drift (file deleted under us) — treat as miss,
222 # remove the stale manifest row so we don't re-hit it.
223 _LOG.warning("cache: entry file missing for %s; re-tokenizing", key_str)
224 del self._manifest[key_str]
225 self._misses += 1
226 return None
227
228 try:
229 with np.load(path) as data:
230 tokens = CachedTokens(
231 input_ids=np.array(data["input_ids"], copy=True),
232 attention_mask=np.array(data["attention_mask"], copy=True),
233 )
234 except (OSError, ValueError, KeyError) as exc:
235 _LOG.warning("cache: corrupt entry %s (%s); re-tokenizing", key_str, exc)
236 del self._manifest[key_str]
237 self._misses += 1
238 return None
239
240 entry.last_access_ts = time.time()
241 self._touched_this_run.add(key_str)
242 self._hits += 1
243 return tokens
244
245 def put(self, key: CacheKey, tokens: CachedTokens) -> None:
246 """Write `tokens` to the cache under `key`. Evicts if needed."""
247 key_str = _key_str(key)
248 shard = key.shard()
249 filename = key.as_filename()
250 shard_dir = self._root / _ENTRIES_DIR / shard
251 shard_dir.mkdir(parents=True, exist_ok=True)
252 final_path = shard_dir / filename
253 tmp_path = shard_dir / f"{filename}.tmp"
254
255 # Save to tmp, then atomic rename. np.savez_compressed writes
256 # directly to the file handle, so we can't stream-then-rename
257 # with a single call — open the tmp file manually.
258 try:
259 with tmp_path.open("wb") as f:
260 np.savez_compressed(
261 f,
262 input_ids=tokens.input_ids,
263 attention_mask=tokens.attention_mask,
264 )
265 except OSError as exc:
266 _LOG.warning("cache: write failed for %s: %s; dropping entry", key_str, exc)
267 tmp_path.unlink(missing_ok=True)
268 return
269
270 try:
271 size = tmp_path.stat().st_size
272 except OSError:
273 size = 0
274
275 # Evict BEFORE replacing so we have the budget headroom.
276 self._evict_if_needed(incoming_bytes=size)
277
278 tmp_path.replace(final_path)
279
280 self._manifest[key_str] = _Entry(
281 key_str=key_str,
282 size=size,
283 last_access_ts=time.time(),
284 shard=shard,
285 filename=filename,
286 tokenizer_sha=key.tokenizer_sha,
287 )
288 self._touched_this_run.add(key_str)
289
290 # ---- Eviction / Prune / Clear ------------------------------------
291
292 def _evict_if_needed(self, *, incoming_bytes: int) -> None:
293 """Delete oldest entries until (total + incoming) ≤ max_bytes.
294
295 Current-run entries are protected: a cold cache won't evict
296 what it just put in to make room for the next put.
297 """
298 budget = self._max_bytes - incoming_bytes
299 if self.total_bytes <= budget:
300 return
301
302 # Candidates: entries not touched this run, sorted by age.
303 candidates = sorted(
304 (e for e in self._manifest.values() if e.key_str not in self._touched_this_run),
305 key=lambda e: e.last_access_ts,
306 )
307 evicted = 0
308 freed = 0
309 for entry in candidates:
310 if self.total_bytes <= budget:
311 break
312 path = self._entry_path(entry)
313 path.unlink(missing_ok=True)
314 del self._manifest[entry.key_str]
315 evicted += 1
316 freed += entry.size
317 if evicted:
318 _LOG.info(
319 "cache: evicted %d entries (%d bytes) to stay under %d",
320 evicted,
321 freed,
322 self._max_bytes,
323 )
324
325 def prune(self, *, older_than_seconds: float) -> int:
326 """Delete entries whose `last_access_ts` is older than the cutoff.
327
328 Returns the number of entries removed. Protected-set doesn't
329 apply — `prune` is an explicit operator action, not a
330 mid-put fallback.
331 """
332 cutoff = time.time() - older_than_seconds
333 stale_keys = [e.key_str for e in self._manifest.values() if e.last_access_ts < cutoff]
334 for key_str in stale_keys:
335 entry = self._manifest[key_str]
336 self._entry_path(entry).unlink(missing_ok=True)
337 del self._manifest[key_str]
338 if stale_keys:
339 _LOG.info(
340 "cache: pruned %d entries older than %ds", len(stale_keys), older_than_seconds
341 )
342 return len(stale_keys)
343
344 def clear(self) -> int:
345 """Delete every entry. Returns count removed."""
346 count = len(self._manifest)
347 for entry in list(self._manifest.values()):
348 self._entry_path(entry).unlink(missing_ok=True)
349 self._manifest.clear()
350 self._touched_this_run.clear()
351 return count
352
353 # ---- Manifest persistence ----------------------------------------
354
355 def save_manifest(self, *, tokenizer_sha: str | None = None) -> None:
356 """Persist the manifest atomically.
357
358 Call at the end of a training run (or on explicit CLI
359 commands). `tokenizer_sha` is stored at the top level so
360 future opens can detect a tokenizer bump before reading
361 entries.
362 """
363 manifest_path = self._root / _MANIFEST_FILENAME
364 tmp_path = manifest_path.with_suffix(".json.tmp")
365 payload: dict[str, Any] = {
366 "version": _CACHE_VERSION,
367 "tokenizer_sha256": tokenizer_sha or self._tokenizer_sha_hint or "",
368 "total_bytes": self.total_bytes,
369 "entries": {
370 e.key_str: {
371 "size": e.size,
372 "last_access_ts": e.last_access_ts,
373 "shard": e.shard,
374 "filename": e.filename,
375 "tokenizer_sha": e.tokenizer_sha,
376 }
377 for e in self._manifest.values()
378 },
379 }
380 tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2) + "\n", encoding="utf-8")
381 tmp_path.replace(manifest_path)
382
383 # ---- Helpers -----------------------------------------------------
384
385 def _entry_path(self, entry: _Entry) -> Path:
386 return self._root / _ENTRIES_DIR / entry.shard / entry.filename