Python · 3244 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.core.model`."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6
7 import pytest
8 from pydantic import ValidationError
9
10 from dlm_sway.core.model import LoadedModel, Model, ModelSpec
11
12
13 class TestModelSpec:
14 def test_defaults(self) -> None:
15 spec = ModelSpec(base="HuggingFaceTB/SmolLM2-135M-Instruct")
16 assert spec.kind == "hf"
17 assert spec.adapter is None
18 assert spec.dtype == "auto"
19 assert spec.device == "auto"
20 assert spec.trust_remote_code is False
21 assert spec.entry_point is None
22
23 def test_frozen(self) -> None:
24 spec = ModelSpec(base="x")
25 with pytest.raises(ValidationError):
26 spec.base = "y" # type: ignore[misc]
27
28 def test_extra_fields_forbidden(self) -> None:
29 with pytest.raises(ValidationError) as exc_info:
30 ModelSpec(base="x", bogus="y") # type: ignore[call-arg]
31 assert "bogus" in str(exc_info.value).lower()
32
33 def test_kind_enum(self) -> None:
34 ModelSpec(base="x", kind="hf")
35 ModelSpec(base="x", kind="mlx")
36 ModelSpec(base="x", kind="dummy")
37 ModelSpec(base="x", kind="custom", entry_point="pkg.mod:Backend")
38 with pytest.raises(ValidationError):
39 ModelSpec(base="x", kind="ollama") # type: ignore[arg-type]
40
41 def test_adapter_coerced_to_path(self) -> None:
42 spec = ModelSpec(base="x", adapter="/tmp/adapter") # type: ignore[arg-type]
43 assert isinstance(spec.adapter, Path)
44
45 def test_adapter_tilde_expanded(self) -> None:
46 """B22: ``~`` is expanded at spec-load time so backends see absolute paths."""
47 spec = ModelSpec(base="x", adapter="~/some/adapter") # type: ignore[arg-type]
48 assert spec.adapter is not None
49 assert "~" not in str(spec.adapter)
50 assert spec.adapter.is_absolute()
51
52 def test_adapter_relative_resolved(self) -> None:
53 """B22: relative paths resolve against the current cwd."""
54 spec = ModelSpec(base="x", adapter="adapter/v1") # type: ignore[arg-type]
55 assert spec.adapter is not None
56 assert spec.adapter.is_absolute()
57
58 def test_adapter_none_passthrough(self) -> None:
59 """B22 normalizer doesn't blow up on the default ``None``."""
60 spec = ModelSpec(base="x")
61 assert spec.adapter is None
62
63
64 class TestLoadedModel:
65 def test_frozen_dataclass(self) -> None:
66 loaded = LoadedModel(
67 id="base",
68 spec=ModelSpec(base="x"),
69 model=object(),
70 tokenizer=object(),
71 meta={"device": "cpu"},
72 )
73 assert loaded.id == "base"
74 assert loaded.meta["device"] == "cpu"
75
76
77 class TestModelProtocol:
78 def test_runtime_checkable(self) -> None:
79 class FakeModel:
80 id = "x"
81
82 def generate(
83 self,
84 prompt: str,
85 *,
86 max_new_tokens: int,
87 temperature: float = 0.0,
88 top_p: float = 1.0,
89 seed: int = 0,
90 ) -> str:
91 return f"{prompt}|{max_new_tokens}"
92
93 def close(self) -> None:
94 return None
95
96 assert isinstance(FakeModel(), Model)