| 1 |
"""VL preprocessor tensor cache. |
| 2 |
|
| 3 |
Keyed on `(blob_sha, processor_sha, target_size)` — a blob-bytes |
| 4 |
change, a processor upgrade, or a resize-policy bump each invalidate |
| 5 |
the entry. Orthogonal to the tokenized-section cache: different |
| 6 |
inputs, different consumers, different keys. |
| 7 |
|
| 8 |
Layout: `<vl-cache>/<blob_sha[:2]>/<blob_sha>.<proc_sha[:12]>.<h>x<w>.npz`. |
| 9 |
Contents: single numpy array stored under the key `pixel_values`. |
| 10 |
Atomic write via `dlm.io.atomic.write_bytes` so a half-written file |
| 11 |
never surfaces to a concurrent reader. |
| 12 |
|
| 13 |
Processor identity (`processor_sha`) is derived from the subset of |
| 14 |
attributes that materially change pixel output: `image_size`, |
| 15 |
`image_mean`, `image_std`, and the class name. That's enough to |
| 16 |
invalidate when a user upgrades HF transformers + the processor bumps |
| 17 |
its normalization constants; full byte-level fingerprinting of the |
| 18 |
processor isn't practical (processors aren't as JSON-clean as fast |
| 19 |
tokenizers are). |
| 20 |
""" |
| 21 |
|
| 22 |
from __future__ import annotations |
| 23 |
|
| 24 |
import contextlib |
| 25 |
import hashlib |
| 26 |
import io |
| 27 |
import json |
| 28 |
from dataclasses import dataclass |
| 29 |
from pathlib import Path |
| 30 |
from typing import Any, Final |
| 31 |
|
| 32 |
import numpy as np |
| 33 |
|
| 34 |
from dlm.io.atomic import write_bytes |
| 35 |
|
| 36 |
_FINGERPRINT_ATTR: Final[str] = "_dlm_processor_sha256" |
| 37 |
|
| 38 |
|
| 39 |
@dataclass(frozen=True) |
| 40 |
class VlCacheKey: |
| 41 |
"""Composite key for one preprocessed image tensor.""" |
| 42 |
|
| 43 |
blob_sha: str |
| 44 |
processor_sha: str |
| 45 |
target_height: int |
| 46 |
target_width: int |
| 47 |
|
| 48 |
def as_filename(self) -> str: |
| 49 |
"""Stable per-entry filename under the shard.""" |
| 50 |
return ( |
| 51 |
f"{self.blob_sha}.{self.processor_sha[:12]}" |
| 52 |
f".{self.target_height}x{self.target_width}.npz" |
| 53 |
) |
| 54 |
|
| 55 |
def shard(self) -> str: |
| 56 |
"""First 2 hex chars of blob_sha — the directory shard.""" |
| 57 |
return self.blob_sha[:2] |
| 58 |
|
| 59 |
|
| 60 |
class VlCache: |
| 61 |
"""On-disk cache for preprocessed image tensors. |
| 62 |
|
| 63 |
Lazy-initialized: constructing a `VlCache` does not create the |
| 64 |
directory. The first `put` creates the root + shard on demand. |
| 65 |
""" |
| 66 |
|
| 67 |
def __init__(self, root: Path) -> None: |
| 68 |
self._root = root |
| 69 |
|
| 70 |
@property |
| 71 |
def root(self) -> Path: |
| 72 |
return self._root |
| 73 |
|
| 74 |
def path_for(self, key: VlCacheKey) -> Path: |
| 75 |
return self._root / key.shard() / key.as_filename() |
| 76 |
|
| 77 |
def get(self, key: VlCacheKey) -> np.ndarray | None: |
| 78 |
"""Return the cached tensor, or `None` on miss.""" |
| 79 |
path = self.path_for(key) |
| 80 |
if not path.exists(): |
| 81 |
return None |
| 82 |
try: |
| 83 |
with np.load(path) as npz: |
| 84 |
arr: np.ndarray = npz["pixel_values"].copy() |
| 85 |
return arr |
| 86 |
except (OSError, KeyError, ValueError): |
| 87 |
# Corrupt cache entry — treat as miss so the trainer can |
| 88 |
# re-tokenize. The stale file stays on disk for `dlm cache |
| 89 |
# clear` to sweep rather than racing a delete here. |
| 90 |
return None |
| 91 |
|
| 92 |
def put(self, key: VlCacheKey, tensor: np.ndarray) -> Path: |
| 93 |
"""Atomically write `tensor` under `key`; return the on-disk path.""" |
| 94 |
path = self.path_for(key) |
| 95 |
path.parent.mkdir(parents=True, exist_ok=True) |
| 96 |
buffer = io.BytesIO() |
| 97 |
np.savez(buffer, pixel_values=tensor) |
| 98 |
write_bytes(path, buffer.getvalue()) |
| 99 |
return path |
| 100 |
|
| 101 |
def exists(self, key: VlCacheKey) -> bool: |
| 102 |
return self.path_for(key).exists() |
| 103 |
|
| 104 |
def clear(self) -> None: |
| 105 |
"""Delete the entire cache tree. Test + opt-in user action only.""" |
| 106 |
if self._root.exists(): |
| 107 |
import shutil |
| 108 |
|
| 109 |
shutil.rmtree(self._root) |
| 110 |
|
| 111 |
|
| 112 |
def processor_sha256(processor: Any) -> str: |
| 113 |
"""Canonical sha256 of the identity-bearing subset of a processor. |
| 114 |
|
| 115 |
HF `AutoProcessor` instances aren't JSON-serializable, so we |
| 116 |
fingerprint the attributes that actually drive pixel output: |
| 117 |
`image_size` (or `size` mapping), `image_mean`, `image_std`, and |
| 118 |
the class name. A future bump in any of these invalidates the |
| 119 |
cache exactly like a tokenizer-fingerprint change does for text. |
| 120 |
|
| 121 |
Pinned on the processor instance via a private attribute for O(1) |
| 122 |
repeat calls within a run. |
| 123 |
""" |
| 124 |
pinned: str | None = getattr(processor, _FINGERPRINT_ATTR, None) |
| 125 |
if pinned is not None: |
| 126 |
return pinned |
| 127 |
|
| 128 |
image_processor = getattr(processor, "image_processor", processor) |
| 129 |
state: dict[str, object] = { |
| 130 |
"class": processor.__class__.__name__, |
| 131 |
"image_size": _readable(getattr(image_processor, "image_size", None)), |
| 132 |
"size": _readable(getattr(image_processor, "size", None)), |
| 133 |
"image_mean": _readable(getattr(image_processor, "image_mean", None)), |
| 134 |
"image_std": _readable(getattr(image_processor, "image_std", None)), |
| 135 |
"do_normalize": bool(getattr(image_processor, "do_normalize", True)), |
| 136 |
"do_rescale": bool(getattr(image_processor, "do_rescale", True)), |
| 137 |
"rescale_factor": _readable(getattr(image_processor, "rescale_factor", None)), |
| 138 |
"resample": _readable(getattr(image_processor, "resample", None)), |
| 139 |
} |
| 140 |
canonical = json.dumps(state, sort_keys=True, default=str) |
| 141 |
sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest() |
| 142 |
with contextlib.suppress(AttributeError, TypeError): |
| 143 |
object.__setattr__(processor, _FINGERPRINT_ATTR, sha) |
| 144 |
return sha |
| 145 |
|
| 146 |
|
| 147 |
def _readable(value: object) -> object: |
| 148 |
"""Coerce a value into a JSON-serializable form. |
| 149 |
|
| 150 |
HF processors use mixed types — ints, floats, lists, dicts, enum |
| 151 |
members (`PILImageResampling`). Stringify exotic types so the |
| 152 |
fingerprint stays stable across HF version bumps that rewrap an |
| 153 |
int as an enum member. |
| 154 |
""" |
| 155 |
if value is None: |
| 156 |
return None |
| 157 |
if isinstance(value, bool | int | float | str): |
| 158 |
return value |
| 159 |
if isinstance(value, list | tuple): |
| 160 |
return [_readable(v) for v in value] |
| 161 |
if isinstance(value, dict): |
| 162 |
return {str(k): _readable(v) for k, v in sorted(value.items())} |
| 163 |
return str(value) |