Python · 5649 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.core.model`."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6
7 import pytest
8 from pydantic import ValidationError
9
10 from dlm_sway.core.model import LoadedModel, Model, ModelSpec
11
12
13 class TestModelSpec:
14 def test_defaults(self) -> None:
15 spec = ModelSpec(base="HuggingFaceTB/SmolLM2-135M-Instruct")
16 assert spec.kind == "hf"
17 assert spec.adapter is None
18 assert spec.dtype == "auto"
19 assert spec.device == "auto"
20 assert spec.trust_remote_code is False
21 assert spec.entry_point is None
22
23 def test_frozen(self) -> None:
24 spec = ModelSpec(base="x")
25 with pytest.raises(ValidationError):
26 spec.base = "y" # type: ignore[misc]
27
28 def test_extra_fields_forbidden(self) -> None:
29 with pytest.raises(ValidationError) as exc_info:
30 ModelSpec(base="x", bogus="y") # type: ignore[call-arg]
31 assert "bogus" in str(exc_info.value).lower()
32
33 def test_kind_enum(self) -> None:
34 ModelSpec(base="x", kind="hf")
35 ModelSpec(base="x", kind="mlx")
36 ModelSpec(base="x", kind="dummy")
37 ModelSpec(base="x", kind="custom", entry_point="pkg.mod:Backend")
38 with pytest.raises(ValidationError):
39 ModelSpec(base="x", kind="ollama") # type: ignore[arg-type]
40
41 def test_adapter_coerced_to_path(self) -> None:
42 spec = ModelSpec(base="x", adapter="/tmp/adapter") # type: ignore[arg-type]
43 assert isinstance(spec.adapter, Path)
44
45 def test_adapter_tilde_expanded(self) -> None:
46 """B22: ``~`` is expanded at spec-load time so backends see absolute paths."""
47 spec = ModelSpec(base="x", adapter="~/some/adapter") # type: ignore[arg-type]
48 assert spec.adapter is not None
49 assert "~" not in str(spec.adapter)
50 assert spec.adapter.is_absolute()
51
52 def test_adapter_relative_resolved(self) -> None:
53 """B22: relative paths resolve against the current cwd."""
54 spec = ModelSpec(base="x", adapter="adapter/v1") # type: ignore[arg-type]
55 assert spec.adapter is not None
56 assert spec.adapter.is_absolute()
57
58 def test_adapter_none_passthrough(self) -> None:
59 """B22 normalizer doesn't blow up on the default ``None``."""
60 spec = ModelSpec(base="x")
61 assert spec.adapter is None
62
63 def test_dtype_enum_accepts_known_values(self) -> None:
64 """DC5 — every ``dtype`` branch parses cleanly. Backends rely
65 on this validation, not their own — passing ``fp12`` here would
66 otherwise crash deep inside the HF loader."""
67 for dtype in ("auto", "fp16", "bf16", "fp32"):
68 spec = ModelSpec(base="x", dtype=dtype) # type: ignore[arg-type]
69 assert spec.dtype == dtype
70
71 def test_dtype_enum_rejects_unknown(self) -> None:
72 with pytest.raises(ValidationError):
73 ModelSpec(base="x", dtype="fp12") # type: ignore[arg-type]
74
75 def test_trust_remote_code_default_false(self) -> None:
76 """DC5 — default is False (safe posture); user must opt in."""
77 assert ModelSpec(base="x").trust_remote_code is False
78
79 def test_trust_remote_code_accepts_true(self) -> None:
80 assert ModelSpec(base="x", trust_remote_code=True).trust_remote_code is True
81
82 def test_endpoint_default_none(self) -> None:
83 """DC5 — ``endpoint`` is an ``api``-backend-only field. Default
84 ``None`` when not specified."""
85 assert ModelSpec(base="x").endpoint is None
86
87 def test_endpoint_can_be_set(self) -> None:
88 """DC5 — the ``api`` kind expects an endpoint URL."""
89 spec = ModelSpec(
90 base="gpt-3.5-turbo-instruct",
91 kind="api",
92 endpoint="http://localhost:11434",
93 )
94 assert spec.endpoint == "http://localhost:11434"
95 assert spec.kind == "api"
96
97 def test_entry_point_optional_without_custom(self) -> None:
98 """DC5 — non-custom kinds don't need an entry_point."""
99 spec = ModelSpec(base="x", kind="hf")
100 assert spec.entry_point is None
101
102 def test_custom_kind_accepts_entry_point(self) -> None:
103 """DC5 — ``custom`` kind stores the entry_point verbatim (the
104 runner imports it)."""
105 spec = ModelSpec(base="x", kind="custom", entry_point="mypkg.backend:MyBackend")
106 assert spec.entry_point == "mypkg.backend:MyBackend"
107
108 def test_device_accepts_explicit_cpu(self) -> None:
109 """DC5 — ``device`` is a free-form str; ``"auto"`` default
110 resolves at backend-load time."""
111 assert ModelSpec(base="x", device="cpu").device == "cpu"
112 assert ModelSpec(base="x", device="cuda:0").device == "cuda:0"
113 assert ModelSpec(base="x").device == "auto"
114
115
116 class TestLoadedModel:
117 def test_frozen_dataclass(self) -> None:
118 loaded = LoadedModel(
119 id="base",
120 spec=ModelSpec(base="x"),
121 model=object(),
122 tokenizer=object(),
123 meta={"device": "cpu"},
124 )
125 assert loaded.id == "base"
126 assert loaded.meta["device"] == "cpu"
127
128
129 class TestModelProtocol:
130 def test_runtime_checkable(self) -> None:
131 class FakeModel:
132 id = "x"
133
134 def generate(
135 self,
136 prompt: str,
137 *,
138 max_new_tokens: int,
139 temperature: float = 0.0,
140 top_p: float = 1.0,
141 seed: int = 0,
142 ) -> str:
143 return f"{prompt}|{max_new_tokens}"
144
145 def close(self) -> None:
146 return None
147
148 assert isinstance(FakeModel(), Model)