Python · 8896 bytes Raw Blame History
1 """Audio preprocessor tensor cache.
2
3 Mirrors `vl_cache.py`. Keyed on
4 `(blob_sha, processor_sha, sample_rate, max_length_seconds)` — a blob
5 bytes change, a processor / feature-extractor upgrade, a sample-rate
6 pin change, or a duration-cap change each invalidate the entry.
7 Orthogonal to the tokenized-section cache (which is keyed on
8 tokenizer sha, not audio processor sha).
9
10 Layout: `<audio-cache>/<blob_sha[:2]>/<blob_sha>.<proc_sha[:12]>.<sr>.<ms>.npz`.
11 Contents: a single numpy array stored under the key `input_features`.
12 Atomic write via `dlm.io.atomic.write_bytes` so a half-written file
13 never surfaces to a concurrent reader.
14
15 Processor identity (`processor_sha`) fingerprints the subset of
16 feature-extractor attributes that materially change output features:
17 `sampling_rate`, `feature_size`, `n_fft`, `hop_length`, `padding_value`,
18 and the class name. Full byte-level fingerprinting of the HF processor
19 isn't practical (they aren't JSON-clean); these fields match what the
20 Qwen2-Audio + Whisper-family feature extractors expose and what drift
21 between upstream revisions.
22 """
23
24 from __future__ import annotations
25
26 import contextlib
27 import hashlib
28 import io
29 import json
30 from dataclasses import dataclass
31 from pathlib import Path
32 from typing import Any, Final
33
34 import numpy as np
35
36 from dlm.io.atomic import write_bytes
37
38 _FINGERPRINT_ATTR: Final[str] = "_dlm_audio_processor_sha256"
39
40
41 @dataclass(frozen=True)
42 class AudioCacheKey:
43 """Composite key for one preprocessed audio tensor.
44
45 `auto_resample` lands on the key (not just the preprocessor path)
46 so a cached entry built without resampling isn't served to a
47 caller that asked for auto-resample — the inputs to the processor
48 differ when the source rate disagreed with the target.
49 """
50
51 blob_sha: str
52 processor_sha: str
53 sample_rate: int
54 max_length_ms: int
55 auto_resample: bool = False
56
57 def as_filename(self) -> str:
58 """Stable per-entry filename under the shard."""
59 rs = ".rs" if self.auto_resample else ""
60 return (
61 f"{self.blob_sha}.{self.processor_sha[:12]}"
62 f".{self.sample_rate}.{self.max_length_ms}{rs}.npz"
63 )
64
65 def shard(self) -> str:
66 """First 2 hex chars of blob_sha — the directory shard."""
67 return self.blob_sha[:2]
68
69
70 class AudioCache:
71 """On-disk cache for preprocessed audio feature tensors.
72
73 Lazy-initialized: constructing an `AudioCache` does not create the
74 directory. The first `put` creates the root + shard on demand.
75 """
76
77 def __init__(self, root: Path) -> None:
78 self._root = root
79
80 @property
81 def root(self) -> Path:
82 return self._root
83
84 def path_for(self, key: AudioCacheKey) -> Path:
85 return self._root / key.shard() / key.as_filename()
86
87 def get(self, key: AudioCacheKey) -> np.ndarray | None:
88 """Return the cached tensor, or `None` on miss."""
89 path = self.path_for(key)
90 if not path.exists():
91 return None
92 try:
93 with np.load(path) as npz:
94 arr: np.ndarray = npz["input_features"].copy()
95 return arr
96 except (OSError, KeyError, ValueError):
97 # Corrupt entry — treat as miss; `dlm cache clear` sweeps.
98 return None
99
100 def put(self, key: AudioCacheKey, tensor: np.ndarray) -> Path:
101 """Atomically write `tensor` under `key`; return the on-disk path."""
102 path = self.path_for(key)
103 path.parent.mkdir(parents=True, exist_ok=True)
104 buffer = io.BytesIO()
105 np.savez(buffer, input_features=tensor)
106 write_bytes(path, buffer.getvalue())
107 return path
108
109 def exists(self, key: AudioCacheKey) -> bool:
110 return self.path_for(key).exists()
111
112 def clear(self) -> None:
113 """Delete the entire cache tree. Test + opt-in user action only."""
114 if self._root.exists():
115 import shutil
116
117 shutil.rmtree(self._root)
118
119
120 @dataclass(frozen=True)
121 class WaveformCacheKey:
122 """Key for the training-hot-path waveform cache.
123
124 Distinct from `AudioCacheKey` (feature-level, processor-dependent):
125 waveforms are pre-processor, so the key has no `processor_sha` —
126 any Qwen2-Audio / Whisper / Wav2Vec2 processor at the same pinned
127 sample_rate + duration sees the same decoded + truncated waveform.
128 The feature-extractor still runs per batch; the cache skips
129 soundfile decode + mono-mixing + truncation on repeat epochs
130 (which dominate per-batch CPU time on a small audio corpus).
131
132 `auto_resample` lands on the key to separate native-rate entries
133 from resampled ones — a 48 kHz file cached without resampling is
134 not interchangeable with the same file resampled to 16 kHz.
135 """
136
137 blob_sha: str
138 sample_rate: int
139 max_length_ms: int
140 auto_resample: bool = False
141
142 def as_filename(self) -> str:
143 rs = ".rs" if self.auto_resample else ""
144 return f"{self.blob_sha}.{self.sample_rate}.{self.max_length_ms}{rs}.wav.npz"
145
146 def shard(self) -> str:
147 return self.blob_sha[:2]
148
149
150 class WaveformCache:
151 """On-disk cache for decoded + mono-mixed + truncated waveforms.
152
153 Parallel to `AudioCache` but keyed without `processor_sha` — the
154 cached value is the pre-processor waveform (1-D float32 mono).
155 Training's per-batch audio work is dominated by this decode step
156 on a small corpus; caching turns a multi-epoch training run into
157 a read-once + extract-each-epoch pattern.
158
159 Stored as npz under key `waveform` so the on-disk layout stays
160 distinct from `AudioCache`'s `input_features`.
161 """
162
163 def __init__(self, root: Path) -> None:
164 self._root = root
165
166 @property
167 def root(self) -> Path:
168 return self._root
169
170 def path_for(self, key: WaveformCacheKey) -> Path:
171 return self._root / key.shard() / key.as_filename()
172
173 def get(self, key: WaveformCacheKey) -> np.ndarray | None:
174 path = self.path_for(key)
175 if not path.exists():
176 return None
177 try:
178 with np.load(path) as npz:
179 arr: np.ndarray = npz["waveform"].copy()
180 return arr
181 except (OSError, KeyError, ValueError):
182 return None
183
184 def put(self, key: WaveformCacheKey, waveform: np.ndarray) -> Path:
185 path = self.path_for(key)
186 path.parent.mkdir(parents=True, exist_ok=True)
187 buffer = io.BytesIO()
188 np.savez(buffer, waveform=waveform)
189 write_bytes(path, buffer.getvalue())
190 return path
191
192 def exists(self, key: WaveformCacheKey) -> bool:
193 return self.path_for(key).exists()
194
195 def clear(self) -> None:
196 if self._root.exists():
197 import shutil
198
199 shutil.rmtree(self._root)
200
201
202 def processor_sha256(processor: Any) -> str:
203 """Canonical sha256 of the identity-bearing subset of an audio processor.
204
205 The feature extractor (exposed as `processor.feature_extractor` on
206 HF's `AutoProcessor` wrappers) carries the pre-processing params.
207 Fingerprint a stable subset so an upstream bump that rewrites log-mel
208 windowing invalidates every cached entry.
209
210 Pinned on the processor instance via a private attribute for O(1)
211 repeat calls within a run.
212 """
213 pinned: str | None = getattr(processor, _FINGERPRINT_ATTR, None)
214 if pinned is not None:
215 return pinned
216
217 feature_extractor = getattr(processor, "feature_extractor", processor)
218 state: dict[str, object] = {
219 "class": processor.__class__.__name__,
220 "fe_class": feature_extractor.__class__.__name__,
221 "sampling_rate": _readable(getattr(feature_extractor, "sampling_rate", None)),
222 "feature_size": _readable(getattr(feature_extractor, "feature_size", None)),
223 "n_fft": _readable(getattr(feature_extractor, "n_fft", None)),
224 "hop_length": _readable(getattr(feature_extractor, "hop_length", None)),
225 "chunk_length": _readable(getattr(feature_extractor, "chunk_length", None)),
226 "padding_value": _readable(getattr(feature_extractor, "padding_value", None)),
227 "return_attention_mask": bool(getattr(feature_extractor, "return_attention_mask", False)),
228 }
229 canonical = json.dumps(state, sort_keys=True, default=str)
230 sha = hashlib.sha256(canonical.encode("utf-8")).hexdigest()
231 with contextlib.suppress(AttributeError, TypeError):
232 object.__setattr__(processor, _FINGERPRINT_ATTR, sha)
233 return sha
234
235
236 def _readable(value: object) -> object:
237 """Coerce a value into a JSON-serializable form (mirror of vl_cache)."""
238 if value is None:
239 return None
240 if isinstance(value, bool | int | float | str):
241 return value
242 if isinstance(value, list | tuple):
243 return [_readable(v) for v in value]
244 if isinstance(value, dict):
245 return {str(k): _readable(v) for k, v in sorted(value.items())}
246 return str(value)