Python · 9286 bytes Raw Blame History
1 """Focused `dlm prompt` edge coverage for the remaining text/VL/audio branches."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6 from types import SimpleNamespace
7 from typing import Any
8
9 import pytest
10 from typer.testing import CliRunner
11
12 from dlm.base_models import BaseModelSpec
13 from dlm.cli.app import app
14
15
16 def _write_doc(path: Path, *, base_model: str = "demo-1b") -> None:
17 path.write_text(
18 f"---\ndlm_id: 01HZ4X7TGZM3J1A2B3C4D5E6F7\nbase_model: {base_model}\n---\nbody\n",
19 encoding="utf-8",
20 )
21
22
23 def _joined_output(result: object) -> str:
24 text = getattr(result, "output", "") + getattr(result, "stderr", "")
25 return " ".join(text.split())
26
27
28 def _spec(*, key: str = "demo-1b", modality: str = "text") -> BaseModelSpec:
29 payload: dict[str, object] = {
30 "key": key,
31 "hf_id": f"org/{key}",
32 "revision": "0123456789abcdef0123456789abcdef01234567",
33 "architecture": "DemoForCausalLM",
34 "params": 1_000_000_000,
35 "target_modules": ["q_proj", "v_proj"],
36 "template": "chatml",
37 "gguf_arch": "demo",
38 "tokenizer_pre": "demo",
39 "license_spdx": "Apache-2.0",
40 "license_url": None,
41 "requires_acceptance": False,
42 "redistributable": True,
43 "size_gb_fp16": 2.0,
44 "context_length": 4096,
45 "recommended_seq_len": 2048,
46 "modality": modality,
47 }
48 if modality == "vision-language":
49 payload["vl_preprocessor_plan"] = {
50 "target_size": [224, 224],
51 "image_token": "<image>",
52 "num_image_tokens": 256,
53 }
54 elif modality == "audio-language":
55 payload["audio_preprocessor_plan"] = {
56 "sample_rate": 16000,
57 "audio_token": "<audio>",
58 "num_audio_tokens": 64,
59 "max_length_seconds": 30.0,
60 }
61 return BaseModelSpec.model_validate(payload)
62
63
64 def _patch_prompt_runtime(
65 monkeypatch: pytest.MonkeyPatch,
66 *,
67 spec: BaseModelSpec | None = None,
68 dispatch: object | None = None,
69 ) -> None:
70 monkeypatch.setattr(
71 "dlm.base_models.resolve",
72 lambda *args, **kwargs: spec or _spec(),
73 )
74 monkeypatch.setattr(
75 "dlm.hardware.doctor",
76 lambda: SimpleNamespace(capabilities=object()),
77 )
78 monkeypatch.setattr(
79 "dlm.modality.modality_for",
80 lambda model_spec: (
81 dispatch
82 or SimpleNamespace(
83 accepts_images=model_spec.modality == "vision-language",
84 accepts_audio=model_spec.modality == "audio-language",
85 )
86 ),
87 )
88
89
90 class TestPromptEdgeBranches:
91 def test_invalid_backend_value_exits_2(self, tmp_path: Path) -> None:
92 doc = tmp_path / "doc.dlm"
93 _write_doc(doc)
94 runner = CliRunner()
95
96 result = runner.invoke(
97 app,
98 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--backend", "bogus"],
99 )
100
101 assert result.exit_code == 2, result.output
102 assert "--backend must be" in _joined_output(result)
103
104 def test_gated_base_without_recorded_acceptance_exits_1(
105 self,
106 tmp_path: Path,
107 monkeypatch: pytest.MonkeyPatch,
108 ) -> None:
109 from dlm.base_models.errors import GatedModelError
110
111 doc = tmp_path / "doc.dlm"
112 _write_doc(doc, base_model="gated-base")
113 runner = CliRunner()
114
115 monkeypatch.setattr(
116 "dlm.base_models.resolve",
117 lambda *args, **kwargs: (_ for _ in ()).throw(
118 GatedModelError("org/gated-base", "https://license.example")
119 ),
120 )
121
122 result = runner.invoke(
123 app,
124 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello"],
125 )
126
127 assert result.exit_code == 1, result.output
128 assert "run `dlm train --i-accept-license` first" in _joined_output(result)
129
130 def test_unsupported_backend_error_exits_2(
131 self,
132 tmp_path: Path,
133 monkeypatch: pytest.MonkeyPatch,
134 ) -> None:
135 from dlm.inference.backends.select import UnsupportedBackendError
136
137 doc = tmp_path / "doc.dlm"
138 _write_doc(doc)
139 runner = CliRunner()
140
141 _patch_prompt_runtime(monkeypatch)
142 monkeypatch.setattr(
143 "dlm.inference.backends.select_backend",
144 lambda *args, **kwargs: (_ for _ in ()).throw(
145 UnsupportedBackendError("mlx backend unavailable")
146 ),
147 )
148
149 result = runner.invoke(
150 app,
151 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--backend", "mlx"],
152 )
153
154 assert result.exit_code == 2, result.output
155 assert "mlx backend unavailable" in _joined_output(result)
156
157 def test_verbose_text_prompt_logs_backend_and_generates(
158 self,
159 tmp_path: Path,
160 monkeypatch: pytest.MonkeyPatch,
161 ) -> None:
162 doc = tmp_path / "doc.dlm"
163 _write_doc(doc)
164 runner = CliRunner()
165 captured: dict[str, Any] = {}
166
167 class _FakeBackend:
168 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
169 captured["adapter_name"] = adapter_name
170
171 def generate(self, query: str, **kwargs: object) -> str:
172 captured["query"] = query
173 captured["kwargs"] = kwargs
174 return "ok"
175
176 _patch_prompt_runtime(monkeypatch)
177 monkeypatch.setattr(
178 "dlm.inference.backends.select_backend",
179 lambda *args, **kwargs: "pytorch",
180 )
181 monkeypatch.setattr(
182 "dlm.inference.backends.build_backend",
183 lambda *args, **kwargs: _FakeBackend(),
184 )
185
186 result = runner.invoke(
187 app,
188 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--verbose"],
189 )
190
191 assert result.exit_code == 0, result.output
192 assert captured["query"] == "hello"
193 assert "backend: pytorch" in _joined_output(result)
194 kwargs = captured["kwargs"]
195 assert isinstance(kwargs, dict)
196 assert kwargs["top_p"] is None
197
198 def test_missing_adapter_maps_to_exit_1(
199 self,
200 tmp_path: Path,
201 monkeypatch: pytest.MonkeyPatch,
202 ) -> None:
203 from dlm.inference import AdapterNotFoundError
204
205 doc = tmp_path / "doc.dlm"
206 _write_doc(doc)
207 runner = CliRunner()
208
209 class _MissingAdapterBackend:
210 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
211 raise AdapterNotFoundError("missing adapter")
212
213 _patch_prompt_runtime(monkeypatch)
214 monkeypatch.setattr(
215 "dlm.inference.backends.select_backend",
216 lambda *args, **kwargs: "pytorch",
217 )
218 monkeypatch.setattr(
219 "dlm.inference.backends.build_backend",
220 lambda *args, **kwargs: _MissingAdapterBackend(),
221 )
222
223 result = runner.invoke(
224 app,
225 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello"],
226 )
227
228 assert result.exit_code == 1, result.output
229 assert "missing adapter" in _joined_output(result)
230
231 def test_vision_language_dispatch_branch_invokes_helper(
232 self,
233 tmp_path: Path,
234 monkeypatch: pytest.MonkeyPatch,
235 ) -> None:
236 doc = tmp_path / "doc.dlm"
237 _write_doc(doc, base_model="vl-demo")
238 image = tmp_path / "frame.png"
239 image.write_bytes(b"\x89PNG fake")
240 runner = CliRunner()
241 captured: dict[str, Any] = {}
242
243 _patch_prompt_runtime(
244 monkeypatch,
245 spec=_spec(key="vl-demo", modality="vision-language"),
246 )
247 monkeypatch.setattr(
248 "dlm.cli.commands._dispatch_vl_prompt",
249 lambda **kwargs: captured.update(kwargs),
250 )
251
252 result = runner.invoke(
253 app,
254 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--image", str(image)],
255 )
256
257 assert result.exit_code == 0, result.output
258 assert captured["query"] == "hello"
259 assert captured["image_paths"] == [image]
260 assert captured["spec"].key == "vl-demo"
261
262 def test_audio_dispatch_branch_invokes_helper(
263 self,
264 tmp_path: Path,
265 monkeypatch: pytest.MonkeyPatch,
266 ) -> None:
267 doc = tmp_path / "doc.dlm"
268 _write_doc(doc, base_model="audio-demo")
269 audio = tmp_path / "clip.wav"
270 audio.write_bytes(b"fake wav bytes")
271 runner = CliRunner()
272 captured: dict[str, Any] = {}
273
274 _patch_prompt_runtime(
275 monkeypatch,
276 spec=_spec(key="audio-demo", modality="audio-language"),
277 )
278 monkeypatch.setattr(
279 "dlm.cli.commands._dispatch_audio_prompt",
280 lambda **kwargs: captured.update(kwargs),
281 )
282
283 result = runner.invoke(
284 app,
285 ["--home", str(tmp_path / "home"), "prompt", str(doc), "hello", "--audio", str(audio)],
286 )
287
288 assert result.exit_code == 0, result.output
289 assert captured["query"] == "hello"
290 assert captured["audio_paths"] == [audio]
291 assert captured["spec"].key == "audio-demo"