Python · 9586 bytes Raw Blame History
1 """Unit tests for :mod:`dlm_sway.serve.app`.
2
3 These exercise the FastAPI surface end-to-end via
4 ``fastapi.testclient.TestClient`` — no uvicorn, no real network — and
5 ride a pre-seeded :class:`BackendCache` so the dummy backend never
6 goes through the rejected build path.
7 """
8
9 from __future__ import annotations
10
11 from pathlib import Path
12 from typing import Any
13
14 import pytest
15
16 # Skip the whole module cleanly if the [serve] extra isn't installed —
17 # avoids forcing every test contributor to ``pip install fastapi``.
18 pytest.importorskip("fastapi")
19 pytest.importorskip("httpx")
20
21 from fastapi.testclient import TestClient # noqa: E402
22
23 from dlm_sway.backends.dummy import ( # noqa: E402
24 DummyDifferentialBackend,
25 DummyResponses,
26 )
27 from dlm_sway.core.model import ModelSpec # noqa: E402
28 from dlm_sway.serve.app import create_app, parse_host_port # noqa: E402
29 from dlm_sway.serve.cache import BackendCache, CachedBackend, cache_key_for # noqa: E402
30 from dlm_sway.suite.spec import SwaySpec # noqa: E402
31
32
33 def _spec_payload(*, base: str = "dummy-base") -> dict[str, Any]:
34 """A minimal valid SwaySpec dict with one delta_kl probe."""
35 return {
36 "version": 1,
37 "models": {
38 "base": {"kind": "dummy", "base": base},
39 "ft": {"kind": "dummy", "base": base},
40 },
41 "defaults": {"seed": 0, "differential": True},
42 "suite": [
43 {"name": "dk", "kind": "delta_kl", "prompts": ["hello world"]},
44 ],
45 }
46
47
48 def _seed_dummy(cache: BackendCache, model_spec: ModelSpec) -> DummyDifferentialBackend:
49 """Pre-load a dummy backend into the cache under ``model_spec``'s key."""
50 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
51 key = cache_key_for(model_spec)
52 entry = CachedBackend(key=key, backend=backend, model_spec=model_spec, load_seconds=0.1)
53 with cache._lock: # noqa: SLF001
54 cache._entries[key] = entry # noqa: SLF001
55 return backend
56
57
58 def _make_seeded_app(*, api_key: str | None = None) -> tuple[Any, BackendCache]:
59 """Build the app and pre-seed its cache with a dummy backend
60 under the spec the tests POST."""
61 cache = BackendCache(max_size=2)
62 app = create_app(cache=cache, api_key=api_key)
63 payload = _spec_payload()
64 spec = SwaySpec.model_validate(payload)
65 _seed_dummy(cache, spec.models.ft)
66 return app, cache
67
68
69 class TestHealth:
70 def test_health_returns_uptime_and_loaded_models(self) -> None:
71 app, _cache = _make_seeded_app()
72 with TestClient(app) as client:
73 resp = client.get("/health")
74 assert resp.status_code == 200
75 body = resp.json()
76 assert body["status"] == "ok"
77 assert body["uptime_seconds"] >= 0.0
78 assert body["max_loaded_models"] == 2
79 assert isinstance(body["loaded_models"], list)
80 assert any(m["base"] == "dummy-base" for m in body["loaded_models"])
81
82 def test_health_unauthenticated_when_api_key_set(self) -> None:
83 """A k8s liveness probe must be able to hit /health without
84 the bearer token."""
85 app, _cache = _make_seeded_app(api_key="secret")
86 with TestClient(app) as client:
87 resp = client.get("/health") # No Authorization header.
88 assert resp.status_code == 200
89
90
91 class TestStats:
92 def test_stats_zero_when_no_runs(self) -> None:
93 app, _cache = _make_seeded_app()
94 with TestClient(app) as client:
95 resp = client.get("/stats")
96 assert resp.status_code == 200
97 body = resp.json()
98 assert body["request_count"] == 0
99 assert body["mean_run_seconds"] is None
100 assert body["cached_backends"] == 1
101
102
103 class TestRun:
104 def test_run_returns_full_report_shape(self) -> None:
105 app, _cache = _make_seeded_app()
106 with TestClient(app) as client:
107 resp = client.post("/run", json={"spec": _spec_payload()})
108 assert resp.status_code == 200, resp.text
109 body = resp.json()
110 # Round-trips the same JSON shape sway run --json-out emits.
111 assert "score" in body
112 assert "probes" in body
113 assert isinstance(body["probes"], list)
114 assert any(p["name"] == "dk" for p in body["probes"])
115 assert "request_seconds" in body
116 assert body["request_seconds"] >= 0.0
117
118 def test_run_increments_stats(self) -> None:
119 app, _cache = _make_seeded_app()
120 with TestClient(app) as client:
121 client.post("/run", json={"spec": _spec_payload()})
122 client.post("/run", json={"spec": _spec_payload()})
123 stats = client.get("/stats").json()
124 assert stats["request_count"] == 2
125 assert stats["mean_run_seconds"] is not None
126 assert stats["mean_run_seconds"] >= 0.0
127
128 def test_run_400_on_invalid_spec(self) -> None:
129 app, _cache = _make_seeded_app()
130 with TestClient(app) as client:
131 # Missing required fields → pydantic validation error → 422.
132 resp = client.post("/run", json={"spec": {"version": 1}})
133 # FastAPI emits 422 for body validation failures.
134 assert resp.status_code in (400, 422)
135
136
137 class TestScore:
138 def test_score_filters_by_probe_names(self) -> None:
139 app, _cache = _make_seeded_app()
140 # Build a spec with two probes, then ask /score for one.
141 spec = _spec_payload()
142 spec["suite"].append({"name": "dk2", "kind": "delta_kl", "prompts": ["different"]})
143 with TestClient(app) as client:
144 resp = client.post(
145 "/score",
146 json={"spec": spec, "probe_names": ["dk2"]},
147 )
148 assert resp.status_code == 200, resp.text
149 body = resp.json()
150 names = {p["name"] for p in body["probes"]}
151 assert names == {"dk2"}
152
153 def test_score_runs_all_when_probe_names_none(self) -> None:
154 app, _cache = _make_seeded_app()
155 spec = _spec_payload()
156 spec["suite"].append({"name": "dk2", "kind": "delta_kl", "prompts": ["other"]})
157 with TestClient(app) as client:
158 resp = client.post("/score", json={"spec": spec})
159 assert resp.status_code == 200, resp.text
160 body = resp.json()
161 names = {p["name"] for p in body["probes"]}
162 assert names == {"dk", "dk2"}
163
164
165 class TestAuth:
166 def test_run_requires_bearer_when_api_key_set(self) -> None:
167 app, _cache = _make_seeded_app(api_key="topsecret")
168 with TestClient(app) as client:
169 resp = client.post("/run", json={"spec": _spec_payload()})
170 assert resp.status_code == 401
171 assert "Authorization" in resp.json()["detail"]
172
173 def test_run_succeeds_with_valid_bearer(self) -> None:
174 app, _cache = _make_seeded_app(api_key="topsecret")
175 with TestClient(app) as client:
176 resp = client.post(
177 "/run",
178 json={"spec": _spec_payload()},
179 headers={"Authorization": "Bearer topsecret"},
180 )
181 assert resp.status_code == 200, resp.text
182
183 def test_run_rejects_wrong_bearer(self) -> None:
184 app, _cache = _make_seeded_app(api_key="topsecret")
185 with TestClient(app) as client:
186 resp = client.post(
187 "/run",
188 json={"spec": _spec_payload()},
189 headers={"Authorization": "Bearer otherkey"},
190 )
191 assert resp.status_code == 401
192
193
194 class TestCacheRoundTripThroughApp:
195 def test_run_does_not_evict_below_cap(self) -> None:
196 app, cache = _make_seeded_app()
197 with TestClient(app) as client:
198 client.post("/run", json={"spec": _spec_payload()})
199 # Cap is 2; we only ever loaded 1 entry, so it must still
200 # be there. Check inside the with-block — TestClient's
201 # __exit__ runs the shutdown handler which evicts everything.
202 assert len(cache.loaded_keys()) == 1
203 # After the lifespan exits, evict_all() has run.
204 assert len(cache.loaded_keys()) == 0
205
206
207 class TestParseHostPort:
208 def test_rejects_out_of_range_port(self) -> None:
209 import typer
210
211 with pytest.raises(typer.BadParameter):
212 parse_host_port("127.0.0.1", 0)
213 with pytest.raises(typer.BadParameter):
214 parse_host_port("127.0.0.1", 70_000)
215
216 def test_accepts_valid_host_port(self) -> None:
217 host, port = parse_host_port("127.0.0.1", 8787)
218 assert host == "127.0.0.1"
219 assert port == 8787
220
221
222 class TestSpecRoundTrip:
223 def test_spec_dump_load_preserves_fields(self, tmp_path: Path) -> None:
224 """RunRequest must roundtrip a SwaySpec through JSON without
225 losing fields — risk #4 in the sprint plan."""
226 from dlm_sway.serve.app import RunRequest
227
228 original = SwaySpec.model_validate(_spec_payload())
229 # Round-trip through JSON like FastAPI's wire format does.
230 req = RunRequest(spec=original)
231 body = req.model_dump_json()
232 reparsed = RunRequest.model_validate_json(body)
233 assert reparsed.spec == original
234 assert reparsed.spec_path == "<serve>"
235
236 def test_spec_with_dlm_source_roundtrip(self) -> None:
237 """dlm_source is optional and must survive round-trip."""
238 from dlm_sway.serve.app import RunRequest
239
240 payload = _spec_payload()
241 payload["dlm_source"] = "/path/to/foo.dlm"
242 original = SwaySpec.model_validate(payload)
243 req = RunRequest(spec=original)
244 reparsed = RunRequest.model_validate_json(req.model_dump_json())
245 assert reparsed.spec.dlm_source == "/path/to/foo.dlm"