| 1 |
"""Custom data collator for audio-language training. |
| 2 |
|
| 3 |
TRL 1.2 ships `DataCollatorForVisionLanguageModeling` for VL bases but |
| 4 |
does **not** ship an audio equivalent. This module fills the gap: it |
| 5 |
takes the path-based audio rows emitted by `sections_to_rows` and |
| 6 |
turns a list-of-rows batch into the `input_ids / attention_mask / |
| 7 |
labels / input_features / feature_attention_mask` dict that |
| 8 |
`Qwen2AudioForConditionalGeneration` (and any future audio-LM class |
| 9 |
with a similar processor contract) expects. |
| 10 |
|
| 11 |
Design choices: |
| 12 |
|
| 13 |
- Rows carry `audio_path` + `audio_blob_sha` (not decoded waveforms). |
| 14 |
This keeps the HF `Dataset` rows small and lets the collator decide |
| 15 |
whether to decode per-batch or hit the cache. |
| 16 |
- The optional `WaveformCache` memoizes the |
| 17 |
`soundfile decode → mono-mix → truncate` pipeline on disk, keyed on |
| 18 |
`(blob_sha, sample_rate, max_length_ms)`. The HF |
| 19 |
processor's feature extractor still runs every batch — caching its |
| 20 |
output would require re-implementing Qwen2-Audio's text-expansion |
| 21 |
logic (the processor derives per-audio placeholder counts from |
| 22 |
`feature_attention_mask`). The waveform cache covers the step that |
| 23 |
actually dominates per-batch CPU time on a small corpus (decoding |
| 24 |
a 30 s .wav is a few hundred ms; re-running on every epoch adds |
| 25 |
up). |
| 26 |
- Labels = `input_ids` with pad positions masked to `-100`. This is |
| 27 |
full-sequence training (the model predicts every non-pad token |
| 28 |
including the audio placeholder expansion). Instruction-tuning |
| 29 |
variants that mask the audio + prompt to train only the response |
| 30 |
land as a follow-up; the simpler shape is enough to get signal from |
| 31 |
a small audio corpus and matches several published recipes. |
| 32 |
- The HF processor owns audio-token-placeholder expansion — we pass |
| 33 |
`text` verbatim and the processor replaces `<|AUDIO|>` with the |
| 34 |
correct number of placeholder tokens derived from the audio frame |
| 35 |
count. Our pinned `max_length_seconds` keeps that count stable. |
| 36 |
|
| 37 |
The collator is deliberately not a `@dataclass` because TRL's trainer |
| 38 |
callbacks sometimes introspect collator attributes; keeping plain |
| 39 |
`__init__` state avoids surprises. |
| 40 |
""" |
| 41 |
|
| 42 |
from __future__ import annotations |
| 43 |
|
| 44 |
import logging |
| 45 |
from pathlib import Path |
| 46 |
from typing import Any |
| 47 |
|
| 48 |
import numpy as np |
| 49 |
|
| 50 |
from dlm.data.audio_cache import WaveformCache, WaveformCacheKey |
| 51 |
|
| 52 |
_LOG = logging.getLogger(__name__) |
| 53 |
|
| 54 |
_IGNORE_INDEX = -100 # HF convention — CrossEntropyLoss skips these positions |
| 55 |
|
| 56 |
|
| 57 |
class AudioLmCollator: |
| 58 |
"""Collator for path-based audio rows → HF model-ready batch dict. |
| 59 |
|
| 60 |
Parameters |
| 61 |
---------- |
| 62 |
processor: |
| 63 |
Loaded `AutoProcessor` (e.g. Qwen2AudioProcessor). Must expose |
| 64 |
`.tokenizer` with a pad token set. |
| 65 |
sample_rate: |
| 66 |
Target sample rate in Hz (from `AudioPreprocessorPlan`). |
| 67 |
Rows whose native rate disagrees raise by default (same gate |
| 68 |
as `preprocess_audio`). Pass `auto_resample=True` to resample |
| 69 |
on the fly via `dlm.data.audio_resample`. |
| 70 |
max_length_seconds: |
| 71 |
Per-clip duration cap in seconds (from |
| 72 |
`AudioPreprocessorPlan`). Longer waveforms are truncated. |
| 73 |
max_length: |
| 74 |
Optional token-length cap for the text side (post-expansion). |
| 75 |
`None` uses the processor's built-in limit. |
| 76 |
waveform_cache: |
| 77 |
Optional `WaveformCache` for memoizing decoded + mono-mixed |
| 78 |
+ truncated waveforms across training epochs. `None` decodes |
| 79 |
fresh every batch (the pre-deferred behavior). Cache keys |
| 80 |
carry `auto_resample` so native-rate and resampled entries |
| 81 |
don't collide. |
| 82 |
auto_resample: |
| 83 |
Opt-in flag flipped by `training.audio.auto_resample=True`. |
| 84 |
When True, SR-mismatched files resample to `sample_rate` |
| 85 |
instead of raising. Requires soxr or scipy; absence surfaces |
| 86 |
as `AudioResampleUnavailable` at first mismatched decode. |
| 87 |
""" |
| 88 |
|
| 89 |
def __init__( |
| 90 |
self, |
| 91 |
*, |
| 92 |
processor: Any, |
| 93 |
sample_rate: int, |
| 94 |
max_length_seconds: float, |
| 95 |
max_length: int | None = None, |
| 96 |
waveform_cache: WaveformCache | None = None, |
| 97 |
auto_resample: bool = False, |
| 98 |
) -> None: |
| 99 |
self._processor = processor |
| 100 |
self._sample_rate = sample_rate |
| 101 |
self._max_length_seconds = max_length_seconds |
| 102 |
self._max_length = max_length |
| 103 |
self._waveform_cache = waveform_cache |
| 104 |
self._auto_resample = auto_resample |
| 105 |
self._max_length_ms = int(round(max_length_seconds * 1000)) |
| 106 |
|
| 107 |
tokenizer = getattr(processor, "tokenizer", None) |
| 108 |
if tokenizer is None: |
| 109 |
raise ValueError( |
| 110 |
"AudioLmCollator: processor has no `.tokenizer` attribute — " |
| 111 |
"cannot resolve pad token id" |
| 112 |
) |
| 113 |
self._pad_token_id = tokenizer.pad_token_id |
| 114 |
if self._pad_token_id is None: |
| 115 |
raise ValueError( |
| 116 |
"AudioLmCollator: tokenizer has no pad_token_id — " |
| 117 |
"prepare_tokenizer must run before the collator is built" |
| 118 |
) |
| 119 |
|
| 120 |
def __call__(self, rows: list[dict[str, Any]]) -> dict[str, Any]: |
| 121 |
"""Turn a list of dataset rows into a model-ready batch dict.""" |
| 122 |
if not rows: |
| 123 |
raise ValueError("AudioLmCollator: received an empty batch") |
| 124 |
texts: list[str] = [] |
| 125 |
waveforms: list[np.ndarray] = [] |
| 126 |
for row in rows: |
| 127 |
if "audio_path" not in row or "text" not in row: |
| 128 |
raise ValueError( |
| 129 |
"AudioLmCollator: row is missing required keys " |
| 130 |
f"({set(row.keys())}); expected audio_path + text" |
| 131 |
) |
| 132 |
blob_sha = row.get("audio_blob_sha") |
| 133 |
waveforms.append(self._load_waveform(Path(row["audio_path"]), blob_sha=blob_sha)) |
| 134 |
texts.append(row["text"]) |
| 135 |
|
| 136 |
# One processor call over the whole batch: it handles padding |
| 137 |
# across both the token side (pad to max_length) and the audio- |
| 138 |
# feature side (pad to longest spectrogram). `return_tensors="pt"` |
| 139 |
# gives us torch tensors matching the model's expected dtypes. |
| 140 |
batch = self._processor( |
| 141 |
text=texts, |
| 142 |
audios=waveforms, |
| 143 |
sampling_rate=self._sample_rate, |
| 144 |
return_tensors="pt", |
| 145 |
padding=True, |
| 146 |
**({"max_length": self._max_length} if self._max_length else {}), |
| 147 |
) |
| 148 |
|
| 149 |
import torch as _torch |
| 150 |
|
| 151 |
input_ids: _torch.Tensor = batch["input_ids"] |
| 152 |
labels = input_ids.clone() |
| 153 |
labels[labels == self._pad_token_id] = _IGNORE_INDEX |
| 154 |
batch["labels"] = labels |
| 155 |
return dict(batch) |
| 156 |
|
| 157 |
def _load_waveform(self, path: Path, *, blob_sha: str | None = None) -> np.ndarray: |
| 158 |
"""Decode one audio blob into a mono float32 waveform. |
| 159 |
|
| 160 |
When `waveform_cache` is configured and `blob_sha` is provided, |
| 161 |
hits the on-disk cache keyed on |
| 162 |
`(blob_sha, sample_rate, max_length_ms)`. Cache miss → decode |
| 163 |
via `soundfile`, mono-mix, truncate, populate cache. |
| 164 |
|
| 165 |
Refuses on sample-rate mismatch (same gate as |
| 166 |
`preprocess_audio`). Stereo-to-mono by channel averaging. |
| 167 |
Truncates to the configured duration. |
| 168 |
""" |
| 169 |
# Cache lookup: only when both the cache is configured and the |
| 170 |
# row carries a blob sha (older row shapes may not have one; |
| 171 |
# skip the cache rather than break them). |
| 172 |
cache_key: WaveformCacheKey | None = None |
| 173 |
if self._waveform_cache is not None and blob_sha is not None: |
| 174 |
cache_key = WaveformCacheKey( |
| 175 |
blob_sha=blob_sha, |
| 176 |
sample_rate=self._sample_rate, |
| 177 |
max_length_ms=self._max_length_ms, |
| 178 |
auto_resample=self._auto_resample, |
| 179 |
) |
| 180 |
hit = self._waveform_cache.get(cache_key) |
| 181 |
if hit is not None: |
| 182 |
return hit |
| 183 |
|
| 184 |
import soundfile as sf # type: ignore[import-untyped] |
| 185 |
|
| 186 |
data, native_sr = sf.read(str(path), dtype="float32", always_2d=False) |
| 187 |
|
| 188 |
# Mono before resample: see audio_preprocessor._run_processor |
| 189 |
# for the same rationale (mixing after resampling can smear |
| 190 |
# channel-specific transients the filter needs to preserve). |
| 191 |
if data.ndim > 1: |
| 192 |
data = data.mean(axis=1).astype(np.float32, copy=False) |
| 193 |
mono: np.ndarray = np.ascontiguousarray(data, dtype=np.float32) |
| 194 |
|
| 195 |
if native_sr != self._sample_rate: |
| 196 |
if not self._auto_resample: |
| 197 |
raise ValueError( |
| 198 |
f"AudioLmCollator: audio {path.name!r} native sample_rate=" |
| 199 |
f"{native_sr} Hz != pinned {self._sample_rate} Hz. " |
| 200 |
"Set `training.audio.auto_resample: true` to resample " |
| 201 |
"on the fly, or re-encode with " |
| 202 |
f"`ffmpeg -i <in> -ar {self._sample_rate} <out>`." |
| 203 |
) |
| 204 |
from dlm.data.audio_resample import resample |
| 205 |
|
| 206 |
mono = resample(mono, src_sr=native_sr, dst_sr=self._sample_rate) |
| 207 |
|
| 208 |
max_samples = int(round(self._max_length_seconds * self._sample_rate)) |
| 209 |
if mono.shape[0] > max_samples: |
| 210 |
mono = mono[:max_samples] |
| 211 |
|
| 212 |
if cache_key is not None and self._waveform_cache is not None: |
| 213 |
self._waveform_cache.put(cache_key, mono) |
| 214 |
|
| 215 |
return mono |