tenseleyflow/sway / bc9b356

Browse files

Add FastAPI app factory for sway serve

Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
bc9b3568997e4a900b38e8a987a77a0bc2e78a8e
Parents
8b59c25
Tree
2b7a811

1 changed file

StatusFile+-
A src/dlm_sway/serve/app.py 288 0
src/dlm_sway/serve/app.pyadded
@@ -0,0 +1,288 @@
1
+"""FastAPI app for the ``sway serve`` daemon.
2
+
3
+Three concerns sit in this module:
4
+
5
+1. ``create_app(...)`` — factory that returns a FastAPI instance with
6
+   the four documented endpoints wired up. Used by the CLI's uvicorn
7
+   launcher and unit tests' :class:`fastapi.testclient.TestClient`.
8
+2. Pydantic request/response models for ``/run`` and ``/score``.
9
+3. The endpoint handlers themselves, which delegate to
10
+   :func:`dlm_sway.suite.runner.run` against backends fetched from
11
+   :class:`dlm_sway.serve.cache.BackendCache`.
12
+
13
+Auth is intentionally minimal in v1: localhost-only by default; binding
14
+to ``0.0.0.0`` is rejected at the CLI layer unless ``--api-key`` is
15
+set, and an API-key middleware validates the bearer token on every
16
+request when an API key was provided. Full OAuth is out of scope.
17
+"""
18
+
19
+from __future__ import annotations
20
+
21
+import logging
22
+import time
23
+from collections.abc import Awaitable, Callable
24
+from typing import Any
25
+
26
+import typer
27
+from pydantic import BaseModel
28
+
29
+from dlm_sway import __version__
30
+from dlm_sway.core.errors import SwayError
31
+from dlm_sway.serve.cache import BackendCache
32
+from dlm_sway.suite.spec import SwaySpec
33
+
34
+_LOG = logging.getLogger(__name__)
35
+
36
+
37
+# ---------------------------------------------------------------------------
38
+# Request / response models
39
+# ---------------------------------------------------------------------------
40
+
41
+
42
+class RunRequest(BaseModel):
43
+    """Body of ``POST /run`` — full spec plus optional overrides."""
44
+
45
+    spec: SwaySpec
46
+    spec_path: str = "<serve>"
47
+    seed: int | None = None
48
+
49
+
50
+class ScoreRequest(BaseModel):
51
+    """Body of ``POST /score`` — full spec plus a subset filter.
52
+
53
+    When ``probe_names`` is set, only those probes execute; the
54
+    response shape is the per-probe entry list (no folded SwayScore).
55
+    """
56
+
57
+    spec: SwaySpec
58
+    probe_names: list[str] | None = None
59
+
60
+
61
+# ---------------------------------------------------------------------------
62
+# App factory
63
+# ---------------------------------------------------------------------------
64
+
65
+
66
+def create_app(
67
+    *,
68
+    cache: BackendCache | None = None,
69
+    api_key: str | None = None,
70
+) -> Any:
71
+    """Build a FastAPI app wired to a :class:`BackendCache`.
72
+
73
+    Parameters
74
+    ----------
75
+    cache:
76
+        Optional pre-built cache. When omitted, a new cache with
77
+        ``max_size=2`` is created. Tests inject a dummy-backed cache
78
+        via this parameter; the CLI passes one with the user's
79
+        ``--max-loaded-models``.
80
+    api_key:
81
+        Optional bearer token. When set, every request must carry
82
+        ``Authorization: Bearer <key>`` or it gets a 401. When
83
+        ``None`` (the default), no auth — only safe for localhost
84
+        binds, which the CLI enforces.
85
+
86
+    Returns
87
+    -------
88
+    A :class:`fastapi.FastAPI` instance ready to hand to uvicorn or
89
+    a TestClient.
90
+    """
91
+    try:
92
+        from fastapi import FastAPI, HTTPException, Request, status
93
+        from fastapi.responses import JSONResponse
94
+    except ImportError as exc:
95
+        raise SwayError(
96
+            "sway serve requires the [serve] extra: pip install 'dlm-sway[serve]'"
97
+        ) from exc
98
+
99
+    if cache is None:
100
+        cache = BackendCache(max_size=2)
101
+    started_at = time.monotonic()
102
+    request_count = 0
103
+    total_run_seconds = 0.0
104
+
105
+    app = FastAPI(
106
+        title="sway",
107
+        version=__version__,
108
+        description=(
109
+            "Warm-backend HTTP API for sway. POST /run accepts a spec; "
110
+            "the daemon keeps backends loaded between calls."
111
+        ),
112
+    )
113
+
114
+    # -- auth middleware --------------------------------------------------
115
+
116
+    if api_key is not None:
117
+        expected_header = f"Bearer {api_key}"
118
+
119
+        @app.middleware("http")
120
+        async def _auth_middleware(
121
+            request: Request,
122
+            call_next: Callable[[Request], Awaitable[Any]],
123
+        ) -> Any:
124
+            # Health is unauthenticated so an external probe (e.g. a
125
+            # k8s liveness check) can hit it without distributing the
126
+            # token. Every other endpoint requires the bearer.
127
+            if request.url.path == "/health":
128
+                return await call_next(request)
129
+            got = request.headers.get("Authorization", "")
130
+            if got != expected_header:
131
+                return JSONResponse(
132
+                    status_code=status.HTTP_401_UNAUTHORIZED,
133
+                    content={"detail": "missing or invalid Authorization header"},
134
+                )
135
+            return await call_next(request)
136
+
137
+    # -- helpers ---------------------------------------------------------
138
+
139
+    def _run_with_warm_backend(spec: SwaySpec, *, spec_path: str) -> Any:
140
+        """Common path for /run and /score: pick the right backend
141
+        from cache, call the runner, return a serializable result."""
142
+        from dlm_sway.suite.runner import run as run_suite
143
+
144
+        # Same logic as cli/_execute_spec but takes an in-memory spec
145
+        # instead of a path, and uses the cache for backend identity.
146
+        if spec.defaults.differential:
147
+            cached = cache.get_or_load(spec.models.ft)
148
+        else:
149
+            # Two-separate (base + ft as distinct backends) isn't yet
150
+            # cached at this layer — falls through to a fresh build
151
+            # per request. Cleanest path until a multi-key cache
152
+            # entry shape lands.
153
+            from dlm_sway.backends import build_two_separate
154
+
155
+            backend = build_two_separate(spec.models)
156
+            try:
157
+                return run_suite(spec, backend, spec_path=spec_path)
158
+            finally:
159
+                close = getattr(backend, "close", None)
160
+                if callable(close):
161
+                    try:
162
+                        close()
163
+                    except Exception as exc:  # noqa: BLE001
164
+                        _LOG.warning("two-separate backend close raised: %s", exc)
165
+        return run_suite(spec, cached.backend, spec_path=spec_path)
166
+
167
+    # -- /health ---------------------------------------------------------
168
+
169
+    @app.get("/health")
170
+    async def health() -> dict[str, Any]:
171
+        return {
172
+            "status": "ok",
173
+            "sway_version": __version__,
174
+            "uptime_seconds": time.monotonic() - started_at,
175
+            "loaded_models": [
176
+                {
177
+                    "kind": s.kind,
178
+                    "base": s.base,
179
+                    "adapter": str(s.adapter) if s.adapter is not None else None,
180
+                    "dtype": s.dtype,
181
+                    "device": s.device,
182
+                }
183
+                for s in cache.loaded_specs()
184
+            ],
185
+            "max_loaded_models": cache.max_size,
186
+        }
187
+
188
+    # -- /stats ----------------------------------------------------------
189
+
190
+    @app.get("/stats")
191
+    async def stats() -> dict[str, Any]:
192
+        return {
193
+            "uptime_seconds": time.monotonic() - started_at,
194
+            "request_count": request_count,
195
+            "total_run_seconds": total_run_seconds,
196
+            "mean_run_seconds": (
197
+                total_run_seconds / request_count if request_count > 0 else None
198
+            ),
199
+            "cached_backends": len(cache.loaded_keys()),
200
+            "max_loaded_models": cache.max_size,
201
+        }
202
+
203
+    # -- /run ------------------------------------------------------------
204
+
205
+    @app.post("/run")
206
+    async def run_endpoint(req: RunRequest) -> dict[str, Any]:
207
+        nonlocal request_count, total_run_seconds
208
+        started = time.monotonic()
209
+        try:
210
+            result = _run_with_warm_backend(req.spec, spec_path=req.spec_path)
211
+        except SwayError as exc:
212
+            raise HTTPException(status_code=400, detail=str(exc)) from exc
213
+        elapsed = time.monotonic() - started
214
+        request_count += 1
215
+        total_run_seconds += elapsed
216
+        # Reuse the on-disk-JSON serializer so the daemon's response
217
+        # shape matches what ``sway run --json-out`` would write — one
218
+        # less JSON contract to maintain.
219
+        import json as _json
220
+
221
+        from dlm_sway.suite import report
222
+        from dlm_sway.suite.score import compute as compute_score
223
+
224
+        score_obj = compute_score(result)
225
+        bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj))
226
+        bundled["request_seconds"] = elapsed
227
+        return bundled
228
+
229
+    # -- /score ----------------------------------------------------------
230
+
231
+    @app.post("/score")
232
+    async def score_endpoint(req: ScoreRequest) -> dict[str, Any]:
233
+        nonlocal request_count, total_run_seconds
234
+        # /score is a thin filter over /run: same code path, but only
235
+        # the named probes execute. Filtering happens at the spec
236
+        # level (we mutate a copy of spec.suite) so the runner's
237
+        # null-calibration logic still sees the right downstream
238
+        # kinds. Frozen pydantic models require .model_copy.
239
+        if req.probe_names is not None:
240
+            wanted = set(req.probe_names)
241
+            filtered_suite = [p for p in req.spec.suite if p.get("name") in wanted]
242
+            spec = req.spec.model_copy(update={"suite": filtered_suite})
243
+        else:
244
+            spec = req.spec
245
+        started = time.monotonic()
246
+        try:
247
+            result = _run_with_warm_backend(spec, spec_path="<serve:/score>")
248
+        except SwayError as exc:
249
+            raise HTTPException(status_code=400, detail=str(exc)) from exc
250
+        elapsed = time.monotonic() - started
251
+        request_count += 1
252
+        total_run_seconds += elapsed
253
+        # /score returns just the probe entries — no folded SwayScore.
254
+        # Reuse the same per-probe serializer the to_json path uses
255
+        # for consistency.
256
+        import json as _json
257
+
258
+        from dlm_sway.suite import report
259
+        from dlm_sway.suite.score import compute as compute_score
260
+
261
+        score_obj = compute_score(result)
262
+        bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj))
263
+        return {
264
+            "probes": bundled["probes"],
265
+            "request_seconds": elapsed,
266
+        }
267
+
268
+    # -- shutdown --------------------------------------------------------
269
+
270
+    @app.on_event("shutdown")
271
+    async def _on_shutdown() -> None:
272
+        cache.evict_all()
273
+
274
+    # Stash the cache + counters on the app so tests can introspect.
275
+    app.state.sway_cache = cache
276
+
277
+    return app
278
+
279
+
280
+def parse_host_port(host: str, port: int) -> tuple[str, int]:
281
+    """Validate host + port at CLI layer; raise via typer for clean exit.
282
+
283
+    Centralizes the "0.0.0.0 without auth = refuse" rule so the CLI
284
+    and any test harness share it.
285
+    """
286
+    if not (1 <= port <= 65535):
287
+        raise typer.BadParameter(f"port must be 1..65535, got {port}")
288
+    return host, port