Python · 9288 bytes Raw Blame History
1 """Unit tests for the probe-RPC server.
2
3 Exercises the handler directly via `urllib` against a real bound port
4 rather than mocking `BaseHTTPRequestHandler` internals — the surface is
5 narrow enough that actual-socket tests stay fast and have no flake
6 potential beyond "picked a used port," mitigated by letting the OS
7 assign (`port=0`).
8 """
9
10 from __future__ import annotations
11
12 import json
13 import socket
14 import urllib.error
15 import urllib.request
16 from collections.abc import Iterator
17 from typing import Any
18
19 import pytest
20
21 from dlm.train.inject import InjectedProbeQueue
22 from dlm.train.rpc import ProbeRpcServer, _check_bearer
23
24 _TOKEN = "test-token-123"
25
26
27 @pytest.fixture
28 def server() -> Iterator[ProbeRpcServer]:
29 queue = InjectedProbeQueue(capacity=4)
30 try:
31 srv = ProbeRpcServer(
32 host="127.0.0.1", port=0, token=_TOKEN, queue=queue, next_cycle_eta_s=lambda: 42
33 )
34 except PermissionError as exc:
35 pytest.skip(f"loopback bind blocked on this host: {exc}")
36 srv.start()
37 try:
38 yield srv
39 finally:
40 srv.stop()
41
42
43 def _post(
44 server: ProbeRpcServer,
45 *,
46 body: dict[str, Any] | str,
47 token: str | None = _TOKEN,
48 path: str = "/rpc",
49 ) -> tuple[int, dict[str, Any]]:
50 host, port = server.address
51 url = f"http://{host}:{port}{path}"
52 raw = body if isinstance(body, str) else json.dumps(body)
53 headers = {"Content-Type": "application/json"}
54 if token is not None:
55 headers["Authorization"] = f"Bearer {token}"
56 req = urllib.request.Request(url, data=raw.encode("utf-8"), headers=headers, method="POST")
57 try:
58 resp = urllib.request.urlopen(req, timeout=5.0) # noqa: S310
59 except urllib.error.HTTPError as exc:
60 return exc.code, json.loads(exc.read())
61 return resp.status, json.loads(resp.read())
62
63
64 def _raw_post(
65 server: ProbeRpcServer,
66 *,
67 headers: dict[str, str],
68 body: bytes = b"",
69 path: str = "/rpc",
70 ) -> tuple[int, dict[str, Any]]:
71 host, port = server.address
72 lines = [
73 f"POST {path} HTTP/1.1",
74 f"Host: {host}:{port}",
75 *[f"{key}: {value}" for key, value in headers.items()],
76 "",
77 "",
78 ]
79 request = "\r\n".join(lines).encode("utf-8") + body
80 with socket.create_connection((host, port), timeout=5.0) as sock:
81 sock.sendall(request)
82 response = b""
83 while b"\r\n\r\n" not in response:
84 response += sock.recv(4096)
85 head, rest = response.split(b"\r\n\r\n", 1)
86 header_lines = head.decode("iso-8859-1").split("\r\n")
87 status = int(header_lines[0].split()[1])
88 parsed_headers: dict[str, str] = {}
89 for line in header_lines[1:]:
90 if ":" not in line:
91 continue
92 key, value = line.split(":", 1)
93 parsed_headers[key.lower()] = value.strip()
94 content_length = int(parsed_headers.get("content-length", "0"))
95 while len(rest) < content_length:
96 rest += sock.recv(4096)
97 return status, json.loads(rest[:content_length].decode("utf-8"))
98
99
100 class TestHappyPath:
101 def test_inject_probe_accepted(self, server: ProbeRpcServer) -> None:
102 status, body = _post(
103 server,
104 body={
105 "method": "inject_probe",
106 "params": {"prompt": "what is X?", "reference": "Y.", "tags": ["sway"]},
107 },
108 )
109 assert status == 200
110 assert body == {"accepted": True, "next_cycle_eta_s": 42, "queue_depth": 1}
111 drained = server.queue.drain()
112 assert len(drained) == 1
113 assert drained[0].prompt == "what is X?"
114 assert drained[0].tags == ("sway",)
115
116
117 class TestAuth:
118 def test_missing_token_401(self, server: ProbeRpcServer) -> None:
119 status, body = _post(
120 server,
121 body={"method": "inject_probe", "params": {"prompt": "q", "reference": "a"}},
122 token=None,
123 )
124 assert status == 401
125 assert "bearer" in body["error"].lower()
126
127 def test_wrong_token_401(self, server: ProbeRpcServer) -> None:
128 status, body = _post(
129 server,
130 body={"method": "inject_probe", "params": {"prompt": "q", "reference": "a"}},
131 token="wrong-token",
132 )
133 assert status == 401
134
135
136 class TestMalformedPayload:
137 def test_bad_json_400(self, server: ProbeRpcServer) -> None:
138 status, body = _post(server, body="not json {")
139 assert status == 400
140 assert "malformed" in body["error"].lower()
141
142 def test_invalid_content_length_400(self, server: ProbeRpcServer) -> None:
143 status, body = _raw_post(
144 server,
145 headers={
146 "Authorization": f"Bearer {_TOKEN}",
147 "Content-Type": "application/json",
148 "Content-Length": "nope",
149 },
150 )
151 assert status == 400
152 assert "content-length" in body["error"].lower()
153
154 def test_empty_body_400(self, server: ProbeRpcServer) -> None:
155 status, body = _raw_post(
156 server,
157 headers={
158 "Authorization": f"Bearer {_TOKEN}",
159 "Content-Type": "application/json",
160 "Content-Length": "0",
161 },
162 )
163 assert status == 400
164 assert "empty body" in body["error"].lower()
165
166 def test_oversized_body_400(self, server: ProbeRpcServer) -> None:
167 status, body = _raw_post(
168 server,
169 headers={
170 "Authorization": f"Bearer {_TOKEN}",
171 "Content-Type": "application/json",
172 "Content-Length": str(70 * 1024),
173 },
174 )
175 assert status == 400
176 assert "exceeds" in body["error"].lower()
177
178 def test_payload_must_be_object(self, server: ProbeRpcServer) -> None:
179 status, body = _post(server, body="[]")
180 assert status == 400
181 assert "json object" in body["error"].lower()
182
183 def test_missing_prompt_400(self, server: ProbeRpcServer) -> None:
184 status, body = _post(server, body={"method": "inject_probe", "params": {"reference": "a"}})
185 assert status == 400
186 assert "prompt" in body["error"].lower()
187
188 def test_params_must_be_object(self, server: ProbeRpcServer) -> None:
189 status, body = _post(server, body={"method": "inject_probe", "params": "bad"})
190 assert status == 400
191 assert "`params`" in body["error"]
192
193 def test_empty_reference_400(self, server: ProbeRpcServer) -> None:
194 status, body = _post(
195 server,
196 body={"method": "inject_probe", "params": {"prompt": "q", "reference": " "}},
197 )
198 assert status == 400
199 assert "reference" in body["error"].lower()
200
201 def test_non_string_tags_400(self, server: ProbeRpcServer) -> None:
202 status, body = _post(
203 server,
204 body={
205 "method": "inject_probe",
206 "params": {"prompt": "q", "reference": "a", "tags": [1, 2]},
207 },
208 )
209 assert status == 400
210 assert "tags" in body["error"].lower()
211
212
213 class TestMethodDispatch:
214 def test_unknown_method_404(self, server: ProbeRpcServer) -> None:
215 status, body = _post(server, body={"method": "explode", "params": {}})
216 assert status == 404
217 assert "explode" in body["error"]
218
219 def test_unknown_path_404(self, server: ProbeRpcServer) -> None:
220 status, body = _post(server, body={"method": "inject_probe"}, path="/other")
221 assert status == 404
222
223
224 class TestCapacity:
225 def test_full_queue_429(self, server: ProbeRpcServer) -> None:
226 payload = {
227 "method": "inject_probe",
228 "params": {"prompt": "q", "reference": "a"},
229 }
230 for _ in range(4):
231 status, _ = _post(server, body=payload)
232 assert status == 200
233 status, body = _post(server, body=payload)
234 assert status == 429
235 assert body["queue_depth"] == 4
236
237
238 class TestAuthHelper:
239 def test_correct_token_matches(self) -> None:
240 assert _check_bearer("Bearer abc", "abc") is True
241
242 def test_wrong_token_fails(self) -> None:
243 assert _check_bearer("Bearer xyz", "abc") is False
244
245 def test_missing_prefix_fails(self) -> None:
246 assert _check_bearer("abc", "abc") is False
247
248 def test_empty_presented_fails(self) -> None:
249 assert _check_bearer("Bearer ", "abc") is False
250
251 def test_length_mismatch_fails(self) -> None:
252 assert _check_bearer("Bearer abcd", "abc") is False
253
254
255 class TestConstruction:
256 def test_empty_token_rejected(self) -> None:
257 with pytest.raises(ValueError, match="bearer token"):
258 ProbeRpcServer(host="127.0.0.1", port=0, token="", queue=InjectedProbeQueue())
259
260 def test_start_twice_rejected(self) -> None:
261 try:
262 srv = ProbeRpcServer(
263 host="127.0.0.1",
264 port=0,
265 token=_TOKEN,
266 queue=InjectedProbeQueue(),
267 )
268 except PermissionError as exc:
269 pytest.skip(f"loopback bind blocked on this host: {exc}")
270 srv.start()
271 try:
272 with pytest.raises(RuntimeError, match="already started"):
273 srv.start()
274 finally:
275 srv.stop()