Python · 7863 bytes Raw Blame History
1 """Unit tests for :class:`dlm_sway.serve.client.ServeClient`.
2
3 The happy path uses an httpx ``MockTransport`` to canned-respond on
4 each route so we exercise the client's request-shaping (URL build,
5 auth header) and response-parsing without binding a port. The error
6 branches drive :meth:`ServeClient._parse_response` directly so we
7 cover transport failures and malformed payloads in isolation.
8 """
9
10 from __future__ import annotations
11
12 import json
13 from typing import Any
14
15 import pytest
16
17 pytest.importorskip("httpx")
18
19 import httpx # noqa: E402
20
21 from dlm_sway.serve.client import ServeClient, ServeClientError # noqa: E402
22 from dlm_sway.suite.spec import SwaySpec # noqa: E402
23
24
25 def _spec_payload() -> dict[str, Any]:
26 return {
27 "version": 1,
28 "models": {
29 "base": {"kind": "dummy", "base": "dummy-base"},
30 "ft": {"kind": "dummy", "base": "dummy-base"},
31 },
32 "defaults": {"seed": 0, "differential": True},
33 "suite": [
34 {"name": "dk", "kind": "delta_kl", "prompts": ["hello"]},
35 ],
36 }
37
38
39 class _StubClient(ServeClient):
40 """ServeClient overridden to route through an httpx MockTransport
41 instead of opening a real connection — same code path otherwise."""
42
43 def __init__(
44 self,
45 handler: Any,
46 *,
47 api_key: str | None = None,
48 ) -> None:
49 super().__init__("http://testserver", api_key=api_key)
50 self._transport = httpx.MockTransport(handler)
51
52 def _get(self, path: str) -> dict[str, Any]:
53 with httpx.Client(transport=self._transport, base_url=self._url) as client:
54 resp = client.get(path, headers=self._headers)
55 return self._parse_response(resp, path=path)
56
57 def _post(self, path: str, body: dict[str, Any]) -> dict[str, Any]:
58 with httpx.Client(transport=self._transport, base_url=self._url) as client:
59 resp = client.post(path, headers=self._headers, json=body)
60 return self._parse_response(resp, path=path)
61
62
63 class TestServeClientHappy:
64 def test_health_returns_dict(self) -> None:
65 seen: dict[str, Any] = {}
66
67 def handler(request: httpx.Request) -> httpx.Response:
68 seen["path"] = request.url.path
69 return httpx.Response(200, json={"status": "ok", "uptime_seconds": 1.0})
70
71 client = _StubClient(handler)
72 body = client.health()
73 assert body["status"] == "ok"
74 assert seen["path"] == "/health"
75
76 def test_run_sends_spec_and_path(self) -> None:
77 captured: dict[str, Any] = {}
78
79 def handler(request: httpx.Request) -> httpx.Response:
80 captured["body"] = json.loads(request.content)
81 return httpx.Response(200, json={"probes": [], "request_seconds": 0.1})
82
83 client = _StubClient(handler)
84 spec = SwaySpec.model_validate(_spec_payload())
85 body = client.run(spec, spec_path="custom/path.yaml")
86 assert "probes" in body
87 assert captured["body"]["spec_path"] == "custom/path.yaml"
88 assert captured["body"]["spec"]["version"] == 1
89
90 def test_score_passes_probe_names(self) -> None:
91 captured: dict[str, Any] = {}
92
93 def handler(request: httpx.Request) -> httpx.Response:
94 captured["body"] = json.loads(request.content)
95 return httpx.Response(200, json={"probes": [], "request_seconds": 0.1})
96
97 client = _StubClient(handler)
98 spec = SwaySpec.model_validate(_spec_payload())
99 client.score(spec, probe_names=["dk"])
100 assert captured["body"]["probe_names"] == ["dk"]
101
102 def test_score_omits_probe_names_when_none(self) -> None:
103 captured: dict[str, Any] = {}
104
105 def handler(request: httpx.Request) -> httpx.Response:
106 captured["body"] = json.loads(request.content)
107 return httpx.Response(200, json={"probes": [], "request_seconds": 0.1})
108
109 client = _StubClient(handler)
110 spec = SwaySpec.model_validate(_spec_payload())
111 client.score(spec)
112 # Default behavior: omit the field rather than send null.
113 assert "probe_names" not in captured["body"]
114
115
116 class TestServeClientAuth:
117 def test_run_attaches_bearer_when_api_key_set(self) -> None:
118 captured: dict[str, Any] = {}
119
120 def handler(request: httpx.Request) -> httpx.Response:
121 captured["auth"] = request.headers.get("Authorization")
122 return httpx.Response(200, json={"probes": []})
123
124 client = _StubClient(handler, api_key="abc")
125 spec = SwaySpec.model_validate(_spec_payload())
126 client.run(spec)
127 assert captured["auth"] == "Bearer abc"
128
129 def test_no_auth_header_when_api_key_unset(self) -> None:
130 captured: dict[str, Any] = {}
131
132 def handler(request: httpx.Request) -> httpx.Response:
133 captured["auth"] = request.headers.get("Authorization")
134 return httpx.Response(200, json={"status": "ok"})
135
136 client = _StubClient(handler)
137 client.health()
138 assert captured["auth"] is None
139
140 def test_401_raises_serve_client_error(self) -> None:
141 def handler(request: httpx.Request) -> httpx.Response:
142 del request
143 return httpx.Response(401, json={"detail": "missing or invalid"})
144
145 client = _StubClient(handler)
146 spec = SwaySpec.model_validate(_spec_payload())
147 with pytest.raises(ServeClientError, match="missing or invalid"):
148 client.run(spec)
149
150
151 class TestServeClientErrorPaths:
152 def test_missing_httpx_dependency_raises_clean_error(
153 self, monkeypatch: pytest.MonkeyPatch
154 ) -> None:
155 """If httpx isn't installed, the client surfaces a SwayError-shaped
156 message pointing at the [serve] extra."""
157 client = ServeClient("http://nope")
158 # Force the httpx import inside _get to fail.
159 import builtins
160
161 real_import = builtins.__import__
162
163 def _fake_import(name: str, *args: Any, **kwargs: Any) -> Any:
164 if name == "httpx":
165 raise ImportError("nope")
166 return real_import(name, *args, **kwargs)
167
168 monkeypatch.setattr(builtins, "__import__", _fake_import)
169 with pytest.raises(ServeClientError, match=r"\[serve\]"):
170 client.health()
171
172 def test_transport_failure_wraps_to_serve_client_error(self) -> None:
173 """An httpx connection error becomes a ServeClientError with the
174 request method/path in the message."""
175 client = ServeClient("http://127.0.0.1:1") # unused port
176 with pytest.raises(ServeClientError, match="GET /health failed"):
177 client.health()
178
179 def test_non_dict_response_raises(self) -> None:
180 """When the daemon ever returns a JSON list at the root, the
181 client refuses it instead of silently mis-typing."""
182
183 def _handler(request: httpx.Request) -> httpx.Response:
184 del request
185 return httpx.Response(200, json=[1, 2, 3])
186
187 # Drive _parse_response directly: build a fake response object.
188 resp = httpx.Response(200, json=[1, 2, 3])
189 with pytest.raises(ServeClientError, match="non-object JSON"):
190 ServeClient._parse_response(resp, path="/x")
191
192 def test_4xx_pulls_detail_from_payload(self) -> None:
193 resp = httpx.Response(400, json={"detail": "bad spec"})
194 with pytest.raises(ServeClientError, match="bad spec"):
195 ServeClient._parse_response(resp, path="/run")
196
197 def test_4xx_with_non_json_falls_back_to_text(self) -> None:
198 resp = httpx.Response(500, text="upstream broke")
199 with pytest.raises(ServeClientError, match="upstream broke"):
200 ServeClient._parse_response(resp, path="/run")
201
202 def test_url_property_strips_trailing_slash(self) -> None:
203 client = ServeClient("http://localhost:8787/")
204 assert client.url == "http://localhost:8787"