Python · 9253 bytes Raw Blame History
1 """Direct helper coverage for sway-backed preference judge wiring."""
2
3 from __future__ import annotations
4
5 import builtins
6 import importlib
7 import sys
8 from pathlib import Path
9 from types import SimpleNamespace
10 from unittest.mock import patch
11
12 import pytest
13
14 from dlm.preference import JudgeUnavailableError
15 from dlm.preference.judge import (
16 _build_sway_backend,
17 _import_sway_bridge,
18 _resolve_sway_trust_remote_code,
19 )
20
21
22 class FakeSwayError(Exception):
23 pass
24
25
26 class FakeModelSpec:
27 def __init__(self, **kwargs: object) -> None:
28 self.kwargs = kwargs
29
30
31 class FakeSysPath(list[str]):
32 def __init__(self) -> None:
33 super().__init__()
34 self.inserted: list[str] = []
35
36 def insert(self, index: int, value: str) -> None: # type: ignore[override]
37 self.inserted.append(value)
38 super().insert(index, value)
39
40
41 def test_build_sway_backend_requires_importable_bridge() -> None:
42 with (
43 patch("dlm.preference.judge._import_sway_bridge", side_effect=ImportError("missing")),
44 pytest.raises(JudgeUnavailableError, match="requires the sway bridge"),
45 ):
46 _build_sway_backend(Path("/tmp/example.dlm"))
47
48
49 def test_build_sway_backend_wraps_sway_resolution_errors() -> None:
50 def resolve_dlm(_path: Path) -> object:
51 raise FakeSwayError("no store")
52
53 with (
54 patch(
55 "dlm.preference.judge._import_sway_bridge",
56 return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
57 ),
58 pytest.raises(JudgeUnavailableError, match="could not resolve"),
59 ):
60 _build_sway_backend(Path("/tmp/example.dlm"))
61
62
63 def test_build_sway_backend_wraps_generic_resolution_errors() -> None:
64 def resolve_dlm(_path: Path) -> object:
65 raise RuntimeError("boom")
66
67 with (
68 patch(
69 "dlm.preference.judge._import_sway_bridge",
70 return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
71 ),
72 pytest.raises(JudgeUnavailableError, match="could not resolve"),
73 ):
74 _build_sway_backend(Path("/tmp/example.dlm"))
75
76
77 def test_build_sway_backend_requires_trained_adapter() -> None:
78 handle = SimpleNamespace(adapter_path=None, base_model="base/model")
79
80 def resolve_dlm(_path: Path) -> object:
81 return handle
82
83 with (
84 patch(
85 "dlm.preference.judge._import_sway_bridge",
86 return_value=(resolve_dlm, object(), FakeModelSpec, FakeSwayError),
87 ),
88 pytest.raises(JudgeUnavailableError, match="requires a trained adapter"),
89 ):
90 _build_sway_backend(Path("/tmp/example.dlm"))
91
92
93 def test_build_sway_backend_wraps_backend_load_errors() -> None:
94 handle = SimpleNamespace(adapter_path=Path("/tmp/adapter"), base_model="base/model")
95
96 def resolve_dlm(_path: Path) -> object:
97 return handle
98
99 def build_backend(_spec: FakeModelSpec, *, adapter_path: Path) -> object:
100 assert adapter_path == handle.adapter_path
101 raise RuntimeError("backend blew up")
102
103 with (
104 patch(
105 "dlm.preference.judge._import_sway_bridge",
106 return_value=(resolve_dlm, build_backend, FakeModelSpec, FakeSwayError),
107 ),
108 patch("dlm.preference.judge._resolve_sway_trust_remote_code", return_value=False),
109 pytest.raises(JudgeUnavailableError, match="could not load backend"),
110 ):
111 _build_sway_backend(Path("/tmp/example.dlm"))
112
113
114 def test_build_sway_backend_builds_model_spec_with_trust_remote_code() -> None:
115 handle = SimpleNamespace(adapter_path=Path("/tmp/adapter"), base_model="base/model")
116 seen: dict[str, object] = {}
117
118 def resolve_dlm(_path: Path) -> object:
119 return handle
120
121 def build_backend(spec: FakeModelSpec, *, adapter_path: Path) -> object:
122 seen["spec"] = spec
123 seen["adapter_path"] = adapter_path
124 return "backend"
125
126 with (
127 patch(
128 "dlm.preference.judge._import_sway_bridge",
129 return_value=(resolve_dlm, build_backend, FakeModelSpec, FakeSwayError),
130 ),
131 patch("dlm.preference.judge._resolve_sway_trust_remote_code", return_value=True),
132 ):
133 backend = _build_sway_backend(Path("/tmp/example.dlm"))
134
135 assert backend == "backend"
136 spec = seen["spec"]
137 assert isinstance(spec, FakeModelSpec)
138 assert spec.kwargs == {
139 "kind": "hf",
140 "base": "base/model",
141 "adapter": handle.adapter_path,
142 "trust_remote_code": True,
143 }
144 assert seen["adapter_path"] == handle.adapter_path
145
146
147 def test_import_sway_bridge_loads_modules_directly(monkeypatch: pytest.MonkeyPatch) -> None:
148 modules = {
149 "dlm_sway.backends": SimpleNamespace(build="build-backend"),
150 "dlm_sway.core.errors": SimpleNamespace(SwayError=FakeSwayError),
151 "dlm_sway.core.model": SimpleNamespace(ModelSpec=FakeModelSpec),
152 "dlm_sway.integrations.dlm.resolver": SimpleNamespace(resolve_dlm="resolve-dlm"),
153 }
154
155 def fake_import_module(name: str) -> object:
156 return modules[name]
157
158 monkeypatch.setattr(importlib, "import_module", fake_import_module)
159 resolve_dlm, build_backend, model_spec, sway_error = _import_sway_bridge()
160
161 assert resolve_dlm == "resolve-dlm"
162 assert build_backend == "build-backend"
163 assert model_spec is FakeModelSpec
164 assert sway_error is FakeSwayError
165
166
167 def test_import_sway_bridge_falls_back_to_local_src_path(
168 monkeypatch: pytest.MonkeyPatch,
169 ) -> None:
170 modules = {
171 "dlm_sway.backends": SimpleNamespace(build="build-backend"),
172 "dlm_sway.core.errors": SimpleNamespace(SwayError=FakeSwayError),
173 "dlm_sway.core.model": SimpleNamespace(ModelSpec=FakeModelSpec),
174 "dlm_sway.integrations.dlm.resolver": SimpleNamespace(resolve_dlm="resolve-dlm"),
175 }
176 calls = {"count": 0}
177
178 def fake_import_module(name: str) -> object:
179 calls["count"] += 1
180 if calls["count"] == 1:
181 raise ImportError("first import fails")
182 return modules[name]
183
184 fake_sys_path = FakeSysPath()
185
186 monkeypatch.setattr(importlib, "import_module", fake_import_module)
187 monkeypatch.setattr(Path, "exists", lambda self: True)
188 monkeypatch.setattr(sys, "path", fake_sys_path)
189 resolve_dlm, build_backend, model_spec, sway_error = _import_sway_bridge()
190
191 assert resolve_dlm == "resolve-dlm"
192 assert build_backend == "build-backend"
193 assert model_spec is FakeModelSpec
194 assert sway_error is FakeSwayError
195 assert fake_sys_path.inserted
196 assert fake_sys_path.inserted[0].endswith("/sway/src")
197
198
199 def test_resolve_sway_trust_remote_code_returns_false_when_imports_are_missing() -> None:
200 real_import = builtins.__import__
201
202 def fake_import(name: str, *args: object, **kwargs: object):
203 if name in {"dlm.base_models", "dlm.doc.parser"}:
204 raise ImportError("missing")
205 return real_import(name, *args, **kwargs)
206
207 with patch("builtins.__import__", side_effect=fake_import):
208 assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
209
210
211 def test_resolve_sway_trust_remote_code_handles_parse_and_resolve_failures() -> None:
212 fake_doc_parser = SimpleNamespace(
213 parse_file=lambda _path: (_ for _ in ()).throw(RuntimeError("bad"))
214 )
215 fake_base_models = SimpleNamespace(resolve=lambda *_args, **_kwargs: object())
216
217 with patch.dict(
218 "sys.modules",
219 {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
220 ):
221 assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
222
223 parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model="custom-base"))
224 fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
225 fake_base_models = SimpleNamespace(
226 resolve=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("no base"))
227 )
228
229 with patch.dict(
230 "sys.modules",
231 {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
232 ):
233 assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
234
235
236 @pytest.mark.parametrize("base_model", ["", "hf:org/model"])
237 def test_resolve_sway_trust_remote_code_short_circuits_for_non_registry_models(
238 base_model: str,
239 ) -> None:
240 parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model=base_model))
241 fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
242 fake_base_models = SimpleNamespace(resolve=lambda *_args, **_kwargs: object())
243
244 with patch.dict(
245 "sys.modules",
246 {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
247 ):
248 assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is False
249
250
251 def test_resolve_sway_trust_remote_code_returns_spec_flag() -> None:
252 parsed = SimpleNamespace(frontmatter=SimpleNamespace(base_model="qwen3-1.7b"))
253 fake_doc_parser = SimpleNamespace(parse_file=lambda _path: parsed)
254 fake_base_models = SimpleNamespace(
255 resolve=lambda *_args, **_kwargs: SimpleNamespace(trust_remote_code=True)
256 )
257
258 with patch.dict(
259 "sys.modules",
260 {"dlm.doc.parser": fake_doc_parser, "dlm.base_models": fake_base_models},
261 ):
262 assert _resolve_sway_trust_remote_code(Path("/tmp/example.dlm")) is True