Python · 5934 bytes Raw Blame History
1 """VL preprocessor tensor cache.
2
3 Keyed on `(blob_sha, processor_sha, target_size)` — a blob-bytes
4 change, a processor upgrade, or a resize-policy bump each invalidate
5 the entry. Orthogonal to the tokenized-section cache: different
6 inputs, different consumers, different keys.
7
8 Layout: `<vl-cache>/<blob_sha[:2]>/<blob_sha>.<proc_sha[:12]>.<h>x<w>.npz`.
9 Contents: single numpy array stored under the key `pixel_values`.
10 Atomic write via `dlm.io.atomic.write_bytes` so a half-written file
11 never surfaces to a concurrent reader.
12
13 Processor identity (`processor_sha`) is derived from the subset of
14 attributes that materially change pixel output: `image_size`,
15 `image_mean`, `image_std`, and the class name. That's enough to
16 invalidate when a user upgrades HF transformers + the processor bumps
17 its normalization constants; full byte-level fingerprinting of the
18 processor isn't practical (processors aren't as JSON-clean as fast
19 tokenizers are).
20 """
21
22 from __future__ import annotations
23
24 import contextlib
25 import hashlib
26 import io
27 import json
28 from dataclasses import dataclass
29 from pathlib import Path
30 from typing import Any, Final
31
32 import numpy as np
33
34 from dlm.io.atomic import write_bytes
35
36 _FINGERPRINT_ATTR: Final[str] = "_dlm_processor_sha256"
37
38
39 @dataclass(frozen=True)
40 class VlCacheKey:
41 """Composite key for one preprocessed image tensor."""
42
43 blob_sha: str
44 processor_sha: str
45 target_height: int
46 target_width: int
47
48 def as_filename(self) -> str:
49 """Stable per-entry filename under the shard."""
50 return (
51 f"{self.blob_sha}.{self.processor_sha[:12]}"
52 f".{self.target_height}x{self.target_width}.npz"
53 )
54
55 def shard(self) -> str:
56 """First 2 hex chars of blob_sha — the directory shard."""
57 return self.blob_sha[:2]
58
59
60 class VlCache:
61 """On-disk cache for preprocessed image tensors.
62
63 Lazy-initialized: constructing a `VlCache` does not create the
64 directory. The first `put` creates the root + shard on demand.
65 """
66
67 def __init__(self, root: Path) -> None:
68 self._root = root
69
70 @property
71 def root(self) -> Path:
72 return self._root
73
74 def path_for(self, key: VlCacheKey) -> Path:
75 return self._root / key.shard() / key.as_filename()
76
77 def get(self, key: VlCacheKey) -> np.ndarray | None:
78 """Return the cached tensor, or `None` on miss."""
79 path = self.path_for(key)
80 if not path.exists():
81 return None
82 try:
83 with np.load(path) as npz:
84 arr: np.ndarray = npz["pixel_values"].copy()
85 return arr
86 except (OSError, KeyError, ValueError):
87 # Corrupt cache entry — treat as miss so the trainer can
88 # re-tokenize. The stale file stays on disk for `dlm cache
89 # clear` to sweep rather than racing a delete here.
90 return None
91
92 def put(self, key: VlCacheKey, tensor: np.ndarray) -> Path:
93 """Atomically write `tensor` under `key`; return the on-disk path."""
94 path = self.path_for(key)
95 path.parent.mkdir(parents=True, exist_ok=True)
96 buffer = io.BytesIO()
97 np.savez(buffer, pixel_values=tensor)
98 write_bytes(path, buffer.getvalue())
99 return path
100
101 def exists(self, key: VlCacheKey) -> bool:
102 return self.path_for(key).exists()
103
104 def clear(self) -> None:
105 """Delete the entire cache tree. Test + opt-in user action only."""
106 if self._root.exists():
107 import shutil
108
109 shutil.rmtree(self._root)
110
111
112 def processor_sha256(processor: Any) -> str:
113 """Canonical sha256 of the identity-bearing subset of a processor.
114
115 HF `AutoProcessor` instances aren't JSON-serializable, so we
116 fingerprint the attributes that actually drive pixel output:
117 `image_size` (or `size` mapping), `image_mean`, `image_std`, and
118 the class name. A future bump in any of these invalidates the
119 cache exactly like a tokenizer-fingerprint change does for text.
120
121 Pinned on the processor instance via a private attribute for O(1)
122 repeat calls within a run.
123 """
124 pinned: str | None = getattr(processor, _FINGERPRINT_ATTR, None)
125 if pinned is not None:
126 return pinned
127
128 image_processor = getattr(processor, "image_processor", processor)
129 state: dict[str, object] = {
130 "class": processor.__class__.__name__,
131 "image_size": _readable(getattr(image_processor, "image_size", None)),
132 "size": _readable(getattr(image_processor, "size", None)),
133 "image_mean": _readable(getattr(image_processor, "image_mean", None)),
134 "image_std": _readable(getattr(image_processor, "image_std", None)),
135 "do_normalize": bool(getattr(image_processor, "do_normalize", True)),
136 "do_rescale": bool(getattr(image_processor, "do_rescale", True)),
137 "rescale_factor": _readable(getattr(image_processor, "rescale_factor", None)),
138 "resample": _readable(getattr(image_processor, "resample", None)),
139 }
140 canonical = json.dumps(state, sort_keys=True, default=str)
141 sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest()
142 with contextlib.suppress(AttributeError, TypeError):
143 object.__setattr__(processor, _FINGERPRINT_ATTR, sha)
144 return sha
145
146
147 def _readable(value: object) -> object:
148 """Coerce a value into a JSON-serializable form.
149
150 HF processors use mixed types — ints, floats, lists, dicts, enum
151 members (`PILImageResampling`). Stringify exotic types so the
152 fingerprint stays stable across HF version bumps that rewrap an
153 int as an enum member.
154 """
155 if value is None:
156 return None
157 if isinstance(value, bool | int | float | str):
158 return value
159 if isinstance(value, list | tuple):
160 return [_readable(v) for v in value]
161 if isinstance(value, dict):
162 return {str(k): _readable(v) for k, v in sorted(value.items())}
163 return str(value)