Python · 6527 bytes Raw Blame History
1 """Audio preprocessing: blob bytes → feature tensor via HF AutoProcessor.
2
3 Thin wrapper that loads an audio file via `soundfile`, reconciles with
4 the spec's pinned `sample_rate` (refuse by default; resample on opt-in),
5 truncates to `max_length_seconds`, and runs the HF processor. On-disk
6 caching is keyed on `(blob_sha, processor_sha, sample_rate,
7 max_length_seconds, auto_resample)` — the flag lands on the key so
8 cached native-rate entries don't serve resample-opted-in callers.
9
10 Callers own the processor lifecycle — `AutoProcessor.from_pretrained`
11 is expensive, so loading it once at trainer startup and reusing across
12 sections is the expected pattern. The cache does the heavy lifting for
13 repeat runs on the same corpus.
14
15 **Sample-rate mismatch policy.** Default `auto_resample=False` preserves
16 the original contract: raise `AudioSampleRateMismatch` on rate disagree.
17 `auto_resample=True` (flipped via `training.audio.auto_resample`)
18 routes through `dlm.data.audio_resample` which raises
19 `AudioResampleUnavailable` if neither soxr nor scipy is importable.
20 Both failure modes surface actionable errors rather than silently
21 training on the wrong rate.
22
23 Heavy imports (`soundfile`, `numpy`) happen inside the functions that
24 need them; the module is cheap to import for CLI subcommands that
25 don't touch audio.
26 """
27
28 from __future__ import annotations
29
30 from dataclasses import dataclass
31 from pathlib import Path
32 from typing import Any, Final
33
34 import numpy as np
35
36 from dlm.data.audio_cache import AudioCache, AudioCacheKey, processor_sha256
37 from dlm.data.errors import DataError
38
39
40 class AudioSampleRateMismatch(DataError): # noqa: N818 — `*Mismatch` mirrors other DataError subclasses
41 """Audio file sample rate doesn't match the base's pinned value.
42
43 Current releases refuse rather than resampling silently. The error
44 message echoes both rates so the user can re-encode with
45 `ffmpeg -ar <target>` or pick a base pinned to the clip's native
46 rate.
47 """
48
49
50 @dataclass(frozen=True)
51 class PreprocessedAudio:
52 """Result of running a processor over a single audio clip.
53
54 `input_features` is the processor's mel/log-mel tensor shaped
55 `(num_mel_bins, num_frames)` for most audio-LM bases. `cache_hit`
56 records whether the value came from disk so callers can surface
57 hit rates (parallel to the VL preprocessor).
58 """
59
60 input_features: np.ndarray
61 cache_hit: bool
62
63
64 _CACHE_KEY_FACTORY: Final = AudioCacheKey
65
66
67 def preprocess_audio(
68 *,
69 blob_path: Path,
70 blob_sha: str,
71 processor: Any,
72 sample_rate: int,
73 max_length_seconds: float,
74 cache: AudioCache | None = None,
75 auto_resample: bool = False,
76 ) -> PreprocessedAudio:
77 """Preprocess one audio blob into a feature tensor.
78
79 `processor` is a pre-loaded HF processor. `sample_rate` and
80 `max_length_seconds` come from the base's `AudioPreprocessorPlan`
81 — they pin both the reconciliation gate *and* the cache key.
82
83 On cache hit, returns the cached array without touching the
84 processor. On miss, reads the file via `soundfile`, reconciles
85 against the target rate (refuse when `auto_resample=False`,
86 resample when `auto_resample=True`), truncates to the max
87 duration, runs the processor, and writes the result back through
88 the cache. `cache=None` bypasses caching entirely (tests, ad-hoc
89 prompts).
90
91 Raises `AudioSampleRateMismatch` when rates disagree and
92 `auto_resample=False`; raises `AudioResampleUnavailable` when
93 `auto_resample=True` but neither soxr nor scipy is importable.
94 """
95 proc_sha = processor_sha256(processor)
96 key = _CACHE_KEY_FACTORY(
97 blob_sha=blob_sha,
98 processor_sha=proc_sha,
99 sample_rate=sample_rate,
100 max_length_ms=int(round(max_length_seconds * 1000)),
101 auto_resample=auto_resample,
102 )
103
104 if cache is not None:
105 hit = cache.get(key)
106 if hit is not None:
107 return PreprocessedAudio(input_features=hit, cache_hit=True)
108
109 tensor = _run_processor(
110 processor,
111 blob_path,
112 target_sample_rate=sample_rate,
113 max_length_seconds=max_length_seconds,
114 auto_resample=auto_resample,
115 )
116
117 if cache is not None:
118 cache.put(key, tensor)
119
120 return PreprocessedAudio(input_features=tensor, cache_hit=False)
121
122
123 def _run_processor(
124 processor: Any,
125 blob_path: Path,
126 *,
127 target_sample_rate: int,
128 max_length_seconds: float,
129 auto_resample: bool = False,
130 ) -> np.ndarray:
131 """Drive the HF processor over one audio clip, return features.
132
133 Loads the waveform via `soundfile` as float32 mono (average
134 channels if stereo), reconciles against `target_sample_rate`
135 (refuse when `auto_resample=False`, resample when `True`),
136 truncates to `max_length_seconds * target_sample_rate` samples,
137 then passes through
138 `processor(audios=..., sampling_rate=..., return_tensors="np")`.
139 """
140 import soundfile as sf # type: ignore[import-untyped]
141
142 data, native_sr = sf.read(str(blob_path), dtype="float32", always_2d=False)
143
144 # Mono-ize first: resampling a stereo waveform then mixing can
145 # smear channel-specific transients; mixing before resampling
146 # keeps the resampler's anti-alias filter well-behaved.
147 if data.ndim > 1:
148 data = data.mean(axis=1).astype(np.float32, copy=False)
149 data = np.ascontiguousarray(data, dtype=np.float32)
150
151 if native_sr != target_sample_rate:
152 if not auto_resample:
153 raise AudioSampleRateMismatch(
154 f"audio {blob_path.name!r}: native sample_rate={native_sr} Hz "
155 f"does not match pinned {target_sample_rate} Hz. "
156 f"Set `training.audio.auto_resample: true` to resample on "
157 f"the fly, or re-encode manually with "
158 f"`ffmpeg -i <in> -ar {target_sample_rate} <out>`."
159 )
160 from dlm.data.audio_resample import resample
161
162 data = resample(data, src_sr=native_sr, dst_sr=target_sample_rate)
163
164 max_samples = int(round(max_length_seconds * target_sample_rate))
165 if data.shape[0] > max_samples:
166 data = data[:max_samples]
167
168 outputs = processor(
169 audios=data,
170 sampling_rate=target_sample_rate,
171 return_tensors="np",
172 )
173 input_features = outputs["input_features"]
174 if not isinstance(input_features, np.ndarray):
175 input_features = np.asarray(input_features, dtype=np.float32)
176 result: np.ndarray = input_features.astype(np.float32, copy=False)
177 return result