tenseleyflow/documentlanguagemodel / 1c7561e

Browse files

Cover inference edge branches

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
1c7561eba8ff40bbde71f0be01f69cabf84f8d6a
Parents
62c3366
Tree
ad20a71

5 changed files

StatusFile+-
A tests/unit/inference/test_audio_generate.py 137 0
M tests/unit/inference/test_backend_select.py 35 0
M tests/unit/inference/test_gate.py 36 0
M tests/unit/inference/test_mlx_adapter_conversion.py 14 0
A tests/unit/inference/test_mlx_backend.py 125 0
tests/unit/inference/test_audio_generate.pyadded
@@ -0,0 +1,137 @@
1
+"""Audio inference helpers — prompt shaping, waveform loading, generation."""
2
+
3
+from __future__ import annotations
4
+
5
+import sys
6
+from pathlib import Path
7
+from types import ModuleType
8
+
9
+import numpy as np
10
+import pytest
11
+import torch
12
+
13
+from dlm.inference.audio_generate import format_audio_prompt, generate_audio, load_audios
14
+
15
+
16
+class TestFormatAudioPrompt:
17
+    def test_respects_user_placed_audio_token(self) -> None:
18
+        prompt = "Please compare <audio> and explain."
19
+        assert format_audio_prompt(prompt, audio_token="<audio>", num_audios=2) == prompt
20
+
21
+    def test_prepends_one_token_per_audio(self) -> None:
22
+        assert (
23
+            format_audio_prompt("describe", audio_token="<audio>", num_audios=2)
24
+            == "<audio><audio>\ndescribe"
25
+        )
26
+
27
+    def test_empty_prompt_emits_tokens_only(self) -> None:
28
+        assert (
29
+            format_audio_prompt("", audio_token="<audio>", num_audios=3) == "<audio><audio><audio>"
30
+        )
31
+
32
+
33
+class TestLoadAudios:
34
+    def test_missing_file_raises(self, tmp_path: Path) -> None:
35
+        with pytest.raises(FileNotFoundError, match="audio not found"):
36
+            load_audios([tmp_path / "missing.wav"], target_sample_rate=16_000)
37
+
38
+    def test_downmixes_stereo_to_mono(
39
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
40
+    ) -> None:
41
+        path = tmp_path / "stereo.wav"
42
+        path.write_bytes(b"stub")
43
+
44
+        fake_sf = ModuleType("soundfile")
45
+        fake_sf.read = lambda _path, dtype, always_2d: (
46
+            np.array([[1.0, 3.0], [5.0, 7.0]], dtype=np.float32),
47
+            16_000,
48
+        )
49
+        monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
50
+
51
+        [waveform] = load_audios([path], target_sample_rate=16_000)
52
+        assert waveform.dtype == np.float32
53
+        assert waveform.tolist() == pytest.approx([2.0, 6.0])
54
+
55
+    def test_sample_rate_mismatch_refused_without_auto_resample(
56
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
57
+    ) -> None:
58
+        path = tmp_path / "native.wav"
59
+        path.write_bytes(b"stub")
60
+
61
+        fake_sf = ModuleType("soundfile")
62
+        fake_sf.read = lambda _path, dtype, always_2d: (np.array([1.0], dtype=np.float32), 22_050)
63
+        monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
64
+
65
+        with pytest.raises(ValueError, match="does not match pinned 16000 Hz"):
66
+            load_audios([path], target_sample_rate=16_000, auto_resample=False)
67
+
68
+    def test_sample_rate_mismatch_resamples_when_enabled(
69
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
70
+    ) -> None:
71
+        path = tmp_path / "native.wav"
72
+        path.write_bytes(b"stub")
73
+
74
+        fake_sf = ModuleType("soundfile")
75
+        fake_sf.read = lambda _path, dtype, always_2d: (
76
+            np.array([1.0, 2.0], dtype=np.float32),
77
+            22_050,
78
+        )
79
+        monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
80
+        monkeypatch.setattr(
81
+            "dlm.data.audio_resample.resample",
82
+            lambda mono, src_sr, dst_sr: np.array([9.0, 8.0], dtype=np.float32),
83
+        )
84
+
85
+        [waveform] = load_audios([path], target_sample_rate=16_000, auto_resample=True)
86
+        assert waveform.tolist() == pytest.approx([9.0, 8.0])
87
+
88
+
89
+class _Inputs(dict[str, torch.Tensor]):
90
+    def to(self, device: object) -> _Inputs:
91
+        return self
92
+
93
+
94
+class TestGenerateAudio:
95
+    def test_generate_audio_decodes_response_only_tokens(self) -> None:
96
+        class _Tokenizer:
97
+            pad_token_id = 99
98
+
99
+            def decode(self, tokens: torch.Tensor, skip_special_tokens: bool = True) -> str:
100
+                assert tokens.tolist() == [4, 5]
101
+                return "transcript"
102
+
103
+        class _Processor:
104
+            def __init__(self) -> None:
105
+                self.tokenizer = _Tokenizer()
106
+
107
+            def __call__(
108
+                self,
109
+                *,
110
+                audios: list[np.ndarray],
111
+                text: str,
112
+                sampling_rate: int,
113
+                return_tensors: str,
114
+            ) -> _Inputs:
115
+                assert len(audios) == 1
116
+                assert text == "<audio>\nwhat happened?"
117
+                assert sampling_rate == 16_000
118
+                return _Inputs({"input_ids": torch.tensor([[1, 2, 3]])})
119
+
120
+        class _Model:
121
+            device = torch.device("cpu")
122
+
123
+            def generate(self, **kwargs: object) -> torch.Tensor:
124
+                assert kwargs["pad_token_id"] == 99
125
+                return torch.tensor([[1, 2, 3, 4, 5]])
126
+
127
+        out = generate_audio(
128
+            _Model(),
129
+            _Processor(),
130
+            "what happened?",
131
+            [np.array([1.0], dtype=np.float32)],
132
+            audio_token="<audio>",
133
+            sample_rate=16_000,
134
+            max_new_tokens=2,
135
+            temperature=0.0,
136
+        )
137
+        assert out == "transcript"
tests/unit/inference/test_backend_select.pymodified
@@ -9,6 +9,7 @@ import pytest
99
 from dlm.inference.backends.select import (
1010
     UnsupportedBackendError,
1111
     build_backend,
12
+    is_apple_silicon,
1213
     select_backend,
1314
 )
1415
 
@@ -77,6 +78,12 @@ class TestBuildBackend:
7778
         backend = build_backend("pytorch", MagicMock())
7879
         assert isinstance(backend, PyTorchBackend)
7980
 
81
+    def test_mlx_returns_mlx_backend(self) -> None:
82
+        from dlm.inference.backends.mlx_backend import MlxBackend
83
+
84
+        backend = build_backend("mlx", MagicMock())
85
+        assert isinstance(backend, MlxBackend)
86
+
8087
     def test_unknown_backend_raises(self) -> None:
8188
         with pytest.raises(ValueError, match="unknown backend"):
8289
             build_backend("haskell", MagicMock())  # type: ignore[arg-type]
@@ -94,3 +101,31 @@ class TestMlxAvailableDoesNotImportMlx:
94101
         ):
95102
             assert sel.mlx_available() is False
96103
             m_find.assert_not_called()
104
+
105
+    def test_mlx_available_checks_both_packages_on_apple_silicon(self) -> None:
106
+        from dlm.inference.backends import select as sel
107
+
108
+        with (
109
+            patch.object(sel, "is_apple_silicon", return_value=True),
110
+            patch.object(
111
+                sel.importlib.util, "find_spec", side_effect=[object(), object()]
112
+            ) as m_find,
113
+        ):
114
+            assert sel.mlx_available() is True
115
+            assert m_find.call_count == 2
116
+
117
+
118
+class TestPlatformHelper:
119
+    def test_is_apple_silicon_true_only_for_darwin_arm64(self) -> None:
120
+        with (
121
+            patch("dlm.inference.backends.select.sys.platform", "darwin"),
122
+            patch("dlm.inference.backends.select.platform.machine", return_value="arm64"),
123
+        ):
124
+            assert is_apple_silicon() is True
125
+
126
+    def test_is_apple_silicon_false_for_other_hosts(self) -> None:
127
+        with (
128
+            patch("dlm.inference.backends.select.sys.platform", "linux"),
129
+            patch("dlm.inference.backends.select.platform.machine", return_value="x86_64"),
130
+        ):
131
+            assert is_apple_silicon() is False
tests/unit/inference/test_gate.pymodified
@@ -88,6 +88,26 @@ class _StubBaseModel:
8888
     forward = __call__
8989
 
9090
 
91
+class _NoMaskTokenizer(_StubTokenizer):
92
+    def __call__(
93
+        self,
94
+        prompt: str,
95
+        *,
96
+        return_tensors: str = "pt",
97
+        truncation: bool = True,
98
+        max_length: int = 512,
99
+    ) -> dict[str, object]:
100
+        import torch
101
+
102
+        ids = torch.randint(0, 100, (1, self._seq_len))
103
+        return {"input_ids": ids}
104
+
105
+
106
+class _NoParamBaseModel(_StubBaseModel):
107
+    def parameters(self):  # type: ignore[no-untyped-def]
108
+        return iter(())
109
+
110
+
91111
 def _train_gate_on_store(
92112
     tmp_path: Path,
93113
     *,
@@ -158,6 +178,22 @@ class TestEmbedPrompt:
158178
         e2 = embed_prompt(prompt="hello world", tokenizer=tokenizer, base_model=model)
159179
         assert not torch.allclose(e1, e2)
160180
 
181
+    def test_falls_back_to_cpu_when_model_has_no_parameters(self) -> None:
182
+        embedding = embed_prompt(
183
+            prompt="hello",
184
+            tokenizer=_StubTokenizer(),
185
+            base_model=_NoParamBaseModel(hidden_dim=8),
186
+        )
187
+        assert embedding.shape == (8,)
188
+
189
+    def test_mean_pools_without_attention_mask(self) -> None:
190
+        embedding = embed_prompt(
191
+            prompt="hello",
192
+            tokenizer=_NoMaskTokenizer(),
193
+            base_model=_StubBaseModel(hidden_dim=8),
194
+        )
195
+        assert embedding.shape == (8,)
196
+
161197
 
162198
 class TestLoadGateHandle:
163199
     def test_uniform_handle_from_cold_start(self, tmp_path: Path) -> None:
tests/unit/inference/test_mlx_adapter_conversion.pymodified
@@ -77,3 +77,17 @@ class TestMapAllKeys:
7777
         ]
7878
         with pytest.raises(MlxConversionError, match="map to the same"):
7979
             map_all_keys(collision)
80
+
81
+
82
+class TestBuildMlxAdapterConfig:
83
+    def test_non_positive_layer_count_rejected(self) -> None:
84
+        from dlm.inference.mlx_adapter import build_mlx_adapter_config
85
+
86
+        with pytest.raises(MlxConversionError, match="expected >=1"):
87
+            build_mlx_adapter_config(
88
+                {
89
+                    "r": 8,
90
+                    "target_modules": ["q_proj"],
91
+                },
92
+                0,
93
+            )
tests/unit/inference/test_mlx_backend.pyadded
@@ -0,0 +1,125 @@
1
+"""MLX backend helpers and lightweight backend-path coverage."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+import sys
7
+from pathlib import Path
8
+from types import ModuleType, SimpleNamespace
9
+
10
+import pytest
11
+
12
+from dlm.base_models import BASE_MODELS
13
+from dlm.inference.backends.mlx_backend import MlxBackend, _resolve_base_num_hidden_layers
14
+from dlm.inference.errors import AdapterNotFoundError
15
+from dlm.inference.mlx_adapter import MlxConversionError
16
+
17
+
18
+class TestResolveBaseNumHiddenLayers:
19
+    def test_prefers_transformers_auto_config(self, monkeypatch: pytest.MonkeyPatch) -> None:
20
+        monkeypatch.setattr(
21
+            "transformers.AutoConfig.from_pretrained",
22
+            lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=24),
23
+        )
24
+        assert _resolve_base_num_hidden_layers("org/demo") == 24
25
+
26
+    def test_falls_back_to_cached_config_json(
27
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
28
+    ) -> None:
29
+        snapshot = tmp_path / "snapshots" / ("a" * 40)
30
+        snapshot.mkdir(parents=True)
31
+        (snapshot / "config.json").write_text(
32
+            json.dumps({"num_hidden_layers": 18}), encoding="utf-8"
33
+        )
34
+        monkeypatch.setattr(
35
+            "transformers.AutoConfig.from_pretrained",
36
+            lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=None),
37
+        )
38
+        monkeypatch.setattr("huggingface_hub.snapshot_download", lambda **kwargs: str(snapshot))
39
+        assert _resolve_base_num_hidden_layers("org/demo") == 18
40
+
41
+    def test_cache_lookup_errors_raise_conversion_error(
42
+        self, monkeypatch: pytest.MonkeyPatch
43
+    ) -> None:
44
+        monkeypatch.setattr(
45
+            "transformers.AutoConfig.from_pretrained",
46
+            lambda hf_id, local_files_only=True: (_ for _ in ()).throw(RuntimeError("boom")),
47
+        )
48
+        monkeypatch.setattr(
49
+            "huggingface_hub.snapshot_download",
50
+            lambda **kwargs: (_ for _ in ()).throw(OSError("missing")),
51
+        )
52
+        with pytest.raises(MlxConversionError, match="could not resolve num_hidden_layers"):
53
+            _resolve_base_num_hidden_layers("org/demo")
54
+
55
+    def test_missing_num_hidden_layers_raises_conversion_error(
56
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
57
+    ) -> None:
58
+        snapshot = tmp_path / "snapshots" / ("a" * 40)
59
+        snapshot.mkdir(parents=True)
60
+        (snapshot / "config.json").write_text("{}", encoding="utf-8")
61
+        monkeypatch.setattr(
62
+            "transformers.AutoConfig.from_pretrained",
63
+            lambda hf_id, local_files_only=True: SimpleNamespace(num_hidden_layers=None),
64
+        )
65
+        monkeypatch.setattr("huggingface_hub.snapshot_download", lambda **kwargs: str(snapshot))
66
+        with pytest.raises(MlxConversionError, match="has no usable num_hidden_layers"):
67
+            _resolve_base_num_hidden_layers("org/demo")
68
+
69
+
70
+class TestMlxBackend:
71
+    def test_generate_before_load_raises(self) -> None:
72
+        backend = MlxBackend(SimpleNamespace())
73
+        with pytest.raises(RuntimeError, match="before load"):
74
+            backend.generate("hello")
75
+
76
+    def test_load_missing_adapter_raises(
77
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
78
+    ) -> None:
79
+        backend = MlxBackend(SimpleNamespace())
80
+        monkeypatch.setattr(
81
+            "dlm.inference.loader.resolve_adapter_path",
82
+            lambda store, adapter_name=None: tmp_path / "missing",
83
+        )
84
+        with pytest.raises(AdapterNotFoundError, match="does not exist"):
85
+            backend.load(BASE_MODELS["smollm2-135m"], SimpleNamespace(root=tmp_path))
86
+
87
+    def test_load_generate_and_unload_happy_path(
88
+        self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
89
+    ) -> None:
90
+        adapter_dir = tmp_path / "adapter"
91
+        adapter_dir.mkdir()
92
+        staged_dir = tmp_path / "staged"
93
+
94
+        backend = MlxBackend(SimpleNamespace())
95
+        monkeypatch.setattr(
96
+            "dlm.inference.loader.resolve_adapter_path",
97
+            lambda store, adapter_name=None: adapter_dir,
98
+        )
99
+        monkeypatch.setattr(
100
+            "dlm.inference.backends.mlx_backend.stage_mlx_adapter_dir",
101
+            lambda peft_adapter_dir, dst_dir, *, base_hf_id: staged_dir,
102
+        )
103
+
104
+        fake_mlx = ModuleType("mlx_lm")
105
+        fake_mlx.load = lambda hf_id, adapter_path: ("model", "tokenizer")
106
+        fake_mlx.generate = lambda model, tokenizer, *, prompt, max_tokens, sampler, verbose: (
107
+            "mlx output"
108
+        )
109
+        fake_sample_utils = ModuleType("mlx_lm.sample_utils")
110
+        fake_sample_utils.make_sampler = lambda temp, top_p, top_k: {
111
+            "temp": temp,
112
+            "top_p": top_p,
113
+            "top_k": top_k,
114
+        }
115
+        monkeypatch.setitem(sys.modules, "mlx_lm", fake_mlx)
116
+        monkeypatch.setitem(sys.modules, "mlx_lm.sample_utils", fake_sample_utils)
117
+
118
+        backend.load(BASE_MODELS["smollm2-135m"], SimpleNamespace(root=tmp_path))
119
+        assert backend.generate(
120
+            "prompt", max_new_tokens=4, temperature=0.5, top_p=0.9, top_k=12
121
+        ) == ("mlx output")
122
+        backend.unload()
123
+        assert backend._workdir is None
124
+        assert backend._model is None
125
+        assert backend._tokenizer is None