Python · 9250 bytes Raw Blame History
1 """Custom data collator for audio-language training.
2
3 TRL 1.2 ships `DataCollatorForVisionLanguageModeling` for VL bases but
4 does **not** ship an audio equivalent. This module fills the gap: it
5 takes the path-based audio rows emitted by `sections_to_rows` and
6 turns a list-of-rows batch into the `input_ids / attention_mask /
7 labels / input_features / feature_attention_mask` dict that
8 `Qwen2AudioForConditionalGeneration` (and any future audio-LM class
9 with a similar processor contract) expects.
10
11 Design choices:
12
13 - Rows carry `audio_path` + `audio_blob_sha` (not decoded waveforms).
14 This keeps the HF `Dataset` rows small and lets the collator decide
15 whether to decode per-batch or hit the cache.
16 - The optional `WaveformCache` memoizes the
17 `soundfile decode → mono-mix → truncate` pipeline on disk, keyed on
18 `(blob_sha, sample_rate, max_length_ms)`. The HF
19 processor's feature extractor still runs every batch — caching its
20 output would require re-implementing Qwen2-Audio's text-expansion
21 logic (the processor derives per-audio placeholder counts from
22 `feature_attention_mask`). The waveform cache covers the step that
23 actually dominates per-batch CPU time on a small corpus (decoding
24 a 30 s .wav is a few hundred ms; re-running on every epoch adds
25 up).
26 - Labels = `input_ids` with pad positions masked to `-100`. This is
27 full-sequence training (the model predicts every non-pad token
28 including the audio placeholder expansion). Instruction-tuning
29 variants that mask the audio + prompt to train only the response
30 land as a follow-up; the simpler shape is enough to get signal from
31 a small audio corpus and matches several published recipes.
32 - The HF processor owns audio-token-placeholder expansion — we pass
33 `text` verbatim and the processor replaces `<|AUDIO|>` with the
34 correct number of placeholder tokens derived from the audio frame
35 count. Our pinned `max_length_seconds` keeps that count stable.
36
37 The collator is deliberately not a `@dataclass` because TRL's trainer
38 callbacks sometimes introspect collator attributes; keeping plain
39 `__init__` state avoids surprises.
40 """
41
42 from __future__ import annotations
43
44 import logging
45 from pathlib import Path
46 from typing import Any
47
48 import numpy as np
49
50 from dlm.data.audio_cache import WaveformCache, WaveformCacheKey
51
52 _LOG = logging.getLogger(__name__)
53
54 _IGNORE_INDEX = -100 # HF convention — CrossEntropyLoss skips these positions
55
56
57 class AudioLmCollator:
58 """Collator for path-based audio rows → HF model-ready batch dict.
59
60 Parameters
61 ----------
62 processor:
63 Loaded `AutoProcessor` (e.g. Qwen2AudioProcessor). Must expose
64 `.tokenizer` with a pad token set.
65 sample_rate:
66 Target sample rate in Hz (from `AudioPreprocessorPlan`).
67 Rows whose native rate disagrees raise by default (same gate
68 as `preprocess_audio`). Pass `auto_resample=True` to resample
69 on the fly via `dlm.data.audio_resample`.
70 max_length_seconds:
71 Per-clip duration cap in seconds (from
72 `AudioPreprocessorPlan`). Longer waveforms are truncated.
73 max_length:
74 Optional token-length cap for the text side (post-expansion).
75 `None` uses the processor's built-in limit.
76 waveform_cache:
77 Optional `WaveformCache` for memoizing decoded + mono-mixed
78 + truncated waveforms across training epochs. `None` decodes
79 fresh every batch (the pre-deferred behavior). Cache keys
80 carry `auto_resample` so native-rate and resampled entries
81 don't collide.
82 auto_resample:
83 Opt-in flag flipped by `training.audio.auto_resample=True`.
84 When True, SR-mismatched files resample to `sample_rate`
85 instead of raising. Requires soxr or scipy; absence surfaces
86 as `AudioResampleUnavailable` at first mismatched decode.
87 """
88
89 def __init__(
90 self,
91 *,
92 processor: Any,
93 sample_rate: int,
94 max_length_seconds: float,
95 max_length: int | None = None,
96 waveform_cache: WaveformCache | None = None,
97 auto_resample: bool = False,
98 ) -> None:
99 self._processor = processor
100 self._sample_rate = sample_rate
101 self._max_length_seconds = max_length_seconds
102 self._max_length = max_length
103 self._waveform_cache = waveform_cache
104 self._auto_resample = auto_resample
105 self._max_length_ms = int(round(max_length_seconds * 1000))
106
107 tokenizer = getattr(processor, "tokenizer", None)
108 if tokenizer is None:
109 raise ValueError(
110 "AudioLmCollator: processor has no `.tokenizer` attribute — "
111 "cannot resolve pad token id"
112 )
113 self._pad_token_id = tokenizer.pad_token_id
114 if self._pad_token_id is None:
115 raise ValueError(
116 "AudioLmCollator: tokenizer has no pad_token_id — "
117 "prepare_tokenizer must run before the collator is built"
118 )
119
120 def __call__(self, rows: list[dict[str, Any]]) -> dict[str, Any]:
121 """Turn a list of dataset rows into a model-ready batch dict."""
122 if not rows:
123 raise ValueError("AudioLmCollator: received an empty batch")
124 texts: list[str] = []
125 waveforms: list[np.ndarray] = []
126 for row in rows:
127 if "audio_path" not in row or "text" not in row:
128 raise ValueError(
129 "AudioLmCollator: row is missing required keys "
130 f"({set(row.keys())}); expected audio_path + text"
131 )
132 blob_sha = row.get("audio_blob_sha")
133 waveforms.append(self._load_waveform(Path(row["audio_path"]), blob_sha=blob_sha))
134 texts.append(row["text"])
135
136 # One processor call over the whole batch: it handles padding
137 # across both the token side (pad to max_length) and the audio-
138 # feature side (pad to longest spectrogram). `return_tensors="pt"`
139 # gives us torch tensors matching the model's expected dtypes.
140 batch = self._processor(
141 text=texts,
142 audios=waveforms,
143 sampling_rate=self._sample_rate,
144 return_tensors="pt",
145 padding=True,
146 **({"max_length": self._max_length} if self._max_length else {}),
147 )
148
149 import torch as _torch
150
151 input_ids: _torch.Tensor = batch["input_ids"]
152 labels = input_ids.clone()
153 labels[labels == self._pad_token_id] = _IGNORE_INDEX
154 batch["labels"] = labels
155 return dict(batch)
156
157 def _load_waveform(self, path: Path, *, blob_sha: str | None = None) -> np.ndarray:
158 """Decode one audio blob into a mono float32 waveform.
159
160 When `waveform_cache` is configured and `blob_sha` is provided,
161 hits the on-disk cache keyed on
162 `(blob_sha, sample_rate, max_length_ms)`. Cache miss → decode
163 via `soundfile`, mono-mix, truncate, populate cache.
164
165 Refuses on sample-rate mismatch (same gate as
166 `preprocess_audio`). Stereo-to-mono by channel averaging.
167 Truncates to the configured duration.
168 """
169 # Cache lookup: only when both the cache is configured and the
170 # row carries a blob sha (older row shapes may not have one;
171 # skip the cache rather than break them).
172 cache_key: WaveformCacheKey | None = None
173 if self._waveform_cache is not None and blob_sha is not None:
174 cache_key = WaveformCacheKey(
175 blob_sha=blob_sha,
176 sample_rate=self._sample_rate,
177 max_length_ms=self._max_length_ms,
178 auto_resample=self._auto_resample,
179 )
180 hit = self._waveform_cache.get(cache_key)
181 if hit is not None:
182 return hit
183
184 import soundfile as sf # type: ignore[import-untyped]
185
186 data, native_sr = sf.read(str(path), dtype="float32", always_2d=False)
187
188 # Mono before resample: see audio_preprocessor._run_processor
189 # for the same rationale (mixing after resampling can smear
190 # channel-specific transients the filter needs to preserve).
191 if data.ndim > 1:
192 data = data.mean(axis=1).astype(np.float32, copy=False)
193 mono: np.ndarray = np.ascontiguousarray(data, dtype=np.float32)
194
195 if native_sr != self._sample_rate:
196 if not self._auto_resample:
197 raise ValueError(
198 f"AudioLmCollator: audio {path.name!r} native sample_rate="
199 f"{native_sr} Hz != pinned {self._sample_rate} Hz. "
200 "Set `training.audio.auto_resample: true` to resample "
201 "on the fly, or re-encode with "
202 f"`ffmpeg -i <in> -ar {self._sample_rate} <out>`."
203 )
204 from dlm.data.audio_resample import resample
205
206 mono = resample(mono, src_sr=native_sr, dst_sr=self._sample_rate)
207
208 max_samples = int(round(self._max_length_seconds * self._sample_rate))
209 if mono.shape[0] > max_samples:
210 mono = mono[:max_samples]
211
212 if cache_key is not None and self._waveform_cache is not None:
213 self._waveform_cache.put(cache_key, mono)
214
215 return mono