Python · 12262 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.integrations.dlm`.
2
3 The bridge imports ``dlm.*`` modules lazily. We mock those via
4 ``sys.modules`` injection so the tests run without the ``dlm-sway[dlm]``
5 extra installed. A full end-to-end integration test against a real
6 ``.dlm`` lives under ``tests/integration/``.
7 """
8
9 from __future__ import annotations
10
11 import sys
12 import types
13 from dataclasses import dataclass
14 from pathlib import Path
15
16 import pytest
17 import yaml
18
19
20 @pytest.fixture
21 def fake_dlm(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
22 """Install a fake ``dlm`` package so the resolver can import."""
23
24 # Build synthetic parsed .dlm structure.
25 @dataclass
26 class _Frontmatter:
27 dlm_id: str = "01TESTULID"
28 base_model: str = "smollm2-135m"
29
30 @dataclass
31 class _Section:
32 section_id: str
33 type: str
34 content: str
35 tag: str | None = None
36
37 @dataclass
38 class _Parsed:
39 frontmatter: _Frontmatter
40 sections: tuple[_Section, ...]
41
42 def _parse_file(_path: Path): # type: ignore[no-untyped-def]
43 return _Parsed(
44 frontmatter=_Frontmatter(),
45 sections=(
46 _Section(
47 section_id="prose-1",
48 type="PROSE",
49 content="This is a prose section with some information. Further detail follows.",
50 ),
51 _Section(
52 section_id="instr-1",
53 type="INSTRUCTION",
54 content="### Q\nWhat is X?\n\n### A\nX is a concept\n",
55 ),
56 _Section(
57 section_id="pref-1",
58 type="PREFERENCE",
59 content="chosen/rejected triple",
60 ),
61 ),
62 )
63
64 # Fake ``dlm.doc.parser`` module.
65 dlm_pkg = types.ModuleType("dlm")
66 dlm_doc = types.ModuleType("dlm.doc")
67 dlm_doc_parser = types.ModuleType("dlm.doc.parser")
68 dlm_doc_parser.parse_file = _parse_file # type: ignore[attr-defined]
69
70 # Fake ``dlm.store.paths`` that returns a resolvable path.
71 dlm_store = types.ModuleType("dlm.store")
72 dlm_store_paths = types.ModuleType("dlm.store.paths")
73
74 adapter_dir = tmp_path / "adapter_v1"
75 adapter_dir.mkdir()
76 (adapter_dir / "adapter_config.json").write_text("{}", encoding="utf-8")
77
78 class _StorePath:
79 def __init__(self, path: Path) -> None:
80 self._p = path
81
82 def resolve_current_adapter(self) -> Path:
83 return self._p
84
85 def _for_dlm(_dlm_id: str) -> _StorePath:
86 return _StorePath(adapter_dir)
87
88 dlm_store_paths.StorePath = _StorePath # type: ignore[attr-defined]
89 dlm_store_paths.for_dlm = _for_dlm # type: ignore[attr-defined]
90
91 # Fake base-model resolver — returns a stub with an ``hf_id`` attribute.
92 dlm_base = types.ModuleType("dlm.base_models")
93
94 @dataclass
95 class _BaseSpec:
96 hf_id: str
97 key: str
98
99 def _resolve(key: str) -> _BaseSpec:
100 return _BaseSpec(hf_id="HuggingFaceTB/SmolLM2-135M-Instruct", key=key)
101
102 dlm_base.resolve = _resolve # type: ignore[attr-defined]
103
104 # Fake instruction / preference parsers.
105 dlm_data = types.ModuleType("dlm.data")
106 dlm_data_instr = types.ModuleType("dlm.data.instruction_parser")
107 dlm_data_pref = types.ModuleType("dlm.data.preference_parser")
108
109 @dataclass
110 class _QAPair:
111 question: str
112 answer: str
113
114 @dataclass
115 class _Triple:
116 prompt: str
117 chosen: str
118 rejected: str
119
120 def _parse_instr(body: str, *, section_id: str) -> list[_QAPair]:
121 del section_id
122 out: list[_QAPair] = []
123 parts = body.split("### Q")
124 for part in parts[1:]:
125 q_block, _, a_block = part.partition("### A")
126 q = q_block.strip()
127 a = a_block.strip()
128 if q and a:
129 out.append(_QAPair(question=q, answer=a))
130 return out
131
132 def _parse_pref(body: str, *, section_id: str) -> list[_Triple]:
133 del body, section_id
134 return [_Triple(prompt="Which?", chosen="good answer", rejected="bad answer")]
135
136 dlm_data_instr.parse_instruction_body = _parse_instr # type: ignore[attr-defined]
137 dlm_data_pref.parse_preference_body = _parse_pref # type: ignore[attr-defined]
138
139 monkeypatch.setitem(sys.modules, "dlm", dlm_pkg)
140 monkeypatch.setitem(sys.modules, "dlm.doc", dlm_doc)
141 monkeypatch.setitem(sys.modules, "dlm.doc.parser", dlm_doc_parser)
142 monkeypatch.setitem(sys.modules, "dlm.store", dlm_store)
143 monkeypatch.setitem(sys.modules, "dlm.store.paths", dlm_store_paths)
144 monkeypatch.setitem(sys.modules, "dlm.base_models", dlm_base)
145 monkeypatch.setitem(sys.modules, "dlm.data", dlm_data)
146 monkeypatch.setitem(sys.modules, "dlm.data.instruction_parser", dlm_data_instr)
147 monkeypatch.setitem(sys.modules, "dlm.data.preference_parser", dlm_data_pref)
148
149 # Return a path to a fake .dlm file (the parser won't actually read it).
150 dlm_file = tmp_path / "doc.dlm"
151 dlm_file.write_text("---\ndlm_id: 01TEST\n---\n\nbody\n", encoding="utf-8")
152 return dlm_file
153
154
155 def test_resolve_dlm_maps_sections(fake_dlm: Path) -> None:
156 from dlm_sway.integrations.dlm.resolver import resolve_dlm
157
158 handle = resolve_dlm(fake_dlm)
159 assert handle.dlm_id == "01TESTULID"
160 assert handle.base_model == "HuggingFaceTB/SmolLM2-135M-Instruct"
161 assert handle.adapter_path is not None
162 assert handle.adapter_path.exists()
163 assert len(handle.sections) == 3
164 # Kinds normalized from uppercase dlm enum values.
165 assert {s.kind for s in handle.sections} == {"prose", "instruction", "preference"}
166 # Instruction Q/A pair survived the translation.
167 instr = next(s for s in handle.sections if s.kind == "instruction")
168 assert instr.probes
169 assert instr.probes[0].prompt == "What is X?"
170 # Preference triple too.
171 pref = next(s for s in handle.sections if s.kind == "preference")
172 assert pref.preferences
173 assert pref.preferences[0].chosen == "good answer"
174
175
176 def test_resolve_raises_dlm_compat_error_on_missing_hf_id(
177 monkeypatch: pytest.MonkeyPatch, tmp_path: Path
178 ) -> None:
179 """F06 — dlm.base_models.resolve returning an object without
180 ``hf_id`` must raise DlmCompatError (not silently fall back to the
181 registry key, which would push the failure into the backend load
182 with a less helpful message)."""
183 dlm_pkg = types.ModuleType("dlm")
184 dlm_doc = types.ModuleType("dlm.doc")
185 dlm_doc_parser = types.ModuleType("dlm.doc.parser")
186
187 @dataclass
188 class _Frontmatter:
189 dlm_id: str = "01TEST"
190 base_model: str = "smollm2-135m"
191
192 @dataclass
193 class _Parsed:
194 frontmatter: _Frontmatter
195 sections: tuple[object, ...]
196
197 dlm_doc_parser.parse_file = lambda _p: _Parsed( # type: ignore[attr-defined]
198 _Frontmatter(), sections=()
199 )
200
201 # Stand-in for a post-rename dlm where ``hf_id`` became ``repo_id``.
202 @dataclass
203 class _RenamedSpec:
204 repo_id: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
205
206 dlm_base = types.ModuleType("dlm.base_models")
207 dlm_base.resolve = lambda _k: _RenamedSpec() # type: ignore[attr-defined]
208
209 monkeypatch.setitem(sys.modules, "dlm", dlm_pkg)
210 monkeypatch.setitem(sys.modules, "dlm.doc", dlm_doc)
211 monkeypatch.setitem(sys.modules, "dlm.doc.parser", dlm_doc_parser)
212 monkeypatch.setitem(sys.modules, "dlm.base_models", dlm_base)
213
214 dlm_file = tmp_path / "doc.dlm"
215 dlm_file.write_text("---\ndlm_id: 01TEST\n---\n\nbody\n", encoding="utf-8")
216
217 from dlm_sway.core.errors import DlmCompatError
218 from dlm_sway.integrations.dlm.resolver import resolve_dlm
219
220 with pytest.raises(DlmCompatError, match="hf_id"):
221 resolve_dlm(dlm_file)
222
223
224 def test_resolve_raises_dlm_compat_error_on_resolve_exception(
225 monkeypatch: pytest.MonkeyPatch, tmp_path: Path
226 ) -> None:
227 """F06 — when dlm.base_models.resolve itself raises, wrap the
228 underlying exception in DlmCompatError (preserves __cause__ for
229 debugging; surfaces as a typed sway error for callers)."""
230 dlm_pkg = types.ModuleType("dlm")
231 dlm_doc = types.ModuleType("dlm.doc")
232 dlm_doc_parser = types.ModuleType("dlm.doc.parser")
233
234 @dataclass
235 class _Frontmatter:
236 dlm_id: str = "01TEST"
237 base_model: str = "smollm2-135m"
238
239 @dataclass
240 class _Parsed:
241 frontmatter: _Frontmatter
242 sections: tuple[object, ...]
243
244 dlm_doc_parser.parse_file = lambda _p: _Parsed( # type: ignore[attr-defined]
245 _Frontmatter(), sections=()
246 )
247
248 class _RegistryDriftError(RuntimeError):
249 pass
250
251 def _raise(_k: str) -> object:
252 raise _RegistryDriftError("unknown base key after rename")
253
254 dlm_base = types.ModuleType("dlm.base_models")
255 dlm_base.resolve = _raise # type: ignore[attr-defined]
256
257 monkeypatch.setitem(sys.modules, "dlm", dlm_pkg)
258 monkeypatch.setitem(sys.modules, "dlm.doc", dlm_doc)
259 monkeypatch.setitem(sys.modules, "dlm.doc.parser", dlm_doc_parser)
260 monkeypatch.setitem(sys.modules, "dlm.base_models", dlm_base)
261
262 dlm_file = tmp_path / "doc.dlm"
263 dlm_file.write_text("---\ndlm_id: 01TEST\n---\n\nbody\n", encoding="utf-8")
264
265 from dlm_sway.core.errors import DlmCompatError
266 from dlm_sway.integrations.dlm.resolver import resolve_dlm
267
268 with pytest.raises(DlmCompatError, match="_RegistryDriftError"):
269 resolve_dlm(dlm_file)
270
271
272 def test_resolve_without_dlm_installed(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
273 """resolve_dlm surfaces a SwayError when the dlm package is missing."""
274 # Wipe any cached dlm modules so the lazy import fails.
275 for mod in list(sys.modules):
276 if mod == "dlm" or mod.startswith("dlm."):
277 monkeypatch.delitem(sys.modules, mod, raising=False)
278
279 import builtins
280
281 real_import = builtins.__import__
282
283 def fake_import(name: str, *args, **kwargs): # type: ignore[no-untyped-def]
284 if name.startswith("dlm."):
285 raise ImportError("missing extra")
286 return real_import(name, *args, **kwargs)
287
288 monkeypatch.setattr(builtins, "__import__", fake_import)
289
290 from dlm_sway.core.errors import SwayError
291 from dlm_sway.integrations.dlm.resolver import resolve_dlm
292
293 with pytest.raises(SwayError, match="dlm package not installed"):
294 resolve_dlm(tmp_path / "doc.dlm")
295
296
297 def test_autogen_writes_complete_suite(fake_dlm: Path, tmp_path: Path) -> None:
298 from dlm_sway.integrations.dlm.autogen import write_sway_yaml
299
300 out = tmp_path / "sway.yaml"
301 write_sway_yaml(fake_dlm, out)
302 data = yaml.safe_load(out.read_text(encoding="utf-8"))
303
304 assert data["version"] == 1
305 assert data["models"]["base"]["base"] == "HuggingFaceTB/SmolLM2-135M-Instruct"
306 assert data["models"]["ft"]["adapter"] is not None
307 assert data["dlm_source"] == str(fake_dlm.resolve())
308
309 kinds = {entry["kind"] for entry in data["suite"]}
310 # The full 11-primitive battery minus nothing is present (some may
311 # be skipped when data is absent, but here we have one of every
312 # section type).
313 expected = {
314 "null_adapter",
315 "delta_kl",
316 "adapter_revert",
317 "prompt_collapse",
318 "section_internalization",
319 "paraphrase_invariance",
320 "preference_flip",
321 "style_fingerprint",
322 "calibration_drift",
323 "leakage",
324 "adapter_ablation",
325 }
326 assert expected <= kinds, f"missing: {expected - kinds}"
327
328
329 def test_build_spec_dict_skips_preference_when_absent() -> None:
330 from dlm_sway.core.sections import Section
331 from dlm_sway.integrations.dlm.autogen import build_spec_dict
332 from dlm_sway.integrations.dlm.resolver import DlmHandle
333
334 sections = (
335 Section(id="a", kind="prose", content="A prose section. Second sentence."),
336 Section(id="b", kind="prose", content="Another prose section."),
337 )
338 handle = DlmHandle(
339 dlm_id="x",
340 base_model="base",
341 adapter_path=Path("/tmp/adapter"),
342 sections=sections,
343 doc_text="whole document",
344 )
345 spec = build_spec_dict(handle)
346 kinds = {entry["kind"] for entry in spec["suite"]}
347 assert "preference_flip" not in kinds
348 assert "section_internalization" in kinds