tenseleyflow/sway / 3329900

Browse files

Add FastAPI app unit tests via TestClient

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
3329900069a72b670066857a8c0eaa6e727a9e20
Parents
0d28d58
Tree
d0ad63b

1 changed file

StatusFile+-
A tests/unit/test_serve_app.py 249 0
tests/unit/test_serve_app.pyadded
@@ -0,0 +1,249 @@
1
+"""Unit tests for :mod:`dlm_sway.serve.app`.
2
+
3
+These exercise the FastAPI surface end-to-end via
4
+``fastapi.testclient.TestClient`` — no uvicorn, no real network — and
5
+ride a pre-seeded :class:`BackendCache` so the dummy backend never
6
+goes through the rejected build path.
7
+"""
8
+
9
+from __future__ import annotations
10
+
11
+from pathlib import Path
12
+from typing import Any
13
+
14
+import pytest
15
+
16
+# Skip the whole module cleanly if the [serve] extra isn't installed —
17
+# avoids forcing every test contributor to ``pip install fastapi``.
18
+pytest.importorskip("fastapi")
19
+pytest.importorskip("httpx")
20
+
21
+from fastapi.testclient import TestClient  # noqa: E402
22
+
23
+from dlm_sway.backends.dummy import (  # noqa: E402
24
+    DummyDifferentialBackend,
25
+    DummyResponses,
26
+)
27
+from dlm_sway.core.model import ModelSpec  # noqa: E402
28
+from dlm_sway.serve.app import create_app, parse_host_port  # noqa: E402
29
+from dlm_sway.serve.cache import BackendCache, CachedBackend, cache_key_for  # noqa: E402
30
+from dlm_sway.suite.spec import SwaySpec  # noqa: E402
31
+
32
+
33
+def _spec_payload(*, base: str = "dummy-base") -> dict[str, Any]:
34
+    """A minimal valid SwaySpec dict with one delta_kl probe."""
35
+    return {
36
+        "version": 1,
37
+        "models": {
38
+            "base": {"kind": "dummy", "base": base},
39
+            "ft": {"kind": "dummy", "base": base},
40
+        },
41
+        "defaults": {"seed": 0, "differential": True},
42
+        "suite": [
43
+            {"name": "dk", "kind": "delta_kl", "prompts": ["hello world"]},
44
+        ],
45
+    }
46
+
47
+
48
+def _seed_dummy(cache: BackendCache, model_spec: ModelSpec) -> DummyDifferentialBackend:
49
+    """Pre-load a dummy backend into the cache under ``model_spec``'s key."""
50
+    backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
51
+    key = cache_key_for(model_spec)
52
+    entry = CachedBackend(key=key, backend=backend, model_spec=model_spec, load_seconds=0.1)
53
+    with cache._lock:  # noqa: SLF001
54
+        cache._entries[key] = entry  # noqa: SLF001
55
+    return backend
56
+
57
+
58
+def _make_seeded_app(*, api_key: str | None = None) -> tuple[Any, BackendCache]:
59
+    """Build the app and pre-seed its cache with a dummy backend
60
+    under the spec the tests POST."""
61
+    cache = BackendCache(max_size=2)
62
+    app = create_app(cache=cache, api_key=api_key)
63
+    payload = _spec_payload()
64
+    spec = SwaySpec.model_validate(payload)
65
+    _seed_dummy(cache, spec.models.ft)
66
+    return app, cache
67
+
68
+
69
+class TestHealth:
70
+    def test_health_returns_uptime_and_loaded_models(self) -> None:
71
+        app, _cache = _make_seeded_app()
72
+        with TestClient(app) as client:
73
+            resp = client.get("/health")
74
+        assert resp.status_code == 200
75
+        body = resp.json()
76
+        assert body["status"] == "ok"
77
+        assert body["uptime_seconds"] >= 0.0
78
+        assert body["max_loaded_models"] == 2
79
+        assert isinstance(body["loaded_models"], list)
80
+        assert any(m["base"] == "dummy-base" for m in body["loaded_models"])
81
+
82
+    def test_health_unauthenticated_when_api_key_set(self) -> None:
83
+        """A k8s liveness probe must be able to hit /health without
84
+        the bearer token."""
85
+        app, _cache = _make_seeded_app(api_key="secret")
86
+        with TestClient(app) as client:
87
+            resp = client.get("/health")  # No Authorization header.
88
+        assert resp.status_code == 200
89
+
90
+
91
+class TestStats:
92
+    def test_stats_zero_when_no_runs(self) -> None:
93
+        app, _cache = _make_seeded_app()
94
+        with TestClient(app) as client:
95
+            resp = client.get("/stats")
96
+        assert resp.status_code == 200
97
+        body = resp.json()
98
+        assert body["request_count"] == 0
99
+        assert body["mean_run_seconds"] is None
100
+        assert body["cached_backends"] == 1
101
+
102
+
103
+class TestRun:
104
+    def test_run_returns_full_report_shape(self) -> None:
105
+        app, _cache = _make_seeded_app()
106
+        with TestClient(app) as client:
107
+            resp = client.post("/run", json={"spec": _spec_payload()})
108
+        assert resp.status_code == 200, resp.text
109
+        body = resp.json()
110
+        # Round-trips the same JSON shape sway run --json-out emits.
111
+        assert "score" in body
112
+        assert "probes" in body
113
+        assert isinstance(body["probes"], list)
114
+        assert any(p["name"] == "dk" for p in body["probes"])
115
+        assert "request_seconds" in body
116
+        assert body["request_seconds"] >= 0.0
117
+
118
+    def test_run_increments_stats(self) -> None:
119
+        app, _cache = _make_seeded_app()
120
+        with TestClient(app) as client:
121
+            client.post("/run", json={"spec": _spec_payload()})
122
+            client.post("/run", json={"spec": _spec_payload()})
123
+            stats = client.get("/stats").json()
124
+        assert stats["request_count"] == 2
125
+        assert stats["mean_run_seconds"] is not None
126
+        assert stats["mean_run_seconds"] >= 0.0
127
+
128
+    def test_run_400_on_invalid_spec(self) -> None:
129
+        app, _cache = _make_seeded_app()
130
+        with TestClient(app) as client:
131
+            # Missing required fields → pydantic validation error → 422.
132
+            resp = client.post("/run", json={"spec": {"version": 1}})
133
+        # FastAPI emits 422 for body validation failures.
134
+        assert resp.status_code in (400, 422)
135
+
136
+
137
+class TestScore:
138
+    def test_score_filters_by_probe_names(self) -> None:
139
+        app, _cache = _make_seeded_app()
140
+        # Build a spec with two probes, then ask /score for one.
141
+        spec = _spec_payload()
142
+        spec["suite"].append(
143
+            {"name": "dk2", "kind": "delta_kl", "prompts": ["different"]}
144
+        )
145
+        with TestClient(app) as client:
146
+            resp = client.post(
147
+                "/score",
148
+                json={"spec": spec, "probe_names": ["dk2"]},
149
+            )
150
+        assert resp.status_code == 200, resp.text
151
+        body = resp.json()
152
+        names = {p["name"] for p in body["probes"]}
153
+        assert names == {"dk2"}
154
+
155
+    def test_score_runs_all_when_probe_names_none(self) -> None:
156
+        app, _cache = _make_seeded_app()
157
+        spec = _spec_payload()
158
+        spec["suite"].append(
159
+            {"name": "dk2", "kind": "delta_kl", "prompts": ["other"]}
160
+        )
161
+        with TestClient(app) as client:
162
+            resp = client.post("/score", json={"spec": spec})
163
+        assert resp.status_code == 200, resp.text
164
+        body = resp.json()
165
+        names = {p["name"] for p in body["probes"]}
166
+        assert names == {"dk", "dk2"}
167
+
168
+
169
+class TestAuth:
170
+    def test_run_requires_bearer_when_api_key_set(self) -> None:
171
+        app, _cache = _make_seeded_app(api_key="topsecret")
172
+        with TestClient(app) as client:
173
+            resp = client.post("/run", json={"spec": _spec_payload()})
174
+        assert resp.status_code == 401
175
+        assert "Authorization" in resp.json()["detail"]
176
+
177
+    def test_run_succeeds_with_valid_bearer(self) -> None:
178
+        app, _cache = _make_seeded_app(api_key="topsecret")
179
+        with TestClient(app) as client:
180
+            resp = client.post(
181
+                "/run",
182
+                json={"spec": _spec_payload()},
183
+                headers={"Authorization": "Bearer topsecret"},
184
+            )
185
+        assert resp.status_code == 200, resp.text
186
+
187
+    def test_run_rejects_wrong_bearer(self) -> None:
188
+        app, _cache = _make_seeded_app(api_key="topsecret")
189
+        with TestClient(app) as client:
190
+            resp = client.post(
191
+                "/run",
192
+                json={"spec": _spec_payload()},
193
+                headers={"Authorization": "Bearer otherkey"},
194
+            )
195
+        assert resp.status_code == 401
196
+
197
+
198
+class TestCacheRoundTripThroughApp:
199
+    def test_run_does_not_evict_below_cap(self) -> None:
200
+        app, cache = _make_seeded_app()
201
+        with TestClient(app) as client:
202
+            client.post("/run", json={"spec": _spec_payload()})
203
+            # Cap is 2; we only ever loaded 1 entry, so it must still
204
+            # be there. Check inside the with-block — TestClient's
205
+            # __exit__ runs the shutdown handler which evicts everything.
206
+            assert len(cache.loaded_keys()) == 1
207
+        # After the lifespan exits, evict_all() has run.
208
+        assert len(cache.loaded_keys()) == 0
209
+
210
+
211
+class TestParseHostPort:
212
+    def test_rejects_out_of_range_port(self) -> None:
213
+        import typer
214
+
215
+        with pytest.raises(typer.BadParameter):
216
+            parse_host_port("127.0.0.1", 0)
217
+        with pytest.raises(typer.BadParameter):
218
+            parse_host_port("127.0.0.1", 70_000)
219
+
220
+    def test_accepts_valid_host_port(self) -> None:
221
+        host, port = parse_host_port("127.0.0.1", 8787)
222
+        assert host == "127.0.0.1"
223
+        assert port == 8787
224
+
225
+
226
+class TestSpecRoundTrip:
227
+    def test_spec_dump_load_preserves_fields(self, tmp_path: Path) -> None:
228
+        """RunRequest must roundtrip a SwaySpec through JSON without
229
+        losing fields — risk #4 in the sprint plan."""
230
+        from dlm_sway.serve.app import RunRequest
231
+
232
+        original = SwaySpec.model_validate(_spec_payload())
233
+        # Round-trip through JSON like FastAPI's wire format does.
234
+        req = RunRequest(spec=original)
235
+        body = req.model_dump_json()
236
+        reparsed = RunRequest.model_validate_json(body)
237
+        assert reparsed.spec == original
238
+        assert reparsed.spec_path == "<serve>"
239
+
240
+    def test_spec_with_dlm_source_roundtrip(self) -> None:
241
+        """dlm_source is optional and must survive round-trip."""
242
+        from dlm_sway.serve.app import RunRequest
243
+
244
+        payload = _spec_payload()
245
+        payload["dlm_source"] = "/path/to/foo.dlm"
246
+        original = SwaySpec.model_validate(payload)
247
+        req = RunRequest(spec=original)
248
+        reparsed = RunRequest.model_validate_json(req.model_dump_json())
249
+        assert reparsed.spec.dlm_source == "/path/to/foo.dlm"