Python · 7406 bytes Raw Blame History
1 """Tests for the backend registry in ``dlm_sway.backends``.
2
3 The registry is the single place that maps a ModelSpec to a concrete
4 backend. These tests check the error paths — actually materializing an
5 HF backend requires model weights and is covered by the integration
6 suite.
7 """
8
9 from __future__ import annotations
10
11 from pathlib import Path
12
13 import pytest
14
15 from dlm_sway.backends import build
16 from dlm_sway.core.errors import BackendNotAvailableError, SpecValidationError
17 from dlm_sway.core.model import ModelSpec
18
19
20 class TestRegistry:
21 def test_dummy_rejected_via_build(self) -> None:
22 with pytest.raises(SpecValidationError, match="kind='dummy'"):
23 build(ModelSpec(base="x", kind="dummy"))
24
25 def test_hf_requires_adapter(self) -> None:
26 with pytest.raises(SpecValidationError, match="adapter"):
27 build(ModelSpec(base="x", kind="hf"))
28
29 def test_mlx_requires_adapter(self) -> None:
30 with pytest.raises(SpecValidationError, match="adapter"):
31 build(ModelSpec(base="x", kind="mlx"))
32
33 def test_mlx_dispatch_raises_when_mlx_missing(self) -> None:
34 # On non-Apple-Silicon (or Apple without mlx installed), constructing
35 # the MLX backend raises BackendNotAvailableError with a pip hint.
36 # We skip this assertion if mlx happens to be installed.
37 import importlib.util
38
39 if importlib.util.find_spec("mlx") is not None:
40 pytest.skip("mlx is installed; error path not exercised")
41 with pytest.raises(BackendNotAvailableError) as exc_info:
42 build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a")))
43 assert exc_info.value.backend == "mlx"
44
45 def test_custom_requires_entry_point(self) -> None:
46 with pytest.raises(SpecValidationError, match="entry_point"):
47 build(ModelSpec(base="x", kind="custom", adapter=Path("/tmp/a")))
48
49 def test_custom_validates_entry_point_shape(self) -> None:
50 with pytest.raises(SpecValidationError, match="pkg.module:ClassName"):
51 build(
52 ModelSpec(
53 base="x",
54 kind="custom",
55 entry_point="not_a_valid_entry_point",
56 adapter=Path("/tmp/a"),
57 )
58 )
59
60 def test_custom_rejects_unimportable_module(self) -> None:
61 with pytest.raises(SpecValidationError, match="cannot import"):
62 build(
63 ModelSpec(
64 base="x",
65 kind="custom",
66 entry_point="nonexistent_pkg_xyz:Backend",
67 adapter=Path("/tmp/a"),
68 )
69 )
70
71 def test_custom_rejects_missing_class(self) -> None:
72 with pytest.raises(SpecValidationError, match="has no attribute"):
73 build(
74 ModelSpec(
75 base="x",
76 kind="custom",
77 entry_point="dlm_sway:NoSuchClass",
78 adapter=Path("/tmp/a"),
79 )
80 )
81
82 def test_custom_rejects_non_differential_class(self) -> None:
83 # A class that accepts the canonical constructor args but doesn't
84 # implement the protocol.
85 import sys
86 import types
87
88 class _Bad:
89 def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def]
90 del base_spec, adapter_path
91
92 mod = types.ModuleType("_sway_bad_mod")
93 mod.Bad = _Bad # type: ignore[attr-defined]
94 sys.modules["_sway_bad_mod"] = mod
95
96 with pytest.raises(SpecValidationError, match="DifferentialBackend"):
97 build(
98 ModelSpec(
99 base="x",
100 kind="custom",
101 entry_point="_sway_bad_mod:Bad",
102 adapter=Path("/tmp/a"),
103 )
104 )
105
106 def test_custom_dispatches_to_valid_backend(self) -> None:
107 # Use the dummy backend via a custom entry point. The dummy class's
108 # __init__ takes different args, so we write a thin adapter class.
109 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
110
111 class _AdapterBackend(DummyDifferentialBackend):
112 def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def]
113 super().__init__(base=DummyResponses(), ft=DummyResponses())
114
115 # Register on a throwaway module we can find by name.
116 import sys
117 import types
118
119 mod = types.ModuleType("_sway_custom_test_mod")
120 mod.AdapterBackend = _AdapterBackend # type: ignore[attr-defined]
121 sys.modules["_sway_custom_test_mod"] = mod
122
123 backend = build(
124 ModelSpec(
125 base="x",
126 kind="custom",
127 entry_point="_sway_custom_test_mod:AdapterBackend",
128 adapter=Path("/tmp/a"),
129 )
130 )
131 from dlm_sway.core.scoring import DifferentialBackend
132
133 assert isinstance(backend, DifferentialBackend)
134
135
136 class TestCustomProtocolStamp:
137 """B20: ``__sway_protocols__`` records the optional protocols satisfied."""
138
139 def test_stamps_null_calibrated_when_satisfied(self) -> None:
140 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
141
142 class _AdapterBackend(DummyDifferentialBackend):
143 def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def]
144 super().__init__(base=DummyResponses(), ft=DummyResponses())
145
146 import sys
147 import types
148
149 mod = types.ModuleType("_sway_protostamp_mod")
150 mod.AdapterBackend = _AdapterBackend # type: ignore[attr-defined]
151 sys.modules["_sway_protostamp_mod"] = mod
152
153 backend = build(
154 ModelSpec(
155 base="x",
156 kind="custom",
157 entry_point="_sway_protostamp_mod:AdapterBackend",
158 adapter=Path("/tmp/a"),
159 )
160 )
161 protocols = getattr(backend, "__sway_protocols__", ())
162 assert "DifferentialBackend" in protocols
163 # DummyDifferentialBackend satisfies both optional protocols.
164 assert "NullCalibratedBackend" in protocols
165 assert "ScalableDifferentialBackend" in protocols
166
167 def test_minimal_diff_backend_only_records_diff(self) -> None:
168 """A backend that only satisfies DifferentialBackend gets stamped accordingly."""
169 from contextlib import contextmanager
170
171 class _MinimalBackend:
172 def __init__(self, base_spec, adapter_path): # type: ignore[no-untyped-def]
173 del base_spec, adapter_path
174
175 @contextmanager
176 def as_base(self): # noqa: ANN201
177 yield self
178
179 @contextmanager
180 def as_finetuned(self): # noqa: ANN201
181 yield self
182
183 import sys
184 import types
185
186 mod = types.ModuleType("_sway_minimal_mod")
187 mod.Minimal = _MinimalBackend # type: ignore[attr-defined]
188 sys.modules["_sway_minimal_mod"] = mod
189
190 backend = build(
191 ModelSpec(
192 base="x",
193 kind="custom",
194 entry_point="_sway_minimal_mod:Minimal",
195 adapter=Path("/tmp/a"),
196 )
197 )
198 protocols = getattr(backend, "__sway_protocols__", ())
199 assert protocols == ("DifferentialBackend",)