| 1 |
"""Audio inference helpers — prompt shaping, waveform loading, generation.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import sys |
| 6 |
from pathlib import Path |
| 7 |
from types import ModuleType |
| 8 |
|
| 9 |
import numpy as np |
| 10 |
import pytest |
| 11 |
import torch |
| 12 |
|
| 13 |
from dlm.inference.audio_generate import format_audio_prompt, generate_audio, load_audios |
| 14 |
|
| 15 |
|
| 16 |
class TestFormatAudioPrompt: |
| 17 |
def test_respects_user_placed_audio_token(self) -> None: |
| 18 |
prompt = "Please compare <audio> and explain." |
| 19 |
assert format_audio_prompt(prompt, audio_token="<audio>", num_audios=2) == prompt |
| 20 |
|
| 21 |
def test_prepends_one_token_per_audio(self) -> None: |
| 22 |
assert ( |
| 23 |
format_audio_prompt("describe", audio_token="<audio>", num_audios=2) |
| 24 |
== "<audio><audio>\ndescribe" |
| 25 |
) |
| 26 |
|
| 27 |
def test_empty_prompt_emits_tokens_only(self) -> None: |
| 28 |
assert ( |
| 29 |
format_audio_prompt("", audio_token="<audio>", num_audios=3) == "<audio><audio><audio>" |
| 30 |
) |
| 31 |
|
| 32 |
|
| 33 |
class TestLoadAudios: |
| 34 |
def test_missing_file_raises(self, tmp_path: Path) -> None: |
| 35 |
with pytest.raises(FileNotFoundError, match="audio not found"): |
| 36 |
load_audios([tmp_path / "missing.wav"], target_sample_rate=16_000) |
| 37 |
|
| 38 |
def test_downmixes_stereo_to_mono( |
| 39 |
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 40 |
) -> None: |
| 41 |
path = tmp_path / "stereo.wav" |
| 42 |
path.write_bytes(b"stub") |
| 43 |
|
| 44 |
fake_sf = ModuleType("soundfile") |
| 45 |
fake_sf.read = lambda _path, dtype, always_2d: ( |
| 46 |
np.array([[1.0, 3.0], [5.0, 7.0]], dtype=np.float32), |
| 47 |
16_000, |
| 48 |
) |
| 49 |
monkeypatch.setitem(sys.modules, "soundfile", fake_sf) |
| 50 |
|
| 51 |
[waveform] = load_audios([path], target_sample_rate=16_000) |
| 52 |
assert waveform.dtype == np.float32 |
| 53 |
assert waveform.tolist() == pytest.approx([2.0, 6.0]) |
| 54 |
|
| 55 |
def test_sample_rate_mismatch_refused_without_auto_resample( |
| 56 |
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 57 |
) -> None: |
| 58 |
path = tmp_path / "native.wav" |
| 59 |
path.write_bytes(b"stub") |
| 60 |
|
| 61 |
fake_sf = ModuleType("soundfile") |
| 62 |
fake_sf.read = lambda _path, dtype, always_2d: (np.array([1.0], dtype=np.float32), 22_050) |
| 63 |
monkeypatch.setitem(sys.modules, "soundfile", fake_sf) |
| 64 |
|
| 65 |
with pytest.raises(ValueError, match="does not match pinned 16000 Hz"): |
| 66 |
load_audios([path], target_sample_rate=16_000, auto_resample=False) |
| 67 |
|
| 68 |
def test_sample_rate_mismatch_resamples_when_enabled( |
| 69 |
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 70 |
) -> None: |
| 71 |
path = tmp_path / "native.wav" |
| 72 |
path.write_bytes(b"stub") |
| 73 |
|
| 74 |
fake_sf = ModuleType("soundfile") |
| 75 |
fake_sf.read = lambda _path, dtype, always_2d: ( |
| 76 |
np.array([1.0, 2.0], dtype=np.float32), |
| 77 |
22_050, |
| 78 |
) |
| 79 |
monkeypatch.setitem(sys.modules, "soundfile", fake_sf) |
| 80 |
monkeypatch.setattr( |
| 81 |
"dlm.data.audio_resample.resample", |
| 82 |
lambda mono, src_sr, dst_sr: np.array([9.0, 8.0], dtype=np.float32), |
| 83 |
) |
| 84 |
|
| 85 |
[waveform] = load_audios([path], target_sample_rate=16_000, auto_resample=True) |
| 86 |
assert waveform.tolist() == pytest.approx([9.0, 8.0]) |
| 87 |
|
| 88 |
|
| 89 |
class _Inputs(dict[str, torch.Tensor]): |
| 90 |
def to(self, device: object) -> _Inputs: |
| 91 |
return self |
| 92 |
|
| 93 |
|
| 94 |
class TestGenerateAudio: |
| 95 |
def test_generate_audio_decodes_response_only_tokens(self) -> None: |
| 96 |
class _Tokenizer: |
| 97 |
pad_token_id = 99 |
| 98 |
|
| 99 |
def decode(self, tokens: torch.Tensor, skip_special_tokens: bool = True) -> str: |
| 100 |
assert tokens.tolist() == [4, 5] |
| 101 |
return "transcript" |
| 102 |
|
| 103 |
class _Processor: |
| 104 |
def __init__(self) -> None: |
| 105 |
self.tokenizer = _Tokenizer() |
| 106 |
|
| 107 |
def __call__( |
| 108 |
self, |
| 109 |
*, |
| 110 |
audios: list[np.ndarray], |
| 111 |
text: str, |
| 112 |
sampling_rate: int, |
| 113 |
return_tensors: str, |
| 114 |
) -> _Inputs: |
| 115 |
assert len(audios) == 1 |
| 116 |
assert text == "<audio>\nwhat happened?" |
| 117 |
assert sampling_rate == 16_000 |
| 118 |
return _Inputs({"input_ids": torch.tensor([[1, 2, 3]])}) |
| 119 |
|
| 120 |
class _Model: |
| 121 |
device = torch.device("cpu") |
| 122 |
|
| 123 |
def generate(self, **kwargs: object) -> torch.Tensor: |
| 124 |
assert kwargs["pad_token_id"] == 99 |
| 125 |
return torch.tensor([[1, 2, 3, 4, 5]]) |
| 126 |
|
| 127 |
out = generate_audio( |
| 128 |
_Model(), |
| 129 |
_Processor(), |
| 130 |
"what happened?", |
| 131 |
[np.array([1.0], dtype=np.float32)], |
| 132 |
audio_token="<audio>", |
| 133 |
sample_rate=16_000, |
| 134 |
max_new_tokens=2, |
| 135 |
temperature=0.0, |
| 136 |
) |
| 137 |
assert out == "transcript" |