Python · 4744 bytes Raw Blame History
1 """Audio inference helpers — prompt shaping, waveform loading, generation."""
2
3 from __future__ import annotations
4
5 import sys
6 from pathlib import Path
7 from types import ModuleType
8
9 import numpy as np
10 import pytest
11 import torch
12
13 from dlm.inference.audio_generate import format_audio_prompt, generate_audio, load_audios
14
15
16 class TestFormatAudioPrompt:
17 def test_respects_user_placed_audio_token(self) -> None:
18 prompt = "Please compare <audio> and explain."
19 assert format_audio_prompt(prompt, audio_token="<audio>", num_audios=2) == prompt
20
21 def test_prepends_one_token_per_audio(self) -> None:
22 assert (
23 format_audio_prompt("describe", audio_token="<audio>", num_audios=2)
24 == "<audio><audio>\ndescribe"
25 )
26
27 def test_empty_prompt_emits_tokens_only(self) -> None:
28 assert (
29 format_audio_prompt("", audio_token="<audio>", num_audios=3) == "<audio><audio><audio>"
30 )
31
32
33 class TestLoadAudios:
34 def test_missing_file_raises(self, tmp_path: Path) -> None:
35 with pytest.raises(FileNotFoundError, match="audio not found"):
36 load_audios([tmp_path / "missing.wav"], target_sample_rate=16_000)
37
38 def test_downmixes_stereo_to_mono(
39 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
40 ) -> None:
41 path = tmp_path / "stereo.wav"
42 path.write_bytes(b"stub")
43
44 fake_sf = ModuleType("soundfile")
45 fake_sf.read = lambda _path, dtype, always_2d: (
46 np.array([[1.0, 3.0], [5.0, 7.0]], dtype=np.float32),
47 16_000,
48 )
49 monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
50
51 [waveform] = load_audios([path], target_sample_rate=16_000)
52 assert waveform.dtype == np.float32
53 assert waveform.tolist() == pytest.approx([2.0, 6.0])
54
55 def test_sample_rate_mismatch_refused_without_auto_resample(
56 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
57 ) -> None:
58 path = tmp_path / "native.wav"
59 path.write_bytes(b"stub")
60
61 fake_sf = ModuleType("soundfile")
62 fake_sf.read = lambda _path, dtype, always_2d: (np.array([1.0], dtype=np.float32), 22_050)
63 monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
64
65 with pytest.raises(ValueError, match="does not match pinned 16000 Hz"):
66 load_audios([path], target_sample_rate=16_000, auto_resample=False)
67
68 def test_sample_rate_mismatch_resamples_when_enabled(
69 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
70 ) -> None:
71 path = tmp_path / "native.wav"
72 path.write_bytes(b"stub")
73
74 fake_sf = ModuleType("soundfile")
75 fake_sf.read = lambda _path, dtype, always_2d: (
76 np.array([1.0, 2.0], dtype=np.float32),
77 22_050,
78 )
79 monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
80 monkeypatch.setattr(
81 "dlm.data.audio_resample.resample",
82 lambda mono, src_sr, dst_sr: np.array([9.0, 8.0], dtype=np.float32),
83 )
84
85 [waveform] = load_audios([path], target_sample_rate=16_000, auto_resample=True)
86 assert waveform.tolist() == pytest.approx([9.0, 8.0])
87
88
89 class _Inputs(dict[str, torch.Tensor]):
90 def to(self, device: object) -> _Inputs:
91 return self
92
93
94 class TestGenerateAudio:
95 def test_generate_audio_decodes_response_only_tokens(self) -> None:
96 class _Tokenizer:
97 pad_token_id = 99
98
99 def decode(self, tokens: torch.Tensor, skip_special_tokens: bool = True) -> str:
100 assert tokens.tolist() == [4, 5]
101 return "transcript"
102
103 class _Processor:
104 def __init__(self) -> None:
105 self.tokenizer = _Tokenizer()
106
107 def __call__(
108 self,
109 *,
110 audios: list[np.ndarray],
111 text: str,
112 sampling_rate: int,
113 return_tensors: str,
114 ) -> _Inputs:
115 assert len(audios) == 1
116 assert text == "<audio>\nwhat happened?"
117 assert sampling_rate == 16_000
118 return _Inputs({"input_ids": torch.tensor([[1, 2, 3]])})
119
120 class _Model:
121 device = torch.device("cpu")
122
123 def generate(self, **kwargs: object) -> torch.Tensor:
124 assert kwargs["pad_token_id"] == 99
125 return torch.tensor([[1, 2, 3, 4, 5]])
126
127 out = generate_audio(
128 _Model(),
129 _Processor(),
130 "what happened?",
131 [np.array([1.0], dtype=np.float32)],
132 audio_token="<audio>",
133 sample_rate=16_000,
134 max_new_tokens=2,
135 temperature=0.0,
136 )
137 assert out == "transcript"