Cover inference edge branches
Authored by
mfwolffe <wolffemf@dukes.jmu.edu>
- SHA
1c7561eba8ff40bbde71f0be01f69cabf84f8d6a- Parents
-
62c3366 - Tree
ad20a71
1c7561e
1c7561eba8ff40bbde71f0be01f69cabf84f8d6a62c3366
ad20a71| Status | File | + | - |
|---|---|---|---|
| A |
tests/unit/inference/test_audio_generate.py
|
137 | 0 |
| M |
tests/unit/inference/test_backend_select.py
|
35 | 0 |
| M |
tests/unit/inference/test_gate.py
|
36 | 0 |
| M |
tests/unit/inference/test_mlx_adapter_conversion.py
|
14 | 0 |
| A |
tests/unit/inference/test_mlx_backend.py
|
125 | 0 |
tests/unit/inference/test_audio_generate.pyadded@@ -0,0 +1,137 @@ | ||
| 1 | +"""Audio inference helpers — prompt shaping, waveform loading, generation.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +import sys | |
| 6 | +from pathlib import Path | |
| 7 | +from types import ModuleType | |
| 8 | + | |
| 9 | +import numpy as np | |
| 10 | +import pytest | |
| 11 | +import torch | |
| 12 | + | |
| 13 | +from dlm.inference.audio_generate import format_audio_prompt, generate_audio, load_audios | |
| 14 | + | |
| 15 | + | |
| 16 | +class TestFormatAudioPrompt: | |
| 17 | + def test_respects_user_placed_audio_token(self) -> None: | |
| 18 | + prompt = "Please compare <audio> and explain." | |
| 19 | + assert format_audio_prompt(prompt, audio_token="<audio>", num_audios=2) == prompt | |
| 20 | + | |
| 21 | + def test_prepends_one_token_per_audio(self) -> None: | |
| 22 | + assert ( | |
| 23 | + format_audio_prompt("describe", audio_token="<audio>", num_audios=2) | |
| 24 | + == "<audio><audio>\ndescribe" | |
| 25 | + ) | |
| 26 | + | |
| 27 | + def test_empty_prompt_emits_tokens_only(self) -> None: | |
| 28 | + assert ( | |
| 29 | + format_audio_prompt("", audio_token="<audio>", num_audios=3) == "<audio><audio><audio>" | |
| 30 | + ) | |
| 31 | + | |
| 32 | + | |
| 33 | +class TestLoadAudios: | |
| 34 | + def test_missing_file_raises(self, tmp_path: Path) -> None: | |
| 35 | + with pytest.raises(FileNotFoundError, match="audio not found"): | |
| 36 | + load_audios([tmp_path / "missing.wav"], target_sample_rate=16_000) | |
| 37 | + | |
| 38 | + def test_downmixes_stereo_to_mono( | |
| 39 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 40 | + ) -> None: | |
| 41 | + path = tmp_path / "stereo.wav" | |
| 42 | + path.write_bytes(b"stub") | |
| 43 | + | |
| 44 | + fake_sf = ModuleType("soundfile") | |
| 45 | + fake_sf.read = lambda _path, dtype, always_2d: ( | |
| 46 | + np.array([[1.0, 3.0], [5.0, 7.0]], dtype=np.float32), | |
| 47 | + 16_000, | |
| 48 | + ) | |
| 49 | + monkeypatch.setitem(sys.modules, "soundfile", fake_sf) | |
| 50 | + | |
| 51 | + [waveform] = load_audios([path], target_sample_rate=16_000) | |
| 52 | + assert waveform.dtype == np.float32 | |
| 53 | + assert waveform.tolist() == pytest.approx([2.0, 6.0]) | |
| 54 | + | |
| 55 | + def test_sample_rate_mismatch_refused_without_auto_resample( | |
| 56 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 57 | + ) -> None: | |
| 58 | + path = tmp_path / "native.wav" | |
| 59 | + path.write_bytes(b"stub") | |
| 60 | + | |
| 61 | + fake_sf = ModuleType("soundfile") | |
| 62 | + fake_sf.read = lambda _path, dtype, always_2d: (np.array([1.0], dtype=np.float32), 22_050) | |
| 63 | + monkeypatch.setitem(sys.modules, "soundfile", fake_sf) | |
| 64 | + | |
| 65 | + with pytest.raises(ValueError, match="does not match pinned 16000 Hz"): | |
| 66 | + load_audios([path], target_sample_rate=16_000, auto_resample=False) | |
| 67 | + | |
| 68 | + def test_sample_rate_mismatch_resamples_when_enabled( | |
| 69 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 70 | + ) -> None: | |
| 71 | + path = tmp_path / "native.wav" | |
| 72 | + path.write_bytes(b"stub") | |
| 73 | + | |
| 74 | + fake_sf = ModuleType("soundfile") | |
| 75 | + fake_sf.read = lambda _path, dtype, always_2d: ( | |
| 76 | + np.array([1.0, 2.0], dtype=np.float32), | |
| 77 | + 22_050, | |
| 78 | + ) | |
| 79 | + monkeypatch.setitem(sys.modules, "soundfile", fake_sf) | |
| 80 | + monkeypatch.setattr( | |
| 81 | + "dlm.data.audio_resample.resample", | |
| 82 | + lambda mono, src_sr, dst_sr: np.array([9.0, 8.0], dtype=np.float32), | |
| 83 | + ) | |
| 84 | + | |
| 85 | + [waveform] = load_audios([path], target_sample_rate=16_000, auto_resample=True) | |
| 86 | + assert waveform.tolist() == pytest.approx([9.0, 8.0]) | |
| 87 | + | |
| 88 | + | |
| 89 | +class _Inputs(dict[str, torch.Tensor]): | |
| 90 | + def to(self, device: object) -> _Inputs: | |
| 91 | + return self | |
| 92 | + | |
| 93 | + | |
| 94 | +class TestGenerateAudio: | |
| 95 | + def test_generate_audio_decodes_response_only_tokens(self) -> None: | |
| 96 | + class _Tokenizer: | |
| 97 | + pad_token_id = 99 | |
| 98 | + | |
| 99 | + def decode(self, tokens: torch.Tensor, skip_special_tokens: bool = True) -> str: | |
| 100 | + assert tokens.tolist() == [4, 5] | |
| 101 | + return "transcript" | |
| 102 | + | |
| 103 | + class _Processor: | |
| 104 | + def __init__(self) -> None: | |
| 105 | + self.tokenizer = _Tokenizer() | |
| 106 | + | |
| 107 | + def __call__( | |
| 108 | + self, | |
| 109 | + *, | |
| 110 | + audios: list[np.ndarray], | |
| 111 | + text: str, | |
| 112 | + sampling_rate: int, | |
| 113 | + return_tensors: str, | |
| 114 | + ) -> _Inputs: | |
| 115 | + assert len(audios) == 1 | |
| 116 | + assert text == "<audio>\nwhat happened?" | |
| 117 | + assert sampling_rate == 16_000 | |
| 118 | + return _Inputs({"input_ids": torch.tensor([[1, 2, 3]])}) | |
| 119 | + | |
| 120 | + class _Model: | |
| 121 | + device = torch.device("cpu") | |
| 122 | + | |
| 123 | + def generate(self, **kwargs: object) -> torch.Tensor: | |
| 124 | + assert kwargs["pad_token_id"] == 99 | |
| 125 | + return torch.tensor([[1, 2, 3, 4, 5]]) | |
| 126 | + | |
| 127 | + out = generate_audio( | |
| 128 | + _Model(), | |
| 129 | + _Processor(), | |
| 130 | + "what happened?", | |
| 131 | + [np.array([1.0], dtype=np.float32)], | |
| 132 | + audio_token="<audio>", | |
| 133 | + sample_rate=16_000, | |
| 134 | + max_new_tokens=2, | |
| 135 | + temperature=0.0, | |
| 136 | + ) | |
| 137 | + assert out == "transcript" | |
tests/unit/inference/test_backend_select.pymodified@@ -9,6 +9,7 @@ import pytest | ||
| 9 | 9 | from dlm.inference.backends.select import ( |
| 10 | 10 | UnsupportedBackendError, |
| 11 | 11 | build_backend, |
| 12 | + is_apple_silicon, | |
| 12 | 13 | select_backend, |
| 13 | 14 | ) |
| 14 | 15 | |
@@ -77,6 +78,12 @@ class TestBuildBackend: | ||
| 77 | 78 | backend = build_backend("pytorch", MagicMock()) |
| 78 | 79 | assert isinstance(backend, PyTorchBackend) |
| 79 | 80 | |
| 81 | + def test_mlx_returns_mlx_backend(self) -> None: | |
| 82 | + from dlm.inference.backends.mlx_backend import MlxBackend | |
| 83 | + | |
| 84 | + backend = build_backend("mlx", MagicMock()) | |
| 85 | + assert isinstance(backend, MlxBackend) | |
| 86 | + | |
| 80 | 87 | def test_unknown_backend_raises(self) -> None: |
| 81 | 88 | with pytest.raises(ValueError, match="unknown backend"): |
| 82 | 89 | build_backend("haskell", MagicMock()) # type: ignore[arg-type] |
@@ -94,3 +101,31 @@ class TestMlxAvailableDoesNotImportMlx: | ||
| 94 | 101 | ): |
| 95 | 102 | assert sel.mlx_available() is False |
| 96 | 103 | m_find.assert_not_called() |
| 104 | + | |
| 105 | + def test_mlx_available_checks_both_packages_on_apple_silicon(self) -> None: | |
| 106 | + from dlm.inference.backends import select as sel | |
| 107 | + | |
| 108 | + with ( | |
| 109 | + patch.object(sel, "is_apple_silicon", return_value=True), | |
| 110 | + patch.object( | |
| 111 | + sel.importlib.util, "find_spec", side_effect=[object(), object()] | |
| 112 | + ) as m_find, | |
| 113 | + ): | |
| 114 | + assert sel.mlx_available() is True | |
| 115 | + assert m_find.call_count == 2 | |
| 116 | + | |
| 117 | + | |
| 118 | +class TestPlatformHelper: | |
| 119 | + def test_is_apple_silicon_true_only_for_darwin_arm64(self) -> None: | |
| 120 | + with ( | |
| 121 | + patch("dlm.inference.backends.select.sys.platform", "darwin"), | |
| 122 | + patch("dlm.inference.backends.select.platform.machine", return_value="arm64"), | |
| 123 | + ): | |
| 124 | + assert is_apple_silicon() is True | |
| 125 | + | |
| 126 | + def test_is_apple_silicon_false_for_other_hosts(self) -> None: | |
| 127 | + with ( | |
| 128 | + patch("dlm.inference.backends.select.sys.platform", "linux"), | |
| 129 | + patch("dlm.inference.backends.select.platform.machine", return_value="x86_64"), | |
| 130 | + ): | |
| 131 | + assert is_apple_silicon() is False | |
tests/unit/inference/test_gate.pymodified@@ -88,6 +88,26 @@ class _StubBaseModel: | ||
| 88 | 88 | forward = __call__ |
| 89 | 89 | |
| 90 | 90 | |
| 91 | +class _NoMaskTokenizer(_StubTokenizer): | |
| 92 | + def __call__( | |
| 93 | + self, | |
| 94 | + prompt: str, | |
| 95 | + *, | |
| 96 | + return_tensors: str = "pt", | |
| 97 | + truncation: bool = True, | |
| 98 | + max_length: int = 512, | |
| 99 | + ) -> dict[str, object]: | |
| 100 | + import torch | |
| 101 | + | |
| 102 | + ids = torch.randint(0, 100, (1, self._seq_len)) | |
| 103 | + return {"input_ids": ids} | |
| 104 | + | |
| 105 | + | |
| 106 | +class _NoParamBaseModel(_StubBaseModel): | |
| 107 | + def parameters(self): # type: ignore[no-untyped-def] | |
| 108 | + return iter(()) | |
| 109 | + | |
| 110 | + | |
| 91 | 111 | def _train_gate_on_store( |
| 92 | 112 | tmp_path: Path, |
| 93 | 113 | *, |
@@ -158,6 +178,22 @@ class TestEmbedPrompt: | ||
| 158 | 178 | e2 = embed_prompt(prompt="hello world", tokenizer=tokenizer, base_model=model) |
| 159 | 179 | assert not torch.allclose(e1, e2) |
| 160 | 180 | |
| 181 | + def test_falls_back_to_cpu_when_model_has_no_parameters(self) -> None: | |
| 182 | + embedding = embed_prompt( | |
| 183 | + prompt="hello", | |
| 184 | + tokenizer=_StubTokenizer(), | |
| 185 | + base_model=_NoParamBaseModel(hidden_dim=8), | |
| 186 | + ) | |
| 187 | + assert embedding.shape == (8,) | |
| 188 | + | |
| 189 | + def test_mean_pools_without_attention_mask(self) -> None: | |
| 190 | + embedding = embed_prompt( | |
| 191 | + prompt="hello", | |
| 192 | + tokenizer=_NoMaskTokenizer(), | |
| 193 | + base_model=_StubBaseModel(hidden_dim=8), | |
| 194 | + ) | |
| 195 | + assert embedding.shape == (8,) | |
| 196 | + | |
| 161 | 197 | |
| 162 | 198 | class TestLoadGateHandle: |
| 163 | 199 | def test_uniform_handle_from_cold_start(self, tmp_path: Path) -> None: |
tests/unit/inference/test_mlx_adapter_conversion.pymodified@@ -77,3 +77,17 @@ class TestMapAllKeys: | ||
| 77 | 77 | ] |
| 78 | 78 | with pytest.raises(MlxConversionError, match="map to the same"): |
| 79 | 79 | map_all_keys(collision) |
| 80 | + | |
| 81 | + | |
| 82 | +class TestBuildMlxAdapterConfig: | |
| 83 | + def test_non_positive_layer_count_rejected(self) -> None: | |
| 84 | + from dlm.inference.mlx_adapter import build_mlx_adapter_config | |
| 85 | + | |
| 86 | + with pytest.raises(MlxConversionError, match="expected >=1"): | |
| 87 | + build_mlx_adapter_config( | |
| 88 | + { | |
| 89 | + "r": 8, | |
| 90 | + "target_modules": ["q_proj"], | |
| 91 | + }, | |
| 92 | + 0, | |
| 93 | + ) | |
tests/unit/inference/test_mlx_backend.pyadded@@ -0,0 +1,125 @@ | ||
| 1 | +"""MLX backend helpers and lightweight backend-path coverage.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +import json | |
| 6 | +import sys | |
| 7 | +from pathlib import Path | |
| 8 | +from types import ModuleType, SimpleNamespace | |
| 9 | + | |
| 10 | +import pytest | |
| 11 | + | |
| 12 | +from dlm.base_models import BASE_MODELS | |
| 13 | +from dlm.inference.backends.mlx_backend import MlxBackend, _resolve_base_num_hidden_layers | |
| 14 | +from dlm.inference.errors import AdapterNotFoundError | |
| 15 | +from dlm.inference.mlx_adapter import MlxConversionError | |
| 16 | + | |
| 17 | + | |
| 18 | +class TestResolveBaseNumHiddenLayers: | |
| 19 | + def test_prefers_transformers_auto_config(self, monkeypatch: pytest.MonkeyPatch) -> None: | |
| 20 | + monkeypatch.setattr( | |
| 21 | + "transformers.AutoConfig.from_pretrained", | |
| 22 | + lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=24), | |
| 23 | + ) | |
| 24 | + assert _resolve_base_num_hidden_layers("org/demo") == 24 | |
| 25 | + | |
| 26 | + def test_falls_back_to_cached_config_json( | |
| 27 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 28 | + ) -> None: | |
| 29 | + snapshot = tmp_path / "snapshots" / ("a" * 40) | |
| 30 | + snapshot.mkdir(parents=True) | |
| 31 | + (snapshot / "config.json").write_text( | |
| 32 | + json.dumps({"num_hidden_layers": 18}), encoding="utf-8" | |
| 33 | + ) | |
| 34 | + monkeypatch.setattr( | |
| 35 | + "transformers.AutoConfig.from_pretrained", | |
| 36 | + lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=None), | |
| 37 | + ) | |
| 38 | + monkeypatch.setattr("huggingface_hub.snapshot_download", lambda **kwargs: str(snapshot)) | |
| 39 | + assert _resolve_base_num_hidden_layers("org/demo") == 18 | |
| 40 | + | |
| 41 | + def test_cache_lookup_errors_raise_conversion_error( | |
| 42 | + self, monkeypatch: pytest.MonkeyPatch | |
| 43 | + ) -> None: | |
| 44 | + monkeypatch.setattr( | |
| 45 | + "transformers.AutoConfig.from_pretrained", | |
| 46 | + lambda hf_id, local_files_only=True: (_ for _ in ()).throw(RuntimeError("boom")), | |
| 47 | + ) | |
| 48 | + monkeypatch.setattr( | |
| 49 | + "huggingface_hub.snapshot_download", | |
| 50 | + lambda **kwargs: (_ for _ in ()).throw(OSError("missing")), | |
| 51 | + ) | |
| 52 | + with pytest.raises(MlxConversionError, match="could not resolve num_hidden_layers"): | |
| 53 | + _resolve_base_num_hidden_layers("org/demo") | |
| 54 | + | |
| 55 | + def test_missing_num_hidden_layers_raises_conversion_error( | |
| 56 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 57 | + ) -> None: | |
| 58 | + snapshot = tmp_path / "snapshots" / ("a" * 40) | |
| 59 | + snapshot.mkdir(parents=True) | |
| 60 | + (snapshot / "config.json").write_text("{}", encoding="utf-8") | |
| 61 | + monkeypatch.setattr( | |
| 62 | + "transformers.AutoConfig.from_pretrained", | |
| 63 | + lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=None), | |
| 64 | + ) | |
| 65 | + monkeypatch.setattr("huggingface_hub.snapshot_download", lambda **kwargs: str(snapshot)) | |
| 66 | + with pytest.raises(MlxConversionError, match="has no usable num_hidden_layers"): | |
| 67 | + _resolve_base_num_hidden_layers("org/demo") | |
| 68 | + | |
| 69 | + | |
| 70 | +class TestMlxBackend: | |
| 71 | + def test_generate_before_load_raises(self) -> None: | |
| 72 | + backend = MlxBackend(SimpleNamespace()) | |
| 73 | + with pytest.raises(RuntimeError, match="before load"): | |
| 74 | + backend.generate("hello") | |
| 75 | + | |
| 76 | + def test_load_missing_adapter_raises( | |
| 77 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 78 | + ) -> None: | |
| 79 | + backend = MlxBackend(SimpleNamespace()) | |
| 80 | + monkeypatch.setattr( | |
| 81 | + "dlm.inference.loader.resolve_adapter_path", | |
| 82 | + lambda store, adapter_name=None: tmp_path / "missing", | |
| 83 | + ) | |
| 84 | + with pytest.raises(AdapterNotFoundError, match="does not exist"): | |
| 85 | + backend.load(BASE_MODELS["smollm2-135m"], SimpleNamespace(root=tmp_path)) | |
| 86 | + | |
| 87 | + def test_load_generate_and_unload_happy_path( | |
| 88 | + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch | |
| 89 | + ) -> None: | |
| 90 | + adapter_dir = tmp_path / "adapter" | |
| 91 | + adapter_dir.mkdir() | |
| 92 | + staged_dir = tmp_path / "staged" | |
| 93 | + | |
| 94 | + backend = MlxBackend(SimpleNamespace()) | |
| 95 | + monkeypatch.setattr( | |
| 96 | + "dlm.inference.loader.resolve_adapter_path", | |
| 97 | + lambda store, adapter_name=None: adapter_dir, | |
| 98 | + ) | |
| 99 | + monkeypatch.setattr( | |
| 100 | + "dlm.inference.backends.mlx_backend.stage_mlx_adapter_dir", | |
| 101 | + lambda peft_adapter_dir, dst_dir, *, base_hf_id: staged_dir, | |
| 102 | + ) | |
| 103 | + | |
| 104 | + fake_mlx = ModuleType("mlx_lm") | |
| 105 | + fake_mlx.load = lambda hf_id, adapter_path: ("model", "tokenizer") | |
| 106 | + fake_mlx.generate = lambda model, tokenizer, *, prompt, max_tokens, sampler, verbose: ( | |
| 107 | + "mlx output" | |
| 108 | + ) | |
| 109 | + fake_sample_utils = ModuleType("mlx_lm.sample_utils") | |
| 110 | + fake_sample_utils.make_sampler = lambda temp, top_p, top_k: { | |
| 111 | + "temp": temp, | |
| 112 | + "top_p": top_p, | |
| 113 | + "top_k": top_k, | |
| 114 | + } | |
| 115 | + monkeypatch.setitem(sys.modules, "mlx_lm", fake_mlx) | |
| 116 | + monkeypatch.setitem(sys.modules, "mlx_lm.sample_utils", fake_sample_utils) | |
| 117 | + | |
| 118 | + backend.load(BASE_MODELS["smollm2-135m"], SimpleNamespace(root=tmp_path)) | |
| 119 | + assert backend.generate( | |
| 120 | + "prompt", max_new_tokens=4, temperature=0.5, top_p=0.9, top_k=12 | |
| 121 | + ) == ("mlx output") | |
| 122 | + backend.unload() | |
| 123 | + assert backend._workdir is None | |
| 124 | + assert backend._model is None | |
| 125 | + assert backend._tokenizer is None | |