Python · 39531 bytes Raw Blame History
1 """Teacher selector parsing and runtime wrappers for Sprint 43."""
2
3 from __future__ import annotations
4
5 import builtins
6 import json
7 import sys
8 import urllib.error
9 from pathlib import Path
10 from types import ModuleType, SimpleNamespace
11 from typing import Any, Literal
12
13 import pytest
14
15 import dlm.synth.teachers as teachers_mod
16 from dlm.synth import (
17 AnthropicTeacher,
18 HfTeacher,
19 InvalidTeacherSpecError,
20 OpenAiTeacher,
21 SelfTeacher,
22 TeacherInvocationError,
23 TeacherUnavailableError,
24 VllmServerTeacher,
25 build_teacher,
26 parse_teacher_ref,
27 )
28
29
30 def _module(name: str, **attrs: object) -> ModuleType:
31 module = ModuleType(name)
32 for key, value in attrs.items():
33 setattr(module, key, value)
34 return module
35
36
37 class TestTeacherSelectorParsing:
38 @pytest.mark.parametrize(
39 ("raw", "kind", "target"),
40 [
41 ("self", "self", None),
42 ("hf:Qwen/Qwen2.5-1.5B-Instruct", "hf", "Qwen/Qwen2.5-1.5B-Instruct"),
43 ("openai:gpt-4o-mini", "openai", "gpt-4o-mini"),
44 ("anthropic:claude-3-5-haiku-latest", "anthropic", "claude-3-5-haiku-latest"),
45 ("vllm-server:http://127.0.0.1:8000", "vllm-server", "http://127.0.0.1:8000"),
46 ],
47 )
48 def test_parse_teacher_ref(self, raw: str, kind: str, target: str | None) -> None:
49 ref = parse_teacher_ref(raw)
50 assert ref.kind == kind
51 assert ref.target == target
52
53 def test_empty_selector_refused(self) -> None:
54 with pytest.raises(InvalidTeacherSpecError, match="must not be empty"):
55 parse_teacher_ref(" ")
56
57 def test_unknown_selector_refused(self) -> None:
58 with pytest.raises(InvalidTeacherSpecError, match="unknown teacher selector"):
59 parse_teacher_ref("mystery:thing")
60
61 @pytest.mark.parametrize(
62 ("raw", "message"),
63 [
64 ("hf: ", "hf teacher selector must include a model id"),
65 ("openai: ", "openai teacher selector must include a model id"),
66 ("anthropic: ", "anthropic teacher selector must include a model id"),
67 ("vllm-server: ", "vllm-server teacher selector must include a URL"),
68 ],
69 )
70 def test_missing_selector_targets_are_refused(self, raw: str, message: str) -> None:
71 with pytest.raises(InvalidTeacherSpecError, match=message):
72 parse_teacher_ref(raw)
73
74
75 class TestBuildTeacher:
76 def test_self_requires_dlm_path(self) -> None:
77 with pytest.raises(TeacherUnavailableError, match="requires the .dlm path context"):
78 build_teacher("self")
79
80 def test_build_teacher_dispatches(self, tmp_path: Path) -> None:
81 self_teacher = build_teacher("self", dlm_path=tmp_path / "doc.dlm")
82 assert isinstance(self_teacher, SelfTeacher)
83 assert self_teacher.backend == "pytorch"
84 assert isinstance(build_teacher("hf:foo/bar"), HfTeacher)
85 assert isinstance(build_teacher("openai:gpt-4o-mini"), OpenAiTeacher)
86 assert isinstance(build_teacher("anthropic:claude"), AnthropicTeacher)
87 assert isinstance(
88 build_teacher("vllm-server:http://127.0.0.1:8000"),
89 VllmServerTeacher,
90 )
91
92
93 class TestSelfTeacher:
94 def test_self_teacher_uses_loader_once_and_forwards_kwargs(self, tmp_path: Path) -> None:
95 calls: list[tuple[str, dict[str, Any]]] = []
96 loaded_paths: list[tuple[Path, str]] = []
97
98 class _Backend:
99 def generate(self, prompt: str, **gen_kwargs: Any) -> str:
100 calls.append((prompt, gen_kwargs))
101 return " synthesized answer "
102
103 def _loader(path: Path, backend: str) -> _Backend:
104 loaded_paths.append((path, backend))
105 return _Backend()
106
107 teacher = SelfTeacher(tmp_path / "doc.dlm", loader=_loader)
108 out1 = teacher.generate(
109 "system text",
110 "user text",
111 max_new_tokens=33,
112 temperature=0.7,
113 top_p=0.9,
114 seed=7,
115 )
116 out2 = teacher.generate("system text", "user text")
117
118 assert out1 == "synthesized answer"
119 assert out2 == "synthesized answer"
120 assert loaded_paths == [(tmp_path / "doc.dlm", "auto")]
121 assert "system text" in calls[0][0]
122 assert "user text" in calls[0][0]
123 assert calls[0][1] == {
124 "max_new_tokens": 33,
125 "temperature": 0.7,
126 "top_p": 0.9,
127 }
128
129
130 class TestHfTeacher:
131 def test_blank_hf_id_refused(self) -> None:
132 with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
133 HfTeacher(" ")
134
135 def test_hf_teacher_uses_loader_and_runner(self) -> None:
136 seen: dict[str, Any] = {}
137
138 def _loader(hf_id: str, device: str) -> teachers_mod._LoadedHfTeacher:
139 seen["loader"] = (hf_id, device)
140 return teachers_mod._LoadedHfTeacher(model="model", tokenizer="tok", device=device)
141
142 def _runner(
143 model: Any,
144 tokenizer: Any,
145 prompt: str,
146 *,
147 max_new_tokens: int,
148 temperature: float,
149 top_p: float | None,
150 seed: int | None,
151 ) -> str:
152 seen["runner"] = (model, tokenizer, prompt, max_new_tokens, temperature, top_p, seed)
153 return " hf output "
154
155 teacher = HfTeacher("Qwen/Qwen2.5-1.5B-Instruct", loader=_loader, runner=_runner)
156 out = teacher.generate(
157 "system", "user", max_new_tokens=21, temperature=0.5, top_p=0.8, seed=11
158 )
159 assert out == "hf output"
160 assert seen["loader"] == (
161 "Qwen/Qwen2.5-1.5B-Instruct",
162 teachers_mod._resolve_generation_device("auto"),
163 )
164 assert seen["runner"][3:] == (21, 0.5, 0.8, 11)
165
166 def test_hf_teacher_reuses_loaded_bundle(self) -> None:
167 loads: list[tuple[str, str]] = []
168
169 def _loader(hf_id: str, device: str) -> teachers_mod._LoadedHfTeacher:
170 loads.append((hf_id, device))
171 return teachers_mod._LoadedHfTeacher(model="model", tokenizer="tok", device=device)
172
173 teacher = HfTeacher(
174 "Qwen/Qwen2.5-1.5B-Instruct",
175 loader=_loader,
176 runner=lambda *_args, **_kwargs: "ok",
177 )
178
179 assert teacher.generate("system", "user") == "ok"
180 assert teacher.generate("system", "user") == "ok"
181 assert loads == [
182 ("Qwen/Qwen2.5-1.5B-Instruct", teachers_mod._resolve_generation_device("auto"))
183 ]
184
185
186 class TestOpenAiTeacher:
187 def test_blank_model_refused(self) -> None:
188 with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
189 OpenAiTeacher(" ")
190
191 def test_missing_api_key_refused(self, monkeypatch: pytest.MonkeyPatch) -> None:
192 monkeypatch.delenv("OPENAI_API_KEY", raising=False)
193 teacher = OpenAiTeacher("gpt-4o-mini")
194 with pytest.raises(TeacherUnavailableError, match="OPENAI_API_KEY"):
195 teacher.generate("system", "user")
196
197 def test_openai_teacher_extracts_message_text(self, monkeypatch: pytest.MonkeyPatch) -> None:
198 monkeypatch.setenv("OPENAI_API_KEY", "secret")
199
200 payloads: list[dict[str, Any]] = []
201 factories: list[str] = []
202
203 def _create(**kwargs: Any) -> Any:
204 payloads.append(kwargs)
205 return SimpleNamespace(
206 choices=[SimpleNamespace(message=SimpleNamespace(content=" generated "))]
207 )
208
209 def _factory(api_key: str) -> Any:
210 factories.append(api_key)
211 return client
212
213 client = SimpleNamespace(
214 chat=SimpleNamespace(
215 completions=SimpleNamespace(create=_create),
216 )
217 )
218
219 teacher = OpenAiTeacher(
220 "gpt-4o-mini",
221 client_factory=_factory,
222 )
223 out = teacher.generate("sys", "usr", max_new_tokens=17, temperature=0.3, top_p=0.7, seed=5)
224 second = teacher.generate("sys", "usr")
225 assert out == "generated"
226 assert second == "generated"
227 assert payloads[0]["model"] == "gpt-4o-mini"
228 assert payloads[0]["seed"] == 5
229 assert factories == ["secret"]
230
231 def test_openai_teacher_wraps_request_failures(self, monkeypatch: pytest.MonkeyPatch) -> None:
232 monkeypatch.setenv("OPENAI_API_KEY", "secret")
233
234 def _create(**_kwargs: Any) -> Any:
235 raise RuntimeError("boom")
236
237 client = SimpleNamespace(
238 chat=SimpleNamespace(
239 completions=SimpleNamespace(create=_create),
240 )
241 )
242 teacher = OpenAiTeacher("gpt-4o-mini", client_factory=lambda _api_key: client)
243
244 with pytest.raises(TeacherInvocationError, match="openai:gpt-4o-mini request failed: boom"):
245 teacher.generate("sys", "usr")
246
247
248 class TestAnthropicTeacher:
249 def test_blank_model_refused(self) -> None:
250 with pytest.raises(InvalidTeacherSpecError, match="must include a model id"):
251 AnthropicTeacher(" ")
252
253 def test_missing_api_key_refused(self, monkeypatch: pytest.MonkeyPatch) -> None:
254 monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
255 teacher = AnthropicTeacher("claude-3-5-haiku-latest")
256 with pytest.raises(TeacherUnavailableError, match="ANTHROPIC_API_KEY"):
257 teacher.generate("system", "user")
258
259 def test_anthropic_teacher_extracts_text_blocks(self, monkeypatch: pytest.MonkeyPatch) -> None:
260 monkeypatch.setenv("ANTHROPIC_API_KEY", "secret")
261 captured: dict[str, Any] = {}
262 factories: list[str] = []
263
264 class _Messages:
265 @staticmethod
266 def create(**kwargs: Any) -> Any:
267 captured["payload"] = kwargs
268 return SimpleNamespace(
269 content=[
270 SimpleNamespace(type="image", text="ignored"),
271 SimpleNamespace(type="text", text=" first "),
272 SimpleNamespace(type="text", text=" second "),
273 ]
274 )
275
276 class _Client:
277 messages = _Messages()
278
279 def _factory(api_key: str) -> _Client:
280 factories.append(api_key)
281 return _Client()
282
283 teacher = AnthropicTeacher(
284 "claude-3-5-haiku-latest",
285 client_factory=_factory,
286 )
287 out = teacher.generate("sys", "usr", max_new_tokens=19, temperature=0.2, top_p=0.6)
288 second = teacher.generate("sys", "usr")
289 assert out == "first\nsecond"
290 assert second == "first\nsecond"
291 assert captured["payload"]["model"] == "claude-3-5-haiku-latest"
292 assert factories == ["secret"]
293
294 def test_anthropic_teacher_wraps_request_failures(
295 self, monkeypatch: pytest.MonkeyPatch
296 ) -> None:
297 monkeypatch.setenv("ANTHROPIC_API_KEY", "secret")
298
299 class _Messages:
300 @staticmethod
301 def create(**_kwargs: Any) -> Any:
302 raise RuntimeError("boom")
303
304 class _Client:
305 messages = _Messages()
306
307 teacher = AnthropicTeacher(
308 "claude-3-5-haiku-latest",
309 client_factory=lambda _api_key: _Client(),
310 )
311
312 with pytest.raises(
313 TeacherInvocationError,
314 match="anthropic:claude-3-5-haiku-latest request failed: boom",
315 ):
316 teacher.generate("sys", "usr")
317
318
319 class TestVllmServerTeacher:
320 def test_blank_url_refused(self) -> None:
321 with pytest.raises(InvalidTeacherSpecError, match="must include a URL"):
322 VllmServerTeacher(" ")
323
324 def test_invalid_url_refused(self) -> None:
325 with pytest.raises(InvalidTeacherSpecError, match="http\\(s\\)"):
326 VllmServerTeacher("localhost:8000")
327
328 def test_vllm_teacher_queries_model_and_completion(
329 self, monkeypatch: pytest.MonkeyPatch
330 ) -> None:
331 model_calls: list[tuple[str, float]] = []
332 completion_calls: list[tuple[Any, ...]] = []
333
334 def _fake_models(base_url: str, *, request_timeout: float) -> str | None:
335 model_calls.append((base_url, request_timeout))
336 return "demo-model"
337
338 def _fake_completion(
339 base_url: str,
340 *,
341 model_id: str | None,
342 messages: list[dict[str, str]],
343 max_new_tokens: int,
344 temperature: float,
345 top_p: float | None,
346 seed: int | None,
347 request_timeout: float,
348 ) -> str:
349 completion_calls.append(
350 (
351 base_url,
352 model_id,
353 messages,
354 max_new_tokens,
355 temperature,
356 top_p,
357 seed,
358 request_timeout,
359 )
360 )
361 return " served "
362
363 monkeypatch.setattr(teachers_mod, "_fetch_openai_compat_model_id", _fake_models)
364 monkeypatch.setattr(teachers_mod, "_request_openai_compat_completion", _fake_completion)
365
366 teacher = VllmServerTeacher("http://127.0.0.1:8000")
367 out = teacher.generate("sys", "usr", max_new_tokens=29, temperature=0.4, top_p=0.75, seed=9)
368 second = teacher.generate("sys", "usr")
369
370 assert out == "served"
371 assert second == "served"
372 assert model_calls == [("http://127.0.0.1:8000", 30.0)]
373 assert completion_calls[0][1] == "demo-model"
374 assert completion_calls[0][3:] == (29, 0.4, 0.75, 9, 30.0)
375
376
377 class TestTeacherHelpers:
378 def test_flatten_teacher_prompt_handles_partial_inputs(self) -> None:
379 assert teachers_mod._flatten_teacher_prompt("system", "user").startswith("System:\n")
380 assert teachers_mod._flatten_teacher_prompt("", "user") == "user"
381 assert teachers_mod._flatten_teacher_prompt("system", "") == "system"
382
383 def test_require_non_empty_teacher_output_refuses_blank_text(self) -> None:
384 with pytest.raises(TeacherInvocationError, match="self returned empty output"):
385 teachers_mod._require_non_empty_teacher_output(" ", teacher="self")
386
387 def test_extract_openai_message_text_handles_list_content_and_errors(self) -> None:
388 response = {
389 "choices": [
390 {
391 "message": {
392 "content": [
393 {"text": " first "},
394 {"text": " second "},
395 ]
396 }
397 }
398 ]
399 }
400 assert teachers_mod._extract_openai_message_text(response) == "first\nsecond"
401
402 with pytest.raises(TeacherInvocationError, match="missing choices"):
403 teachers_mod._extract_openai_message_text({})
404
405 with pytest.raises(TeacherInvocationError, match="missing choices\\[0\\]\\.message"):
406 teachers_mod._extract_openai_message_text({"choices": [{}]})
407
408 with pytest.raises(TeacherInvocationError, match="missing non-empty message content"):
409 teachers_mod._extract_openai_message_text({"choices": [{"message": {"content": None}}]})
410
411 def test_extract_anthropic_text_handles_errors(self) -> None:
412 with pytest.raises(TeacherInvocationError, match="missing content blocks"):
413 teachers_mod._extract_anthropic_text({})
414
415 with pytest.raises(TeacherInvocationError, match="missing non-empty text blocks"):
416 teachers_mod._extract_anthropic_text(
417 {"content": [{"type": "image", "text": "ignored"}, {"type": "text", "text": " "}]}
418 )
419
420 def test_normalize_chat_content_and_obj_get_helpers(self) -> None:
421 assert teachers_mod._normalize_chat_content(" hello ") == "hello"
422 assert (
423 teachers_mod._normalize_chat_content([{"text": " one "}, {"text": " two "}])
424 == "one\ntwo"
425 )
426 assert teachers_mod._normalize_chat_content([{"text": " "}]) is None
427 assert teachers_mod._normalize_chat_content(123) is None
428 assert teachers_mod._obj_get({"name": "value"}, "name") == "value"
429 assert teachers_mod._obj_get(SimpleNamespace(name="value"), "name") == "value"
430
431 def test_openai_compat_url_helpers_normalize_suffixes(self) -> None:
432 assert (
433 teachers_mod._normalize_openai_compat_base_url(
434 "http://127.0.0.1:8000/v1/chat/completions"
435 )
436 == "http://127.0.0.1:8000"
437 )
438 assert (
439 teachers_mod._normalize_openai_compat_base_url("http://127.0.0.1:8000/chat/completions")
440 == "http://127.0.0.1:8000"
441 )
442 assert teachers_mod._openai_compat_models_url("http://127.0.0.1:8000/v1") == (
443 "http://127.0.0.1:8000/v1/models"
444 )
445 assert teachers_mod._openai_compat_models_url("http://127.0.0.1:8000") == (
446 "http://127.0.0.1:8000/v1/models"
447 )
448 assert teachers_mod._openai_compat_chat_url("http://127.0.0.1:8000/v1") == (
449 "http://127.0.0.1:8000/v1/chat/completions"
450 )
451 assert teachers_mod._openai_compat_chat_url("http://127.0.0.1:8000") == (
452 "http://127.0.0.1:8000/v1/chat/completions"
453 )
454
455
456 class TestTeacherRuntimeHelpers:
457 def test_resolve_generation_device_prefers_requested_or_detected_backends(
458 self,
459 monkeypatch: pytest.MonkeyPatch,
460 ) -> None:
461 assert teachers_mod._resolve_generation_device("mps") == "mps"
462
463 monkeypatch.delitem(sys.modules, "torch", raising=False)
464 real_import = builtins.__import__
465
466 def _missing_torch(
467 name: str,
468 globals: dict[str, object] | None = None,
469 locals: dict[str, object] | None = None,
470 fromlist: tuple[str, ...] = (),
471 level: int = 0,
472 ) -> object:
473 if name == "torch":
474 raise ImportError("no torch")
475 return real_import(name, globals, locals, fromlist, level)
476
477 monkeypatch.setattr(builtins, "__import__", _missing_torch)
478 assert teachers_mod._resolve_generation_device("auto") == "cpu"
479
480 monkeypatch.setattr(builtins, "__import__", real_import)
481 monkeypatch.setitem(
482 sys.modules,
483 "torch",
484 SimpleNamespace(
485 cuda=SimpleNamespace(is_available=lambda: True),
486 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
487 ),
488 )
489 assert teachers_mod._resolve_generation_device("auto") == "cuda"
490
491 monkeypatch.setitem(
492 sys.modules,
493 "torch",
494 SimpleNamespace(
495 cuda=SimpleNamespace(is_available=lambda: False),
496 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: True)),
497 ),
498 )
499 assert teachers_mod._resolve_generation_device("auto") == "mps"
500
501 monkeypatch.setitem(
502 sys.modules,
503 "torch",
504 SimpleNamespace(
505 cuda=SimpleNamespace(is_available=lambda: False),
506 backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)),
507 ),
508 )
509 assert teachers_mod._resolve_generation_device("auto") == "cpu"
510
511 def test_default_openai_client_validates_import_surface(
512 self,
513 monkeypatch: pytest.MonkeyPatch,
514 ) -> None:
515 def _raise_import(name: str) -> object:
516 raise ImportError(name)
517
518 monkeypatch.setattr("dlm.synth.teachers.importlib.import_module", _raise_import)
519 with pytest.raises(TeacherUnavailableError, match="requires the openai package"):
520 teachers_mod._default_openai_client("secret")
521
522 monkeypatch.setattr(
523 "dlm.synth.teachers.importlib.import_module", lambda _name: SimpleNamespace()
524 )
525 with pytest.raises(TeacherUnavailableError, match="does not expose OpenAI client"):
526 teachers_mod._default_openai_client("secret")
527
528 captured: list[str] = []
529
530 class _OpenAI:
531 def __init__(self, *, api_key: str) -> None:
532 captured.append(api_key)
533
534 monkeypatch.setattr(
535 "dlm.synth.teachers.importlib.import_module",
536 lambda _name: SimpleNamespace(OpenAI=_OpenAI),
537 )
538 client = teachers_mod._default_openai_client("secret")
539 assert isinstance(client, _OpenAI)
540 assert captured == ["secret"]
541
542 def test_default_anthropic_client_validates_import_surface(
543 self,
544 monkeypatch: pytest.MonkeyPatch,
545 ) -> None:
546 def _raise_import(name: str) -> object:
547 raise ImportError(name)
548
549 monkeypatch.setattr("dlm.synth.teachers.importlib.import_module", _raise_import)
550 with pytest.raises(TeacherUnavailableError, match="requires the anthropic package"):
551 teachers_mod._default_anthropic_client("secret")
552
553 monkeypatch.setattr(
554 "dlm.synth.teachers.importlib.import_module", lambda _name: SimpleNamespace()
555 )
556 with pytest.raises(TeacherUnavailableError, match="does not expose Anthropic client"):
557 teachers_mod._default_anthropic_client("secret")
558
559 captured: list[str] = []
560
561 class _Anthropic:
562 def __init__(self, *, api_key: str) -> None:
563 captured.append(api_key)
564
565 monkeypatch.setattr(
566 "dlm.synth.teachers.importlib.import_module",
567 lambda _name: SimpleNamespace(Anthropic=_Anthropic),
568 )
569 client = teachers_mod._default_anthropic_client("secret")
570 assert isinstance(client, _Anthropic)
571 assert captured == ["secret"]
572
573 def test_fetch_openai_compat_model_id_handles_success_empty_and_errors(
574 self,
575 monkeypatch: pytest.MonkeyPatch,
576 ) -> None:
577 class _Response:
578 def __init__(self, payload: object) -> None:
579 self._payload = payload
580
581 def __enter__(self) -> _Response:
582 return self
583
584 def __exit__(self, *_args: object) -> Literal[False]:
585 return False
586
587 def read(self) -> bytes:
588 return json.dumps(self._payload).encode("utf-8")
589
590 monkeypatch.setattr(
591 "dlm.synth.teachers.urllib.request.urlopen",
592 lambda *_args, **_kwargs: _Response({"data": [{"id": "demo-model"}]}),
593 )
594 assert (
595 teachers_mod._fetch_openai_compat_model_id(
596 "http://127.0.0.1:8000",
597 request_timeout=1.0,
598 )
599 == "demo-model"
600 )
601
602 monkeypatch.setattr(
603 "dlm.synth.teachers.urllib.request.urlopen",
604 lambda *_args, **_kwargs: _Response({"data": []}),
605 )
606 assert (
607 teachers_mod._fetch_openai_compat_model_id(
608 "http://127.0.0.1:8000",
609 request_timeout=1.0,
610 )
611 is None
612 )
613
614 monkeypatch.setattr(
615 "dlm.synth.teachers.urllib.request.urlopen",
616 lambda *_args, **_kwargs: _Response({"data": [{"id": " "}]}),
617 )
618 assert (
619 teachers_mod._fetch_openai_compat_model_id(
620 "http://127.0.0.1:8000",
621 request_timeout=1.0,
622 )
623 is None
624 )
625
626 def _raise_url_error(*_args: object, **_kwargs: object) -> object:
627 raise urllib.error.URLError("boom")
628
629 monkeypatch.setattr("dlm.synth.teachers.urllib.request.urlopen", _raise_url_error)
630 with pytest.raises(TeacherUnavailableError, match="could not query models"):
631 teachers_mod._fetch_openai_compat_model_id(
632 "http://127.0.0.1:8000",
633 request_timeout=1.0,
634 )
635
636 def test_request_openai_compat_completion_handles_success_and_failures(
637 self,
638 monkeypatch: pytest.MonkeyPatch,
639 ) -> None:
640 class _Response:
641 def __init__(self, payload: object) -> None:
642 self._payload = payload
643
644 def __enter__(self) -> _Response:
645 return self
646
647 def __exit__(self, *_args: object) -> Literal[False]:
648 return False
649
650 def read(self) -> bytes:
651 return json.dumps(self._payload).encode("utf-8")
652
653 monkeypatch.setattr(
654 "dlm.synth.teachers.urllib.request.urlopen",
655 lambda *_args, **_kwargs: _Response(
656 {"choices": [{"message": {"content": [{"text": " served "}]}}]}
657 ),
658 )
659 assert (
660 teachers_mod._request_openai_compat_completion(
661 "http://127.0.0.1:8000",
662 model_id="demo-model",
663 messages=[{"role": "user", "content": "hello"}],
664 max_new_tokens=11,
665 temperature=0.2,
666 top_p=0.8,
667 seed=5,
668 request_timeout=1.0,
669 )
670 == "served"
671 )
672
673 monkeypatch.setattr(
674 "dlm.synth.teachers.urllib.request.urlopen",
675 lambda *_args, **_kwargs: _Response({"choices": []}),
676 )
677 with pytest.raises(TeacherInvocationError, match="response missing choices"):
678 teachers_mod._request_openai_compat_completion(
679 "http://127.0.0.1:8000",
680 model_id=None,
681 messages=[{"role": "user", "content": "hello"}],
682 max_new_tokens=11,
683 temperature=0.2,
684 top_p=None,
685 seed=None,
686 request_timeout=1.0,
687 )
688
689 monkeypatch.setattr(
690 "dlm.synth.teachers.urllib.request.urlopen",
691 lambda *_args, **_kwargs: _Response({"choices": [{}]}),
692 )
693 with pytest.raises(
694 TeacherInvocationError, match="response missing choices\\[0\\]\\.message"
695 ):
696 teachers_mod._request_openai_compat_completion(
697 "http://127.0.0.1:8000",
698 model_id=None,
699 messages=[{"role": "user", "content": "hello"}],
700 max_new_tokens=11,
701 temperature=0.2,
702 top_p=None,
703 seed=None,
704 request_timeout=1.0,
705 )
706
707 monkeypatch.setattr(
708 "dlm.synth.teachers.urllib.request.urlopen",
709 lambda *_args, **_kwargs: _Response(
710 {"choices": [{"message": {"content": [{"text": " "}]}}]}
711 ),
712 )
713 with pytest.raises(TeacherInvocationError, match="missing non-empty message content"):
714 teachers_mod._request_openai_compat_completion(
715 "http://127.0.0.1:8000",
716 model_id=None,
717 messages=[{"role": "user", "content": "hello"}],
718 max_new_tokens=11,
719 temperature=0.2,
720 top_p=None,
721 seed=None,
722 request_timeout=1.0,
723 )
724
725 def _raise_url_error(*_args: object, **_kwargs: object) -> object:
726 raise urllib.error.URLError("boom")
727
728 monkeypatch.setattr("dlm.synth.teachers.urllib.request.urlopen", _raise_url_error)
729 with pytest.raises(TeacherInvocationError, match="request to http://127.0.0.1:8000 failed"):
730 teachers_mod._request_openai_compat_completion(
731 "http://127.0.0.1:8000",
732 model_id=None,
733 messages=[{"role": "user", "content": "hello"}],
734 max_new_tokens=11,
735 temperature=0.2,
736 top_p=None,
737 seed=None,
738 request_timeout=1.0,
739 )
740
741
742 def _install_self_loader_modules(
743 monkeypatch: pytest.MonkeyPatch,
744 *,
745 manifest_exists: bool = True,
746 license_acceptance: object | None = "accepted",
747 load_manifest_error: str | None = None,
748 resolve_error: str | None = None,
749 select_error: str | None = None,
750 backend_load_error: str | None = None,
751 ) -> dict[str, object]:
752 calls: dict[str, object] = {}
753 spec = object()
754 caps = object()
755 parsed = SimpleNamespace(
756 frontmatter=SimpleNamespace(
757 dlm_id="01KPQ9X1000000000000000000",
758 base_model="smollm2-135m",
759 )
760 )
761 manifest = SimpleNamespace(exists=lambda: manifest_exists)
762 store = SimpleNamespace(manifest=manifest)
763
764 class GatedModelError(Exception):
765 pass
766
767 class AdapterNotFoundError(Exception):
768 pass
769
770 class UnsupportedBackendError(Exception):
771 pass
772
773 class ManifestCorruptError(Exception):
774 pass
775
776 class _Backend:
777 def load(self, spec_arg: object, store_arg: object) -> None:
778 calls["load"] = (spec_arg, store_arg)
779 if backend_load_error is not None:
780 raise AdapterNotFoundError(backend_load_error)
781
782 backend = _Backend()
783
784 def _resolve(base_model: str, *, accept_license: bool) -> object:
785 calls["resolve"] = (base_model, accept_license)
786 if resolve_error is not None:
787 raise GatedModelError(resolve_error)
788 return spec
789
790 def _load_manifest(_path: object) -> object:
791 calls["load_manifest"] = True
792 if load_manifest_error is not None:
793 raise ManifestCorruptError(load_manifest_error)
794 return SimpleNamespace(license_acceptance=license_acceptance)
795
796 def _select_backend(backend_name: str, capabilities: object) -> str:
797 calls["select_backend"] = (backend_name, capabilities)
798 if select_error is not None:
799 raise UnsupportedBackendError(select_error)
800 return "stub-backend"
801
802 def _build_backend(name: str, capabilities: object) -> object:
803 calls["build_backend"] = (name, capabilities)
804 return backend
805
806 monkeypatch.setitem(
807 sys.modules, "dlm.base_models", _module("dlm.base_models", resolve=_resolve)
808 )
809 monkeypatch.setitem(
810 sys.modules,
811 "dlm.base_models.errors",
812 _module("dlm.base_models.errors", GatedModelError=GatedModelError),
813 )
814 monkeypatch.setitem(
815 sys.modules,
816 "dlm.doc.parser",
817 _module("dlm.doc.parser", parse_file=lambda _path: parsed),
818 )
819 monkeypatch.setitem(
820 sys.modules,
821 "dlm.hardware",
822 _module("dlm.hardware", doctor=lambda: SimpleNamespace(capabilities=caps)),
823 )
824 monkeypatch.setitem(
825 sys.modules,
826 "dlm.inference",
827 _module("dlm.inference", AdapterNotFoundError=AdapterNotFoundError),
828 )
829 monkeypatch.setitem(
830 sys.modules,
831 "dlm.inference.backends",
832 _module(
833 "dlm.inference.backends", build_backend=_build_backend, select_backend=_select_backend
834 ),
835 )
836 monkeypatch.setitem(
837 sys.modules,
838 "dlm.inference.backends.select",
839 _module("dlm.inference.backends.select", UnsupportedBackendError=UnsupportedBackendError),
840 )
841 monkeypatch.setitem(
842 sys.modules,
843 "dlm.store.errors",
844 _module("dlm.store.errors", ManifestCorruptError=ManifestCorruptError),
845 )
846 monkeypatch.setitem(
847 sys.modules,
848 "dlm.store.manifest",
849 _module("dlm.store.manifest", load_manifest=_load_manifest),
850 )
851 monkeypatch.setitem(
852 sys.modules,
853 "dlm.store.paths",
854 _module("dlm.store.paths", for_dlm=lambda _dlm_id: store),
855 )
856
857 calls["caps"] = caps
858 calls["store"] = store
859 calls["spec"] = spec
860 calls["errors"] = {
861 "gated": GatedModelError,
862 "adapter": AdapterNotFoundError,
863 "unsupported": UnsupportedBackendError,
864 "manifest": ManifestCorruptError,
865 }
866 return calls
867
868
869 class TestTeacherLoaderHelpers:
870 def test_load_self_backend_wraps_import_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
871 real_import = builtins.__import__
872
873 def _raise_on_base_models(
874 name: str,
875 globals: dict[str, object] | None = None,
876 locals: dict[str, object] | None = None,
877 fromlist: tuple[str, ...] = (),
878 level: int = 0,
879 ) -> object:
880 if name.startswith("dlm.base_models"):
881 raise ImportError("boom")
882 return real_import(name, globals, locals, fromlist, level)
883
884 monkeypatch.setattr(builtins, "__import__", _raise_on_base_models)
885 with pytest.raises(TeacherUnavailableError, match="requires the local inference stack"):
886 teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
887
888 def test_load_self_backend_uses_recorded_license_acceptance(
889 self,
890 monkeypatch: pytest.MonkeyPatch,
891 ) -> None:
892 calls = _install_self_loader_modules(monkeypatch, license_acceptance="accepted")
893
894 backend = teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
895
896 assert backend is not None
897 assert calls["resolve"] == ("smollm2-135m", True)
898 assert calls["select_backend"] == ("auto", calls["caps"])
899 assert calls["build_backend"] == ("stub-backend", calls["caps"])
900 assert calls["load"] == (calls["spec"], calls["store"])
901
902 def test_load_self_backend_tolerates_manifest_read_failure(
903 self,
904 monkeypatch: pytest.MonkeyPatch,
905 ) -> None:
906 calls = _install_self_loader_modules(
907 monkeypatch,
908 load_manifest_error="bad manifest",
909 )
910
911 teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
912
913 assert calls["resolve"] == ("smollm2-135m", False)
914
915 def test_load_self_backend_wraps_gated_backend_and_adapter_failures(
916 self,
917 monkeypatch: pytest.MonkeyPatch,
918 ) -> None:
919 _install_self_loader_modules(monkeypatch, resolve_error="gated")
920 with pytest.raises(TeacherUnavailableError, match="cannot resolve gated base"):
921 teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
922
923 _install_self_loader_modules(monkeypatch, select_error="unsupported backend")
924 with pytest.raises(TeacherUnavailableError, match="unsupported backend"):
925 teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
926
927 _install_self_loader_modules(monkeypatch, backend_load_error="missing adapter")
928 with pytest.raises(TeacherUnavailableError, match="requires a trained adapter"):
929 teachers_mod._load_self_backend(Path("/tmp/doc.dlm"), "auto")
930
931 def test_default_hf_loader_wraps_import_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
932 real_import = builtins.__import__
933
934 def _raise_transformers(
935 name: str,
936 globals: dict[str, object] | None = None,
937 locals: dict[str, object] | None = None,
938 fromlist: tuple[str, ...] = (),
939 level: int = 0,
940 ) -> object:
941 if name == "transformers":
942 raise ImportError("boom")
943 return real_import(name, globals, locals, fromlist, level)
944
945 monkeypatch.setattr(builtins, "__import__", _raise_transformers)
946 with pytest.raises(TeacherUnavailableError, match="requires transformers"):
947 teachers_mod._default_hf_loader("hf/model", "cpu")
948
949 def test_default_hf_loader_moves_model_and_sets_eval(
950 self,
951 monkeypatch: pytest.MonkeyPatch,
952 ) -> None:
953 seen: dict[str, object] = {}
954
955 class _Model:
956 def to(self, device: str) -> _Model:
957 seen["device"] = device
958 return self
959
960 def eval(self) -> None:
961 seen["eval"] = True
962
963 model = _Model()
964
965 class AutoModelForCausalLM:
966 @staticmethod
967 def from_pretrained(hf_id: str) -> _Model:
968 seen["model_id"] = hf_id
969 return model
970
971 class AutoTokenizer:
972 @staticmethod
973 def from_pretrained(hf_id: str) -> str:
974 seen["tokenizer_id"] = hf_id
975 return "tok"
976
977 monkeypatch.setitem(
978 sys.modules,
979 "transformers",
980 _module(
981 "transformers",
982 AutoModelForCausalLM=AutoModelForCausalLM,
983 AutoTokenizer=AutoTokenizer,
984 ),
985 )
986
987 loaded = teachers_mod._default_hf_loader("hf/model", "cuda")
988
989 assert loaded.model is model
990 assert loaded.tokenizer == "tok"
991 assert loaded.device == "cuda"
992 assert seen == {
993 "model_id": "hf/model",
994 "tokenizer_id": "hf/model",
995 "device": "cuda",
996 "eval": True,
997 }
998
999 def test_default_hf_generate_seeds_torch_and_calls_runner(
1000 self,
1001 monkeypatch: pytest.MonkeyPatch,
1002 ) -> None:
1003 manual: list[int] = []
1004 manual_all: list[int] = []
1005 calls: dict[str, object] = {}
1006
1007 def _generate(
1008 model: object,
1009 tokenizer: object,
1010 prompt: str,
1011 *,
1012 max_new_tokens: int,
1013 temperature: float,
1014 top_p: float | None,
1015 ) -> str:
1016 calls["args"] = (model, tokenizer, prompt, max_new_tokens, temperature, top_p)
1017 return "ok"
1018
1019 monkeypatch.setitem(
1020 sys.modules,
1021 "dlm.inference.generate",
1022 _module("dlm.inference.generate", generate=_generate),
1023 )
1024 monkeypatch.setitem(
1025 sys.modules,
1026 "torch",
1027 SimpleNamespace(
1028 manual_seed=lambda seed: manual.append(seed),
1029 cuda=SimpleNamespace(
1030 is_available=lambda: True,
1031 manual_seed_all=lambda seed: manual_all.append(seed),
1032 ),
1033 ),
1034 )
1035
1036 out = teachers_mod._default_hf_generate(
1037 "model",
1038 "tokenizer",
1039 "prompt",
1040 max_new_tokens=17,
1041 temperature=0.3,
1042 top_p=0.8,
1043 seed=7,
1044 )
1045
1046 assert out == "ok"
1047 assert manual == [7]
1048 assert manual_all == [7]
1049 assert calls["args"] == ("model", "tokenizer", "prompt", 17, 0.3, 0.8)
1050
1051 def test_default_hf_generate_tolerates_missing_torch_when_seeding(
1052 self,
1053 monkeypatch: pytest.MonkeyPatch,
1054 ) -> None:
1055 real_import = builtins.__import__
1056
1057 def _generate(
1058 model: object,
1059 tokenizer: object,
1060 prompt: str,
1061 *,
1062 max_new_tokens: int,
1063 temperature: float,
1064 top_p: float | None,
1065 ) -> str:
1066 _ = model, tokenizer, prompt, max_new_tokens, temperature, top_p
1067 return "ok"
1068
1069 def _raise_torch(
1070 name: str,
1071 globals: dict[str, object] | None = None,
1072 locals: dict[str, object] | None = None,
1073 fromlist: tuple[str, ...] = (),
1074 level: int = 0,
1075 ) -> object:
1076 if name == "torch":
1077 raise ImportError("no torch")
1078 return real_import(name, globals, locals, fromlist, level)
1079
1080 monkeypatch.setitem(
1081 sys.modules,
1082 "dlm.inference.generate",
1083 _module("dlm.inference.generate", generate=_generate),
1084 )
1085 monkeypatch.delitem(sys.modules, "torch", raising=False)
1086 monkeypatch.setattr(builtins, "__import__", _raise_torch)
1087
1088 out = teachers_mod._default_hf_generate(
1089 "model",
1090 "tokenizer",
1091 "prompt",
1092 max_new_tokens=17,
1093 temperature=0.3,
1094 top_p=0.8,
1095 seed=7,
1096 )
1097
1098 assert out == "ok"