Python · 20480 bytes Raw Blame History
1 """Coverage-oriented tests for train/prompt/repl command bodies."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6 from types import SimpleNamespace
7 from typing import Any
8
9 from typer.testing import CliRunner
10
11 from dlm.base_models import BaseModelSpec
12 from dlm.cli.app import app
13 from dlm.doc.schema import TrainingConfig
14
15
16 def _init_doc(tmp_path: Path, *, base: str = "smollm2-135m") -> Path:
17 doc = tmp_path / "doc.dlm"
18 runner = CliRunner()
19 result = runner.invoke(
20 app,
21 [
22 "--home",
23 str(tmp_path / "home"),
24 "init",
25 str(doc),
26 "--base",
27 base,
28 ],
29 )
30 assert result.exit_code == 0, result.output
31 return doc
32
33
34 def _fake_doctor_result() -> object:
35 return SimpleNamespace(plan=object(), capabilities=object())
36
37
38 def _spec(**overrides: object) -> BaseModelSpec:
39 defaults: dict[str, object] = {
40 "key": "demo-1b",
41 "hf_id": "org/demo-1b",
42 "revision": "0123456789abcdef0123456789abcdef01234567",
43 "architecture": "DemoForCausalLM",
44 "params": 1_000_000_000,
45 "target_modules": ["q_proj", "v_proj"],
46 "template": "chatml",
47 "gguf_arch": "demo",
48 "tokenizer_pre": "demo",
49 "license_spdx": "Apache-2.0",
50 "license_url": None,
51 "requires_acceptance": False,
52 "redistributable": True,
53 "size_gb_fp16": 2.0,
54 "context_length": 4096,
55 "recommended_seq_len": 2048,
56 }
57 defaults.update(overrides)
58 return BaseModelSpec.model_validate(defaults)
59
60
61 class TestTrainCommandCoverage:
62 def test_train_uses_resolved_base_for_doctor_plan(
63 self, tmp_path: Path, monkeypatch: Any
64 ) -> None:
65 doc = _init_doc(tmp_path)
66 runner = CliRunner()
67 captured_doctor_calls: list[dict[str, object]] = []
68 capabilities = object()
69 plan = object()
70
71 fake_result = SimpleNamespace(
72 adapter_version=1,
73 steps=2,
74 seed=42,
75 determinism=SimpleNamespace(class_="strict"),
76 adapter_path=tmp_path / "adapter",
77 log_path=tmp_path / "train.jsonl",
78 final_train_loss=0.25,
79 )
80 fake_phase = SimpleNamespace(phase="sft", result=fake_result)
81
82 def _doctor(**kwargs: object) -> object:
83 captured_doctor_calls.append(kwargs)
84 return SimpleNamespace(plan=plan, capabilities=capabilities)
85
86 resolved_spec = _spec(
87 params=4_000_000_000,
88 context_length=8192,
89 context_length_effective=1024,
90 )
91
92 monkeypatch.setattr("dlm.base_models.resolve", lambda *args, **kwargs: resolved_spec)
93 monkeypatch.setattr("dlm.hardware.doctor", _doctor)
94 monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
95
96 def _run_phases(*args: object, **kwargs: object) -> list[object]:
97 assert args[3] is plan
98 assert kwargs["capabilities"] is capabilities
99 return [fake_phase]
100
101 monkeypatch.setattr("dlm.train.preference.phase_orchestrator.run_phases", _run_phases)
102
103 result = runner.invoke(
104 app,
105 ["--home", str(tmp_path / "home"), "train", str(doc), "--max-steps", "2"],
106 )
107 assert result.exit_code == 0, result.output
108 assert len(captured_doctor_calls) == 1
109 call = captured_doctor_calls[0]
110 assert call["base_params"] == 4_000_000_000
111 assert call["seq_len"] == 1024
112 assert call["world_size"] == 1
113 training_config = call["training_config"]
114 assert isinstance(training_config, TrainingConfig)
115 assert training_config.sequence_len == 2048
116
117 def test_train_success_prints_phase_summary(self, tmp_path: Path, monkeypatch: Any) -> None:
118 doc = _init_doc(tmp_path)
119 runner = CliRunner()
120
121 fake_result = SimpleNamespace(
122 adapter_version=1,
123 steps=3,
124 seed=42,
125 determinism=SimpleNamespace(class_="strict"),
126 adapter_path=tmp_path / "adapter",
127 log_path=tmp_path / "train.jsonl",
128 final_train_loss=0.125,
129 )
130 fake_phase = SimpleNamespace(phase="sft", result=fake_result)
131
132 monkeypatch.setattr("dlm.hardware.doctor", lambda **kwargs: _fake_doctor_result())
133 monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
134 monkeypatch.setattr(
135 "dlm.train.preference.phase_orchestrator.run_phases",
136 lambda *args, **kwargs: [fake_phase],
137 )
138
139 result = runner.invoke(
140 app,
141 ["--home", str(tmp_path / "home"), "train", str(doc), "--max-steps", "3"],
142 )
143 assert result.exit_code == 0, result.output
144 assert "sft:" in result.output
145 assert "adapter:" in result.output
146 assert "0.125" in result.output
147
148 def test_train_watch_with_rpc_starts_server(self, tmp_path: Path, monkeypatch: Any) -> None:
149 doc = _init_doc(tmp_path)
150 runner = CliRunner()
151
152 fake_result = SimpleNamespace(
153 adapter_version=1,
154 steps=1,
155 seed=7,
156 determinism=SimpleNamespace(class_="strict"),
157 adapter_path=tmp_path / "adapter",
158 log_path=tmp_path / "train.jsonl",
159 final_train_loss=None,
160 )
161 fake_phase = SimpleNamespace(phase="sft", result=fake_result)
162
163 class _FakeQueue:
164 capacity = 123
165
166 def drain(self) -> list[object]:
167 return []
168
169 class _FakeServer:
170 def __init__(self, *, host: str, port: int, token: str, queue: object) -> None:
171 self.address = (host, port)
172
173 def start(self) -> None:
174 return None
175
176 def stop(self) -> None:
177 return None
178
179 monkeypatch.setenv("DLM_PROBE_TOKEN", "secret")
180 monkeypatch.setattr("dlm.hardware.doctor", lambda **kwargs: _fake_doctor_result())
181 monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
182 monkeypatch.setattr(
183 "dlm.train.preference.phase_orchestrator.run_phases",
184 lambda *args, **kwargs: [fake_phase],
185 )
186 monkeypatch.setattr("dlm.train.inject.InjectedProbeQueue", _FakeQueue)
187 monkeypatch.setattr("dlm.train.rpc.ProbeRpcServer", _FakeServer)
188 monkeypatch.setattr("dlm.watch.loop.run_watch", lambda *args, **kwargs: 0)
189
190 result = runner.invoke(
191 app,
192 [
193 "--home",
194 str(tmp_path / "home"),
195 "train",
196 str(doc),
197 "--watch",
198 "--listen-rpc",
199 "127.0.0.1:7777",
200 ],
201 )
202 assert result.exit_code == 0, result.output
203 assert "rpc:" in result.output
204 assert "watch:" in result.output
205
206 def test_train_noop_watch_repl_and_bounded_rpc_refusals(
207 self,
208 tmp_path: Path,
209 monkeypatch: Any,
210 ) -> None:
211 doc = _init_doc(tmp_path)
212 runner = CliRunner()
213
214 monkeypatch.setattr("dlm.hardware.doctor", lambda **kwargs: _fake_doctor_result())
215 monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
216 monkeypatch.setattr(
217 "dlm.train.preference.phase_orchestrator.run_phases",
218 lambda *args, **kwargs: [],
219 )
220
221 no_op = runner.invoke(
222 app,
223 ["--home", str(tmp_path / "home"), "train", str(doc)],
224 )
225 assert no_op.exit_code == 0, no_op.output
226 assert "nothing to train" in no_op.output
227
228 fake_result = SimpleNamespace(
229 adapter_version=1,
230 steps=1,
231 seed=42,
232 determinism=SimpleNamespace(class_="strict"),
233 adapter_path=tmp_path / "adapter",
234 log_path=tmp_path / "train.jsonl",
235 final_train_loss=None,
236 )
237 fake_phase = SimpleNamespace(phase="sft", result=fake_result)
238 monkeypatch.setattr(
239 "dlm.train.preference.phase_orchestrator.run_phases",
240 lambda *args, **kwargs: [fake_phase],
241 )
242
243 watch_repl = runner.invoke(
244 app,
245 ["--home", str(tmp_path / "home"), "train", str(doc), "--watch", "--repl"],
246 )
247 assert watch_repl.exit_code == 2, watch_repl.output
248 assert "not yet implemented" in watch_repl.output
249
250 monkeypatch.setenv("DLM_PROBE_TOKEN", "secret")
251 bounded_rpc = runner.invoke(
252 app,
253 [
254 "--home",
255 str(tmp_path / "home"),
256 "train",
257 str(doc),
258 "--listen-rpc",
259 "127.0.0.1:7777",
260 "--max-cycles",
261 "1",
262 ],
263 )
264 assert bounded_rpc.exit_code == 2, bounded_rpc.output
265 assert "--watch for now" in bounded_rpc.output
266
267 def test_multi_gpu_helper_and_strip(self, monkeypatch: Any) -> None:
268 from rich.console import Console
269
270 from dlm.cli.commands import _maybe_dispatch_multi_gpu, _strip_gpus_from_argv
271 from dlm.train.distributed import UnsupportedGpuSpecError
272
273 class _GpuSpec:
274 def __init__(self, device_ids: tuple[int, ...]) -> None:
275 self._device_ids = device_ids
276
277 def resolve(self, device_count: int) -> tuple[int, ...]:
278 return self._device_ids
279
280 console = Console(stderr=True)
281 monkeypatch.setattr(
282 "dlm.train.distributed.parse_gpus",
283 lambda raw: (_ for _ in ()).throw(UnsupportedGpuSpecError("bad gpus")),
284 )
285 assert _maybe_dispatch_multi_gpu("bogus", ["dlm", "train"], console) == 2
286
287 class _BadResolveGpuSpec:
288 def resolve(self, device_count: int) -> tuple[int, ...]:
289 raise UnsupportedGpuSpecError("gpu index 7 is unavailable")
290
291 monkeypatch.setattr(
292 "dlm.train.distributed.parse_gpus",
293 lambda raw: _BadResolveGpuSpec(),
294 )
295 assert _maybe_dispatch_multi_gpu("7", ["dlm", "train"], console) == 2
296
297 monkeypatch.setattr("dlm.train.distributed.parse_gpus", lambda raw: _GpuSpec((0,)))
298 import torch
299
300 monkeypatch.setattr(torch.cuda, "device_count", lambda: 2)
301 assert _maybe_dispatch_multi_gpu("1", ["dlm", "train"], console) is None
302
303 launched: dict[str, object] = {}
304 monkeypatch.setattr("dlm.train.distributed.parse_gpus", lambda raw: _GpuSpec((1, 3)))
305 monkeypatch.setattr(
306 "dlm.train.distributed.launch_multi_gpu",
307 lambda device_ids, cli_args, mixed_precision="bf16": (
308 launched.update(
309 {
310 "device_ids": device_ids,
311 "cli_args": cli_args,
312 "mixed_precision": mixed_precision,
313 }
314 )
315 or 17
316 ),
317 )
318 exit_code = _maybe_dispatch_multi_gpu(
319 "1,3",
320 ["dlm", "train", "doc.dlm", "--gpus", "1,3"],
321 console,
322 )
323 assert exit_code == 17
324 assert launched["device_ids"] == (1, 3)
325 assert launched["cli_args"] == ["train", "doc.dlm"]
326 assert _strip_gpus_from_argv(["dlm", "train", "--gpus=0,1", "doc.dlm"]) == [
327 "train",
328 "doc.dlm",
329 ]
330
331 def test_train_error_mappings(self, tmp_path: Path, monkeypatch: Any) -> None:
332 doc = _init_doc(tmp_path)
333 runner = CliRunner()
334
335 from dlm.lock.errors import LockValidationError
336 from dlm.train.errors import DiskSpaceError, OOMError, ResumeIntegrityError, TrainingError
337 from dlm.train.preference.errors import (
338 DpoPhaseError,
339 NoPreferenceContentError,
340 PriorAdapterRequiredError,
341 )
342
343 monkeypatch.setattr("dlm.hardware.doctor", lambda **kwargs: _fake_doctor_result())
344 monkeypatch.setattr("dlm.train.distributed.detect_world_size", lambda: 1)
345
346 cases = [
347 (
348 LockValidationError(path=tmp_path / "dlm.lock", reasons=["torch drift"]),
349 "Re-run with",
350 ),
351 (DiskSpaceError(required_bytes=2_000_000_000, free_bytes=1_000_000_000), "disk:"),
352 (ResumeIntegrityError("resume mismatch"), "resume:"),
353 (NoPreferenceContentError("no preferences"), "dpo:"),
354 (PriorAdapterRequiredError("need prior adapter"), "dpo:"),
355 (DpoPhaseError("dpo failed"), "dpo:"),
356 (TrainingError("trainer failed"), "training:"),
357 ]
358 for error, needle in cases:
359 monkeypatch.setattr(
360 "dlm.train.preference.phase_orchestrator.run_phases",
361 lambda *args, _error=error, **kwargs: (_ for _ in ()).throw(_error),
362 )
363 result = runner.invoke(
364 app,
365 ["--home", str(tmp_path / "home"), "train", str(doc)],
366 )
367 assert result.exit_code == 1, result.output
368 assert needle in result.output
369
370 monkeypatch.setattr(
371 "dlm.train.preference.phase_orchestrator.run_phases",
372 lambda *args, **kwargs: (_ for _ in ()).throw(
373 OOMError(
374 step=5,
375 peak_bytes=2_000,
376 free_at_start_bytes=4_000,
377 current_grad_accum=1,
378 recommended_grad_accum=4,
379 )
380 ),
381 )
382 monkeypatch.setattr("dlm.train.format_oom_message", lambda **kwargs: "OOM advice")
383 oom = runner.invoke(
384 app,
385 ["--home", str(tmp_path / "home"), "train", str(doc)],
386 )
387 assert oom.exit_code == 1, oom.output
388 assert "OOM advice" in oom.output
389
390
391 class TestPromptAndReplCoverage:
392 def test_prompt_text_backend_reads_stdin_and_generates(
393 self, tmp_path: Path, monkeypatch: Any
394 ) -> None:
395 doc = _init_doc(tmp_path)
396 runner = CliRunner()
397
398 class _FakeBackend:
399 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
400 return None
401
402 def generate(self, query: str, **kwargs: object) -> str:
403 return f"reply:{query}"
404
405 monkeypatch.setattr("dlm.hardware.doctor", lambda: SimpleNamespace(capabilities=object()))
406 monkeypatch.setattr(
407 "dlm.inference.backends.select_backend", lambda *args, **kwargs: "pytorch"
408 )
409 monkeypatch.setattr(
410 "dlm.inference.backends.build_backend", lambda *args, **kwargs: _FakeBackend()
411 )
412
413 result = runner.invoke(
414 app,
415 ["--home", str(tmp_path / "home"), "prompt", str(doc)],
416 input="hello from stdin\n",
417 )
418 assert result.exit_code == 0, result.output
419 assert "reply:hello from stdin" in result.output
420
421 def test_repl_success_and_adapter_validation(self, tmp_path: Path, monkeypatch: Any) -> None:
422 doc = _init_doc(tmp_path)
423 runner = CliRunner()
424
425 adapter_bad = runner.invoke(
426 app,
427 ["--home", str(tmp_path / "home"), "repl", str(doc), "--adapter", "knowledge"],
428 )
429 assert adapter_bad.exit_code == 2, adapter_bad.output
430 assert "only valid on multi-adapter" in adapter_bad.output
431
432 class _FakeBackend:
433 def __init__(self) -> None:
434 self._loaded = SimpleNamespace(tokenizer="tok")
435
436 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
437 return None
438
439 monkeypatch.setattr("dlm.hardware.doctor", lambda: SimpleNamespace(capabilities=object()))
440 monkeypatch.setattr(
441 "dlm.inference.backends.select_backend", lambda *args, **kwargs: "pytorch"
442 )
443 monkeypatch.setattr(
444 "dlm.inference.backends.build_backend", lambda *args, **kwargs: _FakeBackend()
445 )
446 monkeypatch.setattr("dlm.repl.app.run_repl", lambda session, console: 5)
447
448 repl_ok = runner.invoke(
449 app,
450 ["--home", str(tmp_path / "home"), "repl", str(doc), "--backend", "pytorch"],
451 )
452 assert repl_ok.exit_code == 5, repl_ok.output
453
454 def test_repl_error_mappings(self, tmp_path: Path, monkeypatch: Any) -> None:
455 doc = _init_doc(tmp_path)
456 runner = CliRunner()
457
458 from dlm.base_models.errors import GatedModelError
459 from dlm.inference import AdapterNotFoundError
460 from dlm.inference.backends.select import UnsupportedBackendError
461
462 original = doc.read_text(encoding="utf-8")
463 fm_end = original.find("\n---\n", original.find("---") + 3)
464 multi = tmp_path / "multi.dlm"
465 multi.write_text(
466 original[:fm_end] + "\ntraining:\n adapters:\n knowledge: {}\n" + original[fm_end:],
467 encoding="utf-8",
468 )
469 unknown = runner.invoke(
470 app,
471 ["--home", str(tmp_path / "home"), "repl", str(multi), "--adapter", "ghost"],
472 )
473 assert unknown.exit_code == 2, unknown.output
474 assert "not declared" in unknown.output
475
476 monkeypatch.setattr(
477 "dlm.base_models.resolve",
478 lambda *args, **kwargs: (_ for _ in ()).throw(
479 GatedModelError("hf/model", "https://license")
480 ),
481 )
482 gated = runner.invoke(
483 app,
484 ["--home", str(tmp_path / "home"), "repl", str(doc), "--backend", "pytorch"],
485 )
486 assert gated.exit_code == 1, gated.output
487 assert "run `dlm train --i-accept-license` first" in gated.output
488
489 monkeypatch.setattr("dlm.base_models.resolve", lambda *args, **kwargs: SimpleNamespace())
490 monkeypatch.setattr("dlm.hardware.doctor", lambda: SimpleNamespace(capabilities=object()))
491 monkeypatch.setattr(
492 "dlm.inference.backends.select_backend",
493 lambda *args, **kwargs: (_ for _ in ()).throw(
494 UnsupportedBackendError("backend not available")
495 ),
496 )
497 unsupported = runner.invoke(
498 app,
499 ["--home", str(tmp_path / "home"), "repl", str(doc), "--backend", "pytorch"],
500 )
501 assert unsupported.exit_code == 2, unsupported.output
502 assert "backend not available" in unsupported.output
503
504 class _MissingAdapterBackend:
505 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
506 raise AdapterNotFoundError("missing adapter")
507
508 monkeypatch.setattr(
509 "dlm.inference.backends.select_backend", lambda *args, **kwargs: "pytorch"
510 )
511 monkeypatch.setattr(
512 "dlm.inference.backends.build_backend",
513 lambda *args, **kwargs: _MissingAdapterBackend(),
514 )
515 missing = runner.invoke(
516 app,
517 ["--home", str(tmp_path / "home"), "repl", str(doc), "--backend", "pytorch"],
518 )
519 assert missing.exit_code == 1, missing.output
520 assert "missing adapter" in missing.output
521
522 def test_prompt_empty_query_and_repl_invalid_backend(
523 self,
524 tmp_path: Path,
525 monkeypatch: Any,
526 ) -> None:
527 doc = _init_doc(tmp_path)
528 runner = CliRunner()
529
530 class _FakeBackend:
531 def load(self, spec: object, store: object, adapter_name: str | None = None) -> None:
532 return None
533
534 monkeypatch.setattr("dlm.hardware.doctor", lambda: SimpleNamespace(capabilities=object()))
535 monkeypatch.setattr(
536 "dlm.inference.backends.select_backend", lambda *args, **kwargs: "pytorch"
537 )
538 monkeypatch.setattr(
539 "dlm.inference.backends.build_backend", lambda *args, **kwargs: _FakeBackend()
540 )
541
542 prompt_result = runner.invoke(
543 app,
544 ["--home", str(tmp_path / "home"), "prompt", str(doc)],
545 input="",
546 )
547 assert prompt_result.exit_code == 2, prompt_result.output
548 assert "empty query" in prompt_result.output
549
550 repl_result = runner.invoke(
551 app,
552 ["--home", str(tmp_path / "home"), "repl", str(doc), "--backend", "bogus"],
553 )
554 assert repl_result.exit_code == 2, repl_result.output
555 assert "--backend must be" in repl_result.output