| 1 |
"""Per-store tokenized-section cache. |
| 2 |
|
| 3 |
Avoid re-tokenizing unchanged directive-sourced files on every |
| 4 |
`dlm train` run. At 50K+ files this is the difference between an |
| 5 |
hour of retokenization and seconds of cache warm-up. |
| 6 |
|
| 7 |
Integration status: the trainer consults this cache during the |
| 8 |
pre-tokenization pass between `build_dataset` and SFTTrainer. Rows |
| 9 |
are tokenized to `input_ids` / `attention_mask`, cached by |
| 10 |
`(section_id, tokenizer_sha256, sequence_len)`, and then handed to |
| 11 |
SFTTrainer in pre-processed form so repeated runs can skip most |
| 12 |
tokenizer work without changing training behavior. |
| 13 |
|
| 14 |
Layout (per store): |
| 15 |
|
| 16 |
~/.dlm/store/<dlm_id>/tokenized-cache/ |
| 17 |
manifest.json { version, tokenizer_sha256, total_bytes, |
| 18 |
entries: {key_str: {size, last_access_ts, |
| 19 |
shard, filename}} } |
| 20 |
entries/ |
| 21 |
<section_id[:2]>/ sharded to avoid 50K files in one dir |
| 22 |
<...>.npz numpy save of input_ids + attention_mask |
| 23 |
|
| 24 |
Entries are keyed by `(section_id, tokenizer_sha, sequence_len)`. A |
| 25 |
change to any input produces a new key — stale entries are garbage- |
| 26 |
collected by `prune` when their last-access age exceeds the cutoff. |
| 27 |
|
| 28 |
Atomicity: `put` writes to a tmp file + `os.replace`; the manifest |
| 29 |
update is the last step, so a mid-put SIGTERM leaves no torn entries |
| 30 |
(the tmp file may orphan, `prune` sweeps it). |
| 31 |
|
| 32 |
LRU eviction fires on `put` when `total_bytes + incoming_size > |
| 33 |
max_bytes`. Oldest `last_access_ts` wins; **current-run entries are |
| 34 |
protected** so a cold cache doesn't self-starve. |
| 35 |
""" |
| 36 |
|
| 37 |
from __future__ import annotations |
| 38 |
|
| 39 |
import json |
| 40 |
import logging |
| 41 |
import time |
| 42 |
from dataclasses import dataclass |
| 43 |
from pathlib import Path |
| 44 |
from typing import Any |
| 45 |
|
| 46 |
import numpy as np |
| 47 |
|
| 48 |
from dlm.directives.cache_key import CacheKey |
| 49 |
|
| 50 |
_LOG = logging.getLogger(__name__) |
| 51 |
|
| 52 |
_CACHE_VERSION = 1 |
| 53 |
_MANIFEST_FILENAME = "manifest.json" |
| 54 |
_ENTRIES_DIR = "entries" |
| 55 |
_DEFAULT_MAX_BYTES = 10 * 1024 * 1024 * 1024 # 10 GiB |
| 56 |
|
| 57 |
|
| 58 |
@dataclass(frozen=True) |
| 59 |
class CachedTokens: |
| 60 |
"""Tokenizer output pulled from the cache. |
| 61 |
|
| 62 |
`input_ids` and `attention_mask` are 1D numpy int arrays matching |
| 63 |
what `tokenizer(text, truncation=True, padding=False, max_length=seq_len)` |
| 64 |
would return. Callers convert to torch tensors at the dataloader |
| 65 |
boundary. |
| 66 |
""" |
| 67 |
|
| 68 |
input_ids: np.ndarray |
| 69 |
attention_mask: np.ndarray |
| 70 |
|
| 71 |
|
| 72 |
@dataclass |
| 73 |
class _Entry: |
| 74 |
"""Manifest row for one cached tokenization. |
| 75 |
|
| 76 |
`key_str` is the canonical string form of the CacheKey. `size` |
| 77 |
is bytes on disk (best-effort stat). `last_access_ts` is Unix |
| 78 |
seconds as a float. |
| 79 |
""" |
| 80 |
|
| 81 |
key_str: str |
| 82 |
size: int |
| 83 |
last_access_ts: float |
| 84 |
shard: str |
| 85 |
filename: str |
| 86 |
tokenizer_sha: str |
| 87 |
|
| 88 |
|
| 89 |
def _key_str(key: CacheKey) -> str: |
| 90 |
"""Canonical str form: used as manifest dict key.""" |
| 91 |
return f"{key.section_id}|{key.tokenizer_sha}|{key.sequence_len}" |
| 92 |
|
| 93 |
|
| 94 |
class TokenizedCache: |
| 95 |
"""Per-store tokenized-section cache. |
| 96 |
|
| 97 |
Open via `TokenizedCache.open(store.tokenized_cache_dir)`. The |
| 98 |
constructor eagerly loads the manifest (cheap: one JSON file with |
| 99 |
N entries). `get` and `put` touch disk for the actual tensors. |
| 100 |
""" |
| 101 |
|
| 102 |
def __init__( |
| 103 |
self, |
| 104 |
root: Path, |
| 105 |
*, |
| 106 |
manifest: dict[str, _Entry], |
| 107 |
tokenizer_sha_hint: str | None = None, |
| 108 |
max_bytes: int = _DEFAULT_MAX_BYTES, |
| 109 |
) -> None: |
| 110 |
self._root = root |
| 111 |
self._manifest = manifest |
| 112 |
self._tokenizer_sha_hint = tokenizer_sha_hint |
| 113 |
self._max_bytes = max_bytes |
| 114 |
# Entries inserted or touched during this session — protected |
| 115 |
# from LRU eviction to avoid cold-cache self-starvation. |
| 116 |
self._touched_this_run: set[str] = set() |
| 117 |
# Counters for end-of-run metrics. |
| 118 |
self._hits = 0 |
| 119 |
self._misses = 0 |
| 120 |
|
| 121 |
# ---- Open / construct -------------------------------------------- |
| 122 |
|
| 123 |
@classmethod |
| 124 |
def open(cls, root: Path, *, max_bytes: int = _DEFAULT_MAX_BYTES) -> TokenizedCache: |
| 125 |
"""Open (or create) a cache at `root`. |
| 126 |
|
| 127 |
Creates the directory layout idempotently. Missing manifest → |
| 128 |
fresh empty cache. Corrupt manifest → log a WARN and start |
| 129 |
fresh, leaving any orphaned entry files to `prune` later. |
| 130 |
""" |
| 131 |
root.mkdir(parents=True, exist_ok=True) |
| 132 |
(root / _ENTRIES_DIR).mkdir(exist_ok=True) |
| 133 |
manifest_path = root / _MANIFEST_FILENAME |
| 134 |
|
| 135 |
if not manifest_path.is_file(): |
| 136 |
return cls(root=root, manifest={}, max_bytes=max_bytes) |
| 137 |
|
| 138 |
try: |
| 139 |
raw = json.loads(manifest_path.read_text(encoding="utf-8")) |
| 140 |
except (OSError, json.JSONDecodeError) as exc: |
| 141 |
_LOG.warning( |
| 142 |
"cache: manifest at %s unreadable (%s); starting fresh", |
| 143 |
manifest_path, |
| 144 |
exc, |
| 145 |
) |
| 146 |
return cls(root=root, manifest={}, max_bytes=max_bytes) |
| 147 |
|
| 148 |
if not isinstance(raw, dict) or raw.get("version") != _CACHE_VERSION: |
| 149 |
_LOG.warning( |
| 150 |
"cache: manifest version mismatch at %s; starting fresh", |
| 151 |
manifest_path, |
| 152 |
) |
| 153 |
return cls(root=root, manifest={}, max_bytes=max_bytes) |
| 154 |
|
| 155 |
entries_raw = raw.get("entries", {}) |
| 156 |
if not isinstance(entries_raw, dict): |
| 157 |
return cls(root=root, manifest={}, max_bytes=max_bytes) |
| 158 |
|
| 159 |
manifest: dict[str, _Entry] = {} |
| 160 |
for key_str, row in entries_raw.items(): |
| 161 |
if not isinstance(row, dict): |
| 162 |
continue |
| 163 |
try: |
| 164 |
manifest[key_str] = _Entry( |
| 165 |
key_str=key_str, |
| 166 |
size=int(row["size"]), |
| 167 |
last_access_ts=float(row["last_access_ts"]), |
| 168 |
shard=str(row["shard"]), |
| 169 |
filename=str(row["filename"]), |
| 170 |
tokenizer_sha=str(row.get("tokenizer_sha", "")), |
| 171 |
) |
| 172 |
except (KeyError, TypeError, ValueError) as exc: |
| 173 |
_LOG.warning("cache: skipping malformed entry %s: %s", key_str, exc) |
| 174 |
|
| 175 |
return cls( |
| 176 |
root=root, |
| 177 |
manifest=manifest, |
| 178 |
tokenizer_sha_hint=raw.get("tokenizer_sha256"), |
| 179 |
max_bytes=max_bytes, |
| 180 |
) |
| 181 |
|
| 182 |
# ---- Properties -------------------------------------------------- |
| 183 |
|
| 184 |
@property |
| 185 |
def root(self) -> Path: |
| 186 |
return self._root |
| 187 |
|
| 188 |
@property |
| 189 |
def total_bytes(self) -> int: |
| 190 |
return sum(e.size for e in self._manifest.values()) |
| 191 |
|
| 192 |
@property |
| 193 |
def entry_count(self) -> int: |
| 194 |
return len(self._manifest) |
| 195 |
|
| 196 |
@property |
| 197 |
def hits(self) -> int: |
| 198 |
return self._hits |
| 199 |
|
| 200 |
@property |
| 201 |
def misses(self) -> int: |
| 202 |
return self._misses |
| 203 |
|
| 204 |
@property |
| 205 |
def hit_rate(self) -> float: |
| 206 |
total = self._hits + self._misses |
| 207 |
return self._hits / total if total else 0.0 |
| 208 |
|
| 209 |
# ---- Get / Put --------------------------------------------------- |
| 210 |
|
| 211 |
def get(self, key: CacheKey) -> CachedTokens | None: |
| 212 |
"""Return cached tokens for `key` or None on miss.""" |
| 213 |
key_str = _key_str(key) |
| 214 |
entry = self._manifest.get(key_str) |
| 215 |
if entry is None: |
| 216 |
self._misses += 1 |
| 217 |
return None |
| 218 |
|
| 219 |
path = self._entry_path(entry) |
| 220 |
if not path.is_file(): |
| 221 |
# Manifest drift (file deleted under us) — treat as miss, |
| 222 |
# remove the stale manifest row so we don't re-hit it. |
| 223 |
_LOG.warning("cache: entry file missing for %s; re-tokenizing", key_str) |
| 224 |
del self._manifest[key_str] |
| 225 |
self._misses += 1 |
| 226 |
return None |
| 227 |
|
| 228 |
try: |
| 229 |
with np.load(path) as data: |
| 230 |
tokens = CachedTokens( |
| 231 |
input_ids=np.array(data["input_ids"], copy=True), |
| 232 |
attention_mask=np.array(data["attention_mask"], copy=True), |
| 233 |
) |
| 234 |
except (OSError, ValueError, KeyError) as exc: |
| 235 |
_LOG.warning("cache: corrupt entry %s (%s); re-tokenizing", key_str, exc) |
| 236 |
del self._manifest[key_str] |
| 237 |
self._misses += 1 |
| 238 |
return None |
| 239 |
|
| 240 |
entry.last_access_ts = time.time() |
| 241 |
self._touched_this_run.add(key_str) |
| 242 |
self._hits += 1 |
| 243 |
return tokens |
| 244 |
|
| 245 |
def put(self, key: CacheKey, tokens: CachedTokens) -> None: |
| 246 |
"""Write `tokens` to the cache under `key`. Evicts if needed.""" |
| 247 |
key_str = _key_str(key) |
| 248 |
shard = key.shard() |
| 249 |
filename = key.as_filename() |
| 250 |
shard_dir = self._root / _ENTRIES_DIR / shard |
| 251 |
shard_dir.mkdir(parents=True, exist_ok=True) |
| 252 |
final_path = shard_dir / filename |
| 253 |
tmp_path = shard_dir / f"{filename}.tmp" |
| 254 |
|
| 255 |
# Save to tmp, then atomic rename. np.savez_compressed writes |
| 256 |
# directly to the file handle, so we can't stream-then-rename |
| 257 |
# with a single call — open the tmp file manually. |
| 258 |
try: |
| 259 |
with tmp_path.open("wb") as f: |
| 260 |
np.savez_compressed( |
| 261 |
f, |
| 262 |
input_ids=tokens.input_ids, |
| 263 |
attention_mask=tokens.attention_mask, |
| 264 |
) |
| 265 |
except OSError as exc: |
| 266 |
_LOG.warning("cache: write failed for %s: %s; dropping entry", key_str, exc) |
| 267 |
tmp_path.unlink(missing_ok=True) |
| 268 |
return |
| 269 |
|
| 270 |
try: |
| 271 |
size = tmp_path.stat().st_size |
| 272 |
except OSError: |
| 273 |
size = 0 |
| 274 |
|
| 275 |
# Evict BEFORE replacing so we have the budget headroom. |
| 276 |
self._evict_if_needed(incoming_bytes=size) |
| 277 |
|
| 278 |
tmp_path.replace(final_path) |
| 279 |
|
| 280 |
self._manifest[key_str] = _Entry( |
| 281 |
key_str=key_str, |
| 282 |
size=size, |
| 283 |
last_access_ts=time.time(), |
| 284 |
shard=shard, |
| 285 |
filename=filename, |
| 286 |
tokenizer_sha=key.tokenizer_sha, |
| 287 |
) |
| 288 |
self._touched_this_run.add(key_str) |
| 289 |
|
| 290 |
# ---- Eviction / Prune / Clear ------------------------------------ |
| 291 |
|
| 292 |
def _evict_if_needed(self, *, incoming_bytes: int) -> None: |
| 293 |
"""Delete oldest entries until (total + incoming) ≤ max_bytes. |
| 294 |
|
| 295 |
Current-run entries are protected: a cold cache won't evict |
| 296 |
what it just put in to make room for the next put. |
| 297 |
""" |
| 298 |
budget = self._max_bytes - incoming_bytes |
| 299 |
if self.total_bytes <= budget: |
| 300 |
return |
| 301 |
|
| 302 |
# Candidates: entries not touched this run, sorted by age. |
| 303 |
candidates = sorted( |
| 304 |
(e for e in self._manifest.values() if e.key_str not in self._touched_this_run), |
| 305 |
key=lambda e: e.last_access_ts, |
| 306 |
) |
| 307 |
evicted = 0 |
| 308 |
freed = 0 |
| 309 |
for entry in candidates: |
| 310 |
if self.total_bytes <= budget: |
| 311 |
break |
| 312 |
path = self._entry_path(entry) |
| 313 |
path.unlink(missing_ok=True) |
| 314 |
del self._manifest[entry.key_str] |
| 315 |
evicted += 1 |
| 316 |
freed += entry.size |
| 317 |
if evicted: |
| 318 |
_LOG.info( |
| 319 |
"cache: evicted %d entries (%d bytes) to stay under %d", |
| 320 |
evicted, |
| 321 |
freed, |
| 322 |
self._max_bytes, |
| 323 |
) |
| 324 |
|
| 325 |
def prune(self, *, older_than_seconds: float) -> int: |
| 326 |
"""Delete entries whose `last_access_ts` is older than the cutoff. |
| 327 |
|
| 328 |
Returns the number of entries removed. Protected-set doesn't |
| 329 |
apply — `prune` is an explicit operator action, not a |
| 330 |
mid-put fallback. |
| 331 |
""" |
| 332 |
cutoff = time.time() - older_than_seconds |
| 333 |
stale_keys = [e.key_str for e in self._manifest.values() if e.last_access_ts < cutoff] |
| 334 |
for key_str in stale_keys: |
| 335 |
entry = self._manifest[key_str] |
| 336 |
self._entry_path(entry).unlink(missing_ok=True) |
| 337 |
del self._manifest[key_str] |
| 338 |
if stale_keys: |
| 339 |
_LOG.info( |
| 340 |
"cache: pruned %d entries older than %ds", len(stale_keys), older_than_seconds |
| 341 |
) |
| 342 |
return len(stale_keys) |
| 343 |
|
| 344 |
def clear(self) -> int: |
| 345 |
"""Delete every entry. Returns count removed.""" |
| 346 |
count = len(self._manifest) |
| 347 |
for entry in list(self._manifest.values()): |
| 348 |
self._entry_path(entry).unlink(missing_ok=True) |
| 349 |
self._manifest.clear() |
| 350 |
self._touched_this_run.clear() |
| 351 |
return count |
| 352 |
|
| 353 |
# ---- Manifest persistence ---------------------------------------- |
| 354 |
|
| 355 |
def save_manifest(self, *, tokenizer_sha: str | None = None) -> None: |
| 356 |
"""Persist the manifest atomically. |
| 357 |
|
| 358 |
Call at the end of a training run (or on explicit CLI |
| 359 |
commands). `tokenizer_sha` is stored at the top level so |
| 360 |
future opens can detect a tokenizer bump before reading |
| 361 |
entries. |
| 362 |
""" |
| 363 |
manifest_path = self._root / _MANIFEST_FILENAME |
| 364 |
tmp_path = manifest_path.with_suffix(".json.tmp") |
| 365 |
payload: dict[str, Any] = { |
| 366 |
"version": _CACHE_VERSION, |
| 367 |
"tokenizer_sha256": tokenizer_sha or self._tokenizer_sha_hint or "", |
| 368 |
"total_bytes": self.total_bytes, |
| 369 |
"entries": { |
| 370 |
e.key_str: { |
| 371 |
"size": e.size, |
| 372 |
"last_access_ts": e.last_access_ts, |
| 373 |
"shard": e.shard, |
| 374 |
"filename": e.filename, |
| 375 |
"tokenizer_sha": e.tokenizer_sha, |
| 376 |
} |
| 377 |
for e in self._manifest.values() |
| 378 |
}, |
| 379 |
} |
| 380 |
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2) + "\n", encoding="utf-8") |
| 381 |
tmp_path.replace(manifest_path) |
| 382 |
|
| 383 |
# ---- Helpers ----------------------------------------------------- |
| 384 |
|
| 385 |
def _entry_path(self, entry: _Entry) -> Path: |
| 386 |
return self._root / _ENTRIES_DIR / entry.shard / entry.filename |