| 1 |
"""Unit tests for the probe-RPC server. |
| 2 |
|
| 3 |
Exercises the handler directly via `urllib` against a real bound port |
| 4 |
rather than mocking `BaseHTTPRequestHandler` internals — the surface is |
| 5 |
narrow enough that actual-socket tests stay fast and have no flake |
| 6 |
potential beyond "picked a used port," mitigated by letting the OS |
| 7 |
assign (`port=0`). |
| 8 |
""" |
| 9 |
|
| 10 |
from __future__ import annotations |
| 11 |
|
| 12 |
import json |
| 13 |
import socket |
| 14 |
import urllib.error |
| 15 |
import urllib.request |
| 16 |
from collections.abc import Iterator |
| 17 |
from typing import Any |
| 18 |
|
| 19 |
import pytest |
| 20 |
|
| 21 |
from dlm.train.inject import InjectedProbeQueue |
| 22 |
from dlm.train.rpc import ProbeRpcServer, _check_bearer |
| 23 |
|
| 24 |
_TOKEN = "test-token-123" |
| 25 |
|
| 26 |
|
| 27 |
@pytest.fixture |
| 28 |
def server() -> Iterator[ProbeRpcServer]: |
| 29 |
queue = InjectedProbeQueue(capacity=4) |
| 30 |
try: |
| 31 |
srv = ProbeRpcServer( |
| 32 |
host="127.0.0.1", port=0, token=_TOKEN, queue=queue, next_cycle_eta_s=lambda: 42 |
| 33 |
) |
| 34 |
except PermissionError as exc: |
| 35 |
pytest.skip(f"loopback bind blocked on this host: {exc}") |
| 36 |
srv.start() |
| 37 |
try: |
| 38 |
yield srv |
| 39 |
finally: |
| 40 |
srv.stop() |
| 41 |
|
| 42 |
|
| 43 |
def _post( |
| 44 |
server: ProbeRpcServer, |
| 45 |
*, |
| 46 |
body: dict[str, Any] | str, |
| 47 |
token: str | None = _TOKEN, |
| 48 |
path: str = "/rpc", |
| 49 |
) -> tuple[int, dict[str, Any]]: |
| 50 |
host, port = server.address |
| 51 |
url = f"http://{host}:{port}{path}" |
| 52 |
raw = body if isinstance(body, str) else json.dumps(body) |
| 53 |
headers = {"Content-Type": "application/json"} |
| 54 |
if token is not None: |
| 55 |
headers["Authorization"] = f"Bearer {token}" |
| 56 |
req = urllib.request.Request(url, data=raw.encode("utf-8"), headers=headers, method="POST") |
| 57 |
try: |
| 58 |
resp = urllib.request.urlopen(req, timeout=5.0) # noqa: S310 |
| 59 |
except urllib.error.HTTPError as exc: |
| 60 |
return exc.code, json.loads(exc.read()) |
| 61 |
return resp.status, json.loads(resp.read()) |
| 62 |
|
| 63 |
|
| 64 |
def _raw_post( |
| 65 |
server: ProbeRpcServer, |
| 66 |
*, |
| 67 |
headers: dict[str, str], |
| 68 |
body: bytes = b"", |
| 69 |
path: str = "/rpc", |
| 70 |
) -> tuple[int, dict[str, Any]]: |
| 71 |
host, port = server.address |
| 72 |
lines = [ |
| 73 |
f"POST {path} HTTP/1.1", |
| 74 |
f"Host: {host}:{port}", |
| 75 |
*[f"{key}: {value}" for key, value in headers.items()], |
| 76 |
"", |
| 77 |
"", |
| 78 |
] |
| 79 |
request = "\r\n".join(lines).encode("utf-8") + body |
| 80 |
with socket.create_connection((host, port), timeout=5.0) as sock: |
| 81 |
sock.sendall(request) |
| 82 |
response = b"" |
| 83 |
while b"\r\n\r\n" not in response: |
| 84 |
response += sock.recv(4096) |
| 85 |
head, rest = response.split(b"\r\n\r\n", 1) |
| 86 |
header_lines = head.decode("iso-8859-1").split("\r\n") |
| 87 |
status = int(header_lines[0].split()[1]) |
| 88 |
parsed_headers: dict[str, str] = {} |
| 89 |
for line in header_lines[1:]: |
| 90 |
if ":" not in line: |
| 91 |
continue |
| 92 |
key, value = line.split(":", 1) |
| 93 |
parsed_headers[key.lower()] = value.strip() |
| 94 |
content_length = int(parsed_headers.get("content-length", "0")) |
| 95 |
while len(rest) < content_length: |
| 96 |
rest += sock.recv(4096) |
| 97 |
return status, json.loads(rest[:content_length].decode("utf-8")) |
| 98 |
|
| 99 |
|
| 100 |
class TestHappyPath: |
| 101 |
def test_inject_probe_accepted(self, server: ProbeRpcServer) -> None: |
| 102 |
status, body = _post( |
| 103 |
server, |
| 104 |
body={ |
| 105 |
"method": "inject_probe", |
| 106 |
"params": {"prompt": "what is X?", "reference": "Y.", "tags": ["sway"]}, |
| 107 |
}, |
| 108 |
) |
| 109 |
assert status == 200 |
| 110 |
assert body == {"accepted": True, "next_cycle_eta_s": 42, "queue_depth": 1} |
| 111 |
drained = server.queue.drain() |
| 112 |
assert len(drained) == 1 |
| 113 |
assert drained[0].prompt == "what is X?" |
| 114 |
assert drained[0].tags == ("sway",) |
| 115 |
|
| 116 |
|
| 117 |
class TestAuth: |
| 118 |
def test_missing_token_401(self, server: ProbeRpcServer) -> None: |
| 119 |
status, body = _post( |
| 120 |
server, |
| 121 |
body={"method": "inject_probe", "params": {"prompt": "q", "reference": "a"}}, |
| 122 |
token=None, |
| 123 |
) |
| 124 |
assert status == 401 |
| 125 |
assert "bearer" in body["error"].lower() |
| 126 |
|
| 127 |
def test_wrong_token_401(self, server: ProbeRpcServer) -> None: |
| 128 |
status, body = _post( |
| 129 |
server, |
| 130 |
body={"method": "inject_probe", "params": {"prompt": "q", "reference": "a"}}, |
| 131 |
token="wrong-token", |
| 132 |
) |
| 133 |
assert status == 401 |
| 134 |
|
| 135 |
|
| 136 |
class TestMalformedPayload: |
| 137 |
def test_bad_json_400(self, server: ProbeRpcServer) -> None: |
| 138 |
status, body = _post(server, body="not json {") |
| 139 |
assert status == 400 |
| 140 |
assert "malformed" in body["error"].lower() |
| 141 |
|
| 142 |
def test_invalid_content_length_400(self, server: ProbeRpcServer) -> None: |
| 143 |
status, body = _raw_post( |
| 144 |
server, |
| 145 |
headers={ |
| 146 |
"Authorization": f"Bearer {_TOKEN}", |
| 147 |
"Content-Type": "application/json", |
| 148 |
"Content-Length": "nope", |
| 149 |
}, |
| 150 |
) |
| 151 |
assert status == 400 |
| 152 |
assert "content-length" in body["error"].lower() |
| 153 |
|
| 154 |
def test_empty_body_400(self, server: ProbeRpcServer) -> None: |
| 155 |
status, body = _raw_post( |
| 156 |
server, |
| 157 |
headers={ |
| 158 |
"Authorization": f"Bearer {_TOKEN}", |
| 159 |
"Content-Type": "application/json", |
| 160 |
"Content-Length": "0", |
| 161 |
}, |
| 162 |
) |
| 163 |
assert status == 400 |
| 164 |
assert "empty body" in body["error"].lower() |
| 165 |
|
| 166 |
def test_oversized_body_400(self, server: ProbeRpcServer) -> None: |
| 167 |
status, body = _raw_post( |
| 168 |
server, |
| 169 |
headers={ |
| 170 |
"Authorization": f"Bearer {_TOKEN}", |
| 171 |
"Content-Type": "application/json", |
| 172 |
"Content-Length": str(70 * 1024), |
| 173 |
}, |
| 174 |
) |
| 175 |
assert status == 400 |
| 176 |
assert "exceeds" in body["error"].lower() |
| 177 |
|
| 178 |
def test_payload_must_be_object(self, server: ProbeRpcServer) -> None: |
| 179 |
status, body = _post(server, body="[]") |
| 180 |
assert status == 400 |
| 181 |
assert "json object" in body["error"].lower() |
| 182 |
|
| 183 |
def test_missing_prompt_400(self, server: ProbeRpcServer) -> None: |
| 184 |
status, body = _post(server, body={"method": "inject_probe", "params": {"reference": "a"}}) |
| 185 |
assert status == 400 |
| 186 |
assert "prompt" in body["error"].lower() |
| 187 |
|
| 188 |
def test_params_must_be_object(self, server: ProbeRpcServer) -> None: |
| 189 |
status, body = _post(server, body={"method": "inject_probe", "params": "bad"}) |
| 190 |
assert status == 400 |
| 191 |
assert "`params`" in body["error"] |
| 192 |
|
| 193 |
def test_empty_reference_400(self, server: ProbeRpcServer) -> None: |
| 194 |
status, body = _post( |
| 195 |
server, |
| 196 |
body={"method": "inject_probe", "params": {"prompt": "q", "reference": " "}}, |
| 197 |
) |
| 198 |
assert status == 400 |
| 199 |
assert "reference" in body["error"].lower() |
| 200 |
|
| 201 |
def test_non_string_tags_400(self, server: ProbeRpcServer) -> None: |
| 202 |
status, body = _post( |
| 203 |
server, |
| 204 |
body={ |
| 205 |
"method": "inject_probe", |
| 206 |
"params": {"prompt": "q", "reference": "a", "tags": [1, 2]}, |
| 207 |
}, |
| 208 |
) |
| 209 |
assert status == 400 |
| 210 |
assert "tags" in body["error"].lower() |
| 211 |
|
| 212 |
|
| 213 |
class TestMethodDispatch: |
| 214 |
def test_unknown_method_404(self, server: ProbeRpcServer) -> None: |
| 215 |
status, body = _post(server, body={"method": "explode", "params": {}}) |
| 216 |
assert status == 404 |
| 217 |
assert "explode" in body["error"] |
| 218 |
|
| 219 |
def test_unknown_path_404(self, server: ProbeRpcServer) -> None: |
| 220 |
status, body = _post(server, body={"method": "inject_probe"}, path="/other") |
| 221 |
assert status == 404 |
| 222 |
|
| 223 |
|
| 224 |
class TestCapacity: |
| 225 |
def test_full_queue_429(self, server: ProbeRpcServer) -> None: |
| 226 |
payload = { |
| 227 |
"method": "inject_probe", |
| 228 |
"params": {"prompt": "q", "reference": "a"}, |
| 229 |
} |
| 230 |
for _ in range(4): |
| 231 |
status, _ = _post(server, body=payload) |
| 232 |
assert status == 200 |
| 233 |
status, body = _post(server, body=payload) |
| 234 |
assert status == 429 |
| 235 |
assert body["queue_depth"] == 4 |
| 236 |
|
| 237 |
|
| 238 |
class TestAuthHelper: |
| 239 |
def test_correct_token_matches(self) -> None: |
| 240 |
assert _check_bearer("Bearer abc", "abc") is True |
| 241 |
|
| 242 |
def test_wrong_token_fails(self) -> None: |
| 243 |
assert _check_bearer("Bearer xyz", "abc") is False |
| 244 |
|
| 245 |
def test_missing_prefix_fails(self) -> None: |
| 246 |
assert _check_bearer("abc", "abc") is False |
| 247 |
|
| 248 |
def test_empty_presented_fails(self) -> None: |
| 249 |
assert _check_bearer("Bearer ", "abc") is False |
| 250 |
|
| 251 |
def test_length_mismatch_fails(self) -> None: |
| 252 |
assert _check_bearer("Bearer abcd", "abc") is False |
| 253 |
|
| 254 |
|
| 255 |
class TestConstruction: |
| 256 |
def test_empty_token_rejected(self) -> None: |
| 257 |
with pytest.raises(ValueError, match="bearer token"): |
| 258 |
ProbeRpcServer(host="127.0.0.1", port=0, token="", queue=InjectedProbeQueue()) |
| 259 |
|
| 260 |
def test_start_twice_rejected(self) -> None: |
| 261 |
try: |
| 262 |
srv = ProbeRpcServer( |
| 263 |
host="127.0.0.1", |
| 264 |
port=0, |
| 265 |
token=_TOKEN, |
| 266 |
queue=InjectedProbeQueue(), |
| 267 |
) |
| 268 |
except PermissionError as exc: |
| 269 |
pytest.skip(f"loopback bind blocked on this host: {exc}") |
| 270 |
srv.start() |
| 271 |
try: |
| 272 |
with pytest.raises(RuntimeError, match="already started"): |
| 273 |
srv.start() |
| 274 |
finally: |
| 275 |
srv.stop() |