tenseleyflow/documentlanguagemodel / 61af0d3

Browse files

Prove external synth teachers

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
61af0d334f2b67cb28a90cc09c09e3892fb79abf
Parents
1556d38
Tree
12cf152

2 changed files

StatusFile+-
A tests/integration/synth/test_vllm_teacher.py 151 0
M tests/unit/cli/test_synth_cmd.py 58 0
tests/integration/synth/test_vllm_teacher.pyadded
@@ -0,0 +1,151 @@
1
+"""Integration proof for the OpenAI-compatible `vllm-server` synth teacher."""
2
+
3
+from __future__ import annotations
4
+
5
+import json
6
+import threading
7
+from collections.abc import Iterator
8
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
9
+from pathlib import Path
10
+
11
+import pytest
12
+from typer.testing import CliRunner
13
+
14
+from dlm.cli.app import app
15
+from dlm.doc.parser import parse_file
16
+from dlm.doc.sections import SectionType
17
+
18
+_DLM_ID = "01KSYNTHVLLM000000000000000"
19
+
20
+
21
+def _write_doc(path: Path) -> None:
22
+    path.write_text(
23
+        "---\n"
24
+        f"dlm_id: {_DLM_ID}\n"
25
+        "dlm_version: 15\n"
26
+        "base_model: smollm2-135m\n"
27
+        "---\n"
28
+        "DGEMM multiplies two dense matrices and can accumulate the result.\n",
29
+        encoding="utf-8",
30
+    )
31
+
32
+
33
+class _CompatHandler(BaseHTTPRequestHandler):
34
+    requests_seen: list[tuple[str, dict[str, object] | None]] = []
35
+
36
+    def log_message(self, format: str, *args: object) -> None:  # noqa: A003
37
+        _ = format, args
38
+        return
39
+
40
+    def do_GET(self) -> None:  # noqa: N802
41
+        if self.path != "/v1/models":
42
+            self.send_error(404)
43
+            return
44
+        self._write_json(200, {"data": [{"id": "stub-vllm-teacher"}]})
45
+        self.requests_seen.append((self.path, None))
46
+
47
+    def do_POST(self) -> None:  # noqa: N802
48
+        if self.path != "/v1/chat/completions":
49
+            self.send_error(404)
50
+            return
51
+        raw = self.rfile.read(int(self.headers.get("Content-Length", "0")))
52
+        payload = json.loads(raw.decode("utf-8"))
53
+        self.requests_seen.append((self.path, payload))
54
+        self._write_json(
55
+            200,
56
+            {
57
+                "choices": [
58
+                    {
59
+                        "message": {
60
+                            "content": (
61
+                                '[{"question":"What does DGEMM do?",'
62
+                                '"answer":"It multiplies dense matrices."}]'
63
+                            )
64
+                        }
65
+                    }
66
+                ]
67
+            },
68
+        )
69
+
70
+    def _write_json(self, status: int, payload: object) -> None:
71
+        body = json.dumps(payload).encode("utf-8")
72
+        self.send_response(status)
73
+        self.send_header("Content-Type", "application/json")
74
+        self.send_header("Content-Length", str(len(body)))
75
+        self.end_headers()
76
+        self.wfile.write(body)
77
+
78
+
79
+@pytest.fixture
80
+def compat_server() -> Iterator[str]:
81
+    try:
82
+        server = ThreadingHTTPServer(("127.0.0.1", 0), _CompatHandler)
83
+    except PermissionError as exc:
84
+        pytest.skip(f"loopback bind blocked on this host: {exc}")
85
+    _CompatHandler.requests_seen = []
86
+    thread = threading.Thread(target=server.serve_forever, daemon=True)
87
+    thread.start()
88
+    try:
89
+        address = server.server_address
90
+        host = str(address[0])
91
+        port = int(address[1])
92
+        yield f"http://{host}:{port}"
93
+    finally:
94
+        server.shutdown()
95
+        thread.join(timeout=5)
96
+        server.server_close()
97
+
98
+
99
+def test_synth_instructions_vllm_server_teacher_applies_sections(
100
+    tmp_path: Path,
101
+    compat_server: str,
102
+) -> None:
103
+    home = tmp_path / "home"
104
+    doc = tmp_path / "doc.dlm"
105
+    _write_doc(doc)
106
+
107
+    runner = CliRunner()
108
+    result = runner.invoke(
109
+        app,
110
+        [
111
+            "--home",
112
+            str(home),
113
+            "synth",
114
+            "instructions",
115
+            str(doc),
116
+            "--teacher",
117
+            f"vllm-server:{compat_server}",
118
+            "--filter",
119
+            "dedup-only",
120
+            "--per-section",
121
+            "1",
122
+            "--apply",
123
+        ],
124
+    )
125
+
126
+    assert result.exit_code == 0, result.output
127
+
128
+    parsed = parse_file(doc)
129
+    synth_sections = [
130
+        section
131
+        for section in parsed.sections
132
+        if section.type is SectionType.INSTRUCTION and section.auto_synth
133
+    ]
134
+    assert len(synth_sections) == 1
135
+    assert synth_sections[0].synth_teacher == f"vllm-server:{compat_server}"
136
+    assert synth_sections[0].synth_strategy == "extraction"
137
+
138
+    paths = [path for path, _payload in _CompatHandler.requests_seen]
139
+    assert "/v1/models" in paths
140
+    assert "/v1/chat/completions" in paths
141
+
142
+    chat_payload = next(
143
+        payload
144
+        for path, payload in _CompatHandler.requests_seen
145
+        if path == "/v1/chat/completions" and payload is not None
146
+    )
147
+    assert isinstance(chat_payload, dict)
148
+    assert chat_payload["model"] == "stub-vllm-teacher"
149
+    assert isinstance(chat_payload["messages"], list)
150
+    assert chat_payload["messages"][0]["role"] == "system"
151
+    assert chat_payload["messages"][1]["role"] == "user"
tests/unit/cli/test_synth_cmd.pymodified
@@ -370,3 +370,61 @@ class TestSynthCmd:
370370
         assert pending is not None
371371
         assert len(pending.sections) == 1
372372
         assert pending.sections[0].auto_mined is True
373
+
374
+    def test_openai_teacher_without_api_key_fails_cleanly(
375
+        self,
376
+        tmp_path: Path,
377
+        monkeypatch: pytest.MonkeyPatch,
378
+    ) -> None:
379
+        home = tmp_path / "home"
380
+        doc = tmp_path / "doc.dlm"
381
+        _write_synth_doc(doc)
382
+        monkeypatch.delenv("OPENAI_API_KEY", raising=False)
383
+
384
+        runner = CliRunner()
385
+        result = runner.invoke(
386
+            app,
387
+            [
388
+                "--home",
389
+                str(home),
390
+                "synth",
391
+                "instructions",
392
+                str(doc),
393
+                "--teacher",
394
+                "openai:gpt-4o-mini",
395
+                "--filter",
396
+                "none",
397
+            ],
398
+        )
399
+
400
+        assert result.exit_code == 1, result.output
401
+        assert "requires $OPENAI_API_KEY to be set" in _normalized_output(result)
402
+
403
+    def test_anthropic_teacher_without_api_key_fails_cleanly(
404
+        self,
405
+        tmp_path: Path,
406
+        monkeypatch: pytest.MonkeyPatch,
407
+    ) -> None:
408
+        home = tmp_path / "home"
409
+        doc = tmp_path / "doc.dlm"
410
+        _write_synth_doc(doc)
411
+        monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
412
+
413
+        runner = CliRunner()
414
+        result = runner.invoke(
415
+            app,
416
+            [
417
+                "--home",
418
+                str(home),
419
+                "synth",
420
+                "instructions",
421
+                str(doc),
422
+                "--teacher",
423
+                "anthropic:claude-3-5-haiku-latest",
424
+                "--filter",
425
+                "none",
426
+            ],
427
+        )
428
+
429
+        assert result.exit_code == 1, result.output
430
+        assert "requires $ANTHROPIC_API_KEY to be set" in _normalized_output(result)