| 1 |
"""Audio preprocessor tensor cache. |
| 2 |
|
| 3 |
Mirrors `vl_cache.py`. Keyed on |
| 4 |
`(blob_sha, processor_sha, sample_rate, max_length_seconds)` — a blob |
| 5 |
bytes change, a processor / feature-extractor upgrade, a sample-rate |
| 6 |
pin change, or a duration-cap change each invalidate the entry. |
| 7 |
Orthogonal to the tokenized-section cache (which is keyed on |
| 8 |
tokenizer sha, not audio processor sha). |
| 9 |
|
| 10 |
Layout: `<audio-cache>/<blob_sha[:2]>/<blob_sha>.<proc_sha[:12]>.<sr>.<ms>.npz`. |
| 11 |
Contents: a single numpy array stored under the key `input_features`. |
| 12 |
Atomic write via `dlm.io.atomic.write_bytes` so a half-written file |
| 13 |
never surfaces to a concurrent reader. |
| 14 |
|
| 15 |
Processor identity (`processor_sha`) fingerprints the subset of |
| 16 |
feature-extractor attributes that materially change output features: |
| 17 |
`sampling_rate`, `feature_size`, `n_fft`, `hop_length`, `padding_value`, |
| 18 |
and the class name. Full byte-level fingerprinting of the HF processor |
| 19 |
isn't practical (they aren't JSON-clean); these fields match what the |
| 20 |
Qwen2-Audio + Whisper-family feature extractors expose and what drift |
| 21 |
between upstream revisions. |
| 22 |
""" |
| 23 |
|
| 24 |
from __future__ import annotations |
| 25 |
|
| 26 |
import contextlib |
| 27 |
import hashlib |
| 28 |
import io |
| 29 |
import json |
| 30 |
from dataclasses import dataclass |
| 31 |
from pathlib import Path |
| 32 |
from typing import Any, Final |
| 33 |
|
| 34 |
import numpy as np |
| 35 |
|
| 36 |
from dlm.io.atomic import write_bytes |
| 37 |
|
| 38 |
_FINGERPRINT_ATTR: Final[str] = "_dlm_audio_processor_sha256" |
| 39 |
|
| 40 |
|
| 41 |
@dataclass(frozen=True) |
| 42 |
class AudioCacheKey: |
| 43 |
"""Composite key for one preprocessed audio tensor. |
| 44 |
|
| 45 |
`auto_resample` lands on the key (not just the preprocessor path) |
| 46 |
so a cached entry built without resampling isn't served to a |
| 47 |
caller that asked for auto-resample — the inputs to the processor |
| 48 |
differ when the source rate disagreed with the target. |
| 49 |
""" |
| 50 |
|
| 51 |
blob_sha: str |
| 52 |
processor_sha: str |
| 53 |
sample_rate: int |
| 54 |
max_length_ms: int |
| 55 |
auto_resample: bool = False |
| 56 |
|
| 57 |
def as_filename(self) -> str: |
| 58 |
"""Stable per-entry filename under the shard.""" |
| 59 |
rs = ".rs" if self.auto_resample else "" |
| 60 |
return ( |
| 61 |
f"{self.blob_sha}.{self.processor_sha[:12]}" |
| 62 |
f".{self.sample_rate}.{self.max_length_ms}{rs}.npz" |
| 63 |
) |
| 64 |
|
| 65 |
def shard(self) -> str: |
| 66 |
"""First 2 hex chars of blob_sha — the directory shard.""" |
| 67 |
return self.blob_sha[:2] |
| 68 |
|
| 69 |
|
| 70 |
class AudioCache: |
| 71 |
"""On-disk cache for preprocessed audio feature tensors. |
| 72 |
|
| 73 |
Lazy-initialized: constructing an `AudioCache` does not create the |
| 74 |
directory. The first `put` creates the root + shard on demand. |
| 75 |
""" |
| 76 |
|
| 77 |
def __init__(self, root: Path) -> None: |
| 78 |
self._root = root |
| 79 |
|
| 80 |
@property |
| 81 |
def root(self) -> Path: |
| 82 |
return self._root |
| 83 |
|
| 84 |
def path_for(self, key: AudioCacheKey) -> Path: |
| 85 |
return self._root / key.shard() / key.as_filename() |
| 86 |
|
| 87 |
def get(self, key: AudioCacheKey) -> np.ndarray | None: |
| 88 |
"""Return the cached tensor, or `None` on miss.""" |
| 89 |
path = self.path_for(key) |
| 90 |
if not path.exists(): |
| 91 |
return None |
| 92 |
try: |
| 93 |
with np.load(path) as npz: |
| 94 |
arr: np.ndarray = npz["input_features"].copy() |
| 95 |
return arr |
| 96 |
except (OSError, KeyError, ValueError): |
| 97 |
# Corrupt entry — treat as miss; `dlm cache clear` sweeps. |
| 98 |
return None |
| 99 |
|
| 100 |
def put(self, key: AudioCacheKey, tensor: np.ndarray) -> Path: |
| 101 |
"""Atomically write `tensor` under `key`; return the on-disk path.""" |
| 102 |
path = self.path_for(key) |
| 103 |
path.parent.mkdir(parents=True, exist_ok=True) |
| 104 |
buffer = io.BytesIO() |
| 105 |
np.savez(buffer, input_features=tensor) |
| 106 |
write_bytes(path, buffer.getvalue()) |
| 107 |
return path |
| 108 |
|
| 109 |
def exists(self, key: AudioCacheKey) -> bool: |
| 110 |
return self.path_for(key).exists() |
| 111 |
|
| 112 |
def clear(self) -> None: |
| 113 |
"""Delete the entire cache tree. Test + opt-in user action only.""" |
| 114 |
if self._root.exists(): |
| 115 |
import shutil |
| 116 |
|
| 117 |
shutil.rmtree(self._root) |
| 118 |
|
| 119 |
|
| 120 |
@dataclass(frozen=True) |
| 121 |
class WaveformCacheKey: |
| 122 |
"""Key for the training-hot-path waveform cache. |
| 123 |
|
| 124 |
Distinct from `AudioCacheKey` (feature-level, processor-dependent): |
| 125 |
waveforms are pre-processor, so the key has no `processor_sha` — |
| 126 |
any Qwen2-Audio / Whisper / Wav2Vec2 processor at the same pinned |
| 127 |
sample_rate + duration sees the same decoded + truncated waveform. |
| 128 |
The feature-extractor still runs per batch; the cache skips |
| 129 |
soundfile decode + mono-mixing + truncation on repeat epochs |
| 130 |
(which dominate per-batch CPU time on a small audio corpus). |
| 131 |
|
| 132 |
`auto_resample` lands on the key to separate native-rate entries |
| 133 |
from resampled ones — a 48 kHz file cached without resampling is |
| 134 |
not interchangeable with the same file resampled to 16 kHz. |
| 135 |
""" |
| 136 |
|
| 137 |
blob_sha: str |
| 138 |
sample_rate: int |
| 139 |
max_length_ms: int |
| 140 |
auto_resample: bool = False |
| 141 |
|
| 142 |
def as_filename(self) -> str: |
| 143 |
rs = ".rs" if self.auto_resample else "" |
| 144 |
return f"{self.blob_sha}.{self.sample_rate}.{self.max_length_ms}{rs}.wav.npz" |
| 145 |
|
| 146 |
def shard(self) -> str: |
| 147 |
return self.blob_sha[:2] |
| 148 |
|
| 149 |
|
| 150 |
class WaveformCache: |
| 151 |
"""On-disk cache for decoded + mono-mixed + truncated waveforms. |
| 152 |
|
| 153 |
Parallel to `AudioCache` but keyed without `processor_sha` — the |
| 154 |
cached value is the pre-processor waveform (1-D float32 mono). |
| 155 |
Training's per-batch audio work is dominated by this decode step |
| 156 |
on a small corpus; caching turns a multi-epoch training run into |
| 157 |
a read-once + extract-each-epoch pattern. |
| 158 |
|
| 159 |
Stored as npz under key `waveform` so the on-disk layout stays |
| 160 |
distinct from `AudioCache`'s `input_features`. |
| 161 |
""" |
| 162 |
|
| 163 |
def __init__(self, root: Path) -> None: |
| 164 |
self._root = root |
| 165 |
|
| 166 |
@property |
| 167 |
def root(self) -> Path: |
| 168 |
return self._root |
| 169 |
|
| 170 |
def path_for(self, key: WaveformCacheKey) -> Path: |
| 171 |
return self._root / key.shard() / key.as_filename() |
| 172 |
|
| 173 |
def get(self, key: WaveformCacheKey) -> np.ndarray | None: |
| 174 |
path = self.path_for(key) |
| 175 |
if not path.exists(): |
| 176 |
return None |
| 177 |
try: |
| 178 |
with np.load(path) as npz: |
| 179 |
arr: np.ndarray = npz["waveform"].copy() |
| 180 |
return arr |
| 181 |
except (OSError, KeyError, ValueError): |
| 182 |
return None |
| 183 |
|
| 184 |
def put(self, key: WaveformCacheKey, waveform: np.ndarray) -> Path: |
| 185 |
path = self.path_for(key) |
| 186 |
path.parent.mkdir(parents=True, exist_ok=True) |
| 187 |
buffer = io.BytesIO() |
| 188 |
np.savez(buffer, waveform=waveform) |
| 189 |
write_bytes(path, buffer.getvalue()) |
| 190 |
return path |
| 191 |
|
| 192 |
def exists(self, key: WaveformCacheKey) -> bool: |
| 193 |
return self.path_for(key).exists() |
| 194 |
|
| 195 |
def clear(self) -> None: |
| 196 |
if self._root.exists(): |
| 197 |
import shutil |
| 198 |
|
| 199 |
shutil.rmtree(self._root) |
| 200 |
|
| 201 |
|
| 202 |
def processor_sha256(processor: Any) -> str: |
| 203 |
"""Canonical sha256 of the identity-bearing subset of an audio processor. |
| 204 |
|
| 205 |
The feature extractor (exposed as `processor.feature_extractor` on |
| 206 |
HF's `AutoProcessor` wrappers) carries the pre-processing params. |
| 207 |
Fingerprint a stable subset so an upstream bump that rewrites log-mel |
| 208 |
windowing invalidates every cached entry. |
| 209 |
|
| 210 |
Pinned on the processor instance via a private attribute for O(1) |
| 211 |
repeat calls within a run. |
| 212 |
""" |
| 213 |
pinned: str | None = getattr(processor, _FINGERPRINT_ATTR, None) |
| 214 |
if pinned is not None: |
| 215 |
return pinned |
| 216 |
|
| 217 |
feature_extractor = getattr(processor, "feature_extractor", processor) |
| 218 |
state: dict[str, object] = { |
| 219 |
"class": processor.__class__.__name__, |
| 220 |
"fe_class": feature_extractor.__class__.__name__, |
| 221 |
"sampling_rate": _readable(getattr(feature_extractor, "sampling_rate", None)), |
| 222 |
"feature_size": _readable(getattr(feature_extractor, "feature_size", None)), |
| 223 |
"n_fft": _readable(getattr(feature_extractor, "n_fft", None)), |
| 224 |
"hop_length": _readable(getattr(feature_extractor, "hop_length", None)), |
| 225 |
"chunk_length": _readable(getattr(feature_extractor, "chunk_length", None)), |
| 226 |
"padding_value": _readable(getattr(feature_extractor, "padding_value", None)), |
| 227 |
"return_attention_mask": bool(getattr(feature_extractor, "return_attention_mask", False)), |
| 228 |
} |
| 229 |
canonical = json.dumps(state, sort_keys=True, default=str) |
| 230 |
sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest() |
| 231 |
with contextlib.suppress(AttributeError, TypeError): |
| 232 |
object.__setattr__(processor, _FINGERPRINT_ATTR, sha) |
| 233 |
return sha |
| 234 |
|
| 235 |
|
| 236 |
def _readable(value: object) -> object: |
| 237 |
"""Coerce a value into a JSON-serializable form (mirror of vl_cache).""" |
| 238 |
if value is None: |
| 239 |
return None |
| 240 |
if isinstance(value, bool | int | float | str): |
| 241 |
return value |
| 242 |
if isinstance(value, list | tuple): |
| 243 |
return [_readable(v) for v in value] |
| 244 |
if isinstance(value, dict): |
| 245 |
return {str(k): _readable(v) for k, v in sorted(value.items())} |
| 246 |
return str(value) |