Python · 10840 bytes Raw Blame History
1 """FastAPI app for the ``sway serve`` daemon.
2
3 Three concerns sit in this module:
4
5 1. ``create_app(...)`` — factory that returns a FastAPI instance with
6 the four documented endpoints wired up. Used by the CLI's uvicorn
7 launcher and unit tests' :class:`fastapi.testclient.TestClient`.
8 2. Pydantic request/response models for ``/run`` and ``/score``.
9 3. The endpoint handlers themselves, which delegate to
10 :func:`dlm_sway.suite.runner.run` against backends fetched from
11 :class:`dlm_sway.serve.cache.BackendCache`.
12
13 Auth is intentionally minimal in v1: localhost-only by default; binding
14 to ``0.0.0.0`` is rejected at the CLI layer unless ``--api-key`` is
15 set, and an API-key middleware validates the bearer token on every
16 request when an API key was provided. Full OAuth is out of scope.
17 """
18
19 from __future__ import annotations
20
21 import logging
22 import time
23 from collections.abc import Awaitable, Callable
24 from typing import Any
25
26 import typer
27 from pydantic import BaseModel
28
29 from dlm_sway import __version__
30 from dlm_sway.core.errors import SwayError
31 from dlm_sway.serve.cache import BackendCache
32 from dlm_sway.suite.spec import SwaySpec
33
34 _LOG = logging.getLogger(__name__)
35
36
37 # ---------------------------------------------------------------------------
38 # Request / response models
39 # ---------------------------------------------------------------------------
40
41
42 class RunRequest(BaseModel):
43 """Body of ``POST /run`` — full spec plus optional overrides."""
44
45 spec: SwaySpec
46 spec_path: str = "<serve>"
47 seed: int | None = None
48
49
50 class ScoreRequest(BaseModel):
51 """Body of ``POST /score`` — full spec plus a subset filter.
52
53 When ``probe_names`` is set, only those probes execute; the
54 response shape is the per-probe entry list (no folded SwayScore).
55 """
56
57 spec: SwaySpec
58 probe_names: list[str] | None = None
59
60
61 # ---------------------------------------------------------------------------
62 # App factory
63 # ---------------------------------------------------------------------------
64
65
66 def create_app(
67 *,
68 cache: BackendCache | None = None,
69 api_key: str | None = None,
70 ) -> Any:
71 """Build a FastAPI app wired to a :class:`BackendCache`.
72
73 Parameters
74 ----------
75 cache:
76 Optional pre-built cache. When omitted, a new cache with
77 ``max_size=2`` is created. Tests inject a dummy-backed cache
78 via this parameter; the CLI passes one with the user's
79 ``--max-loaded-models``.
80 api_key:
81 Optional bearer token. When set, every request must carry
82 ``Authorization: Bearer <key>`` or it gets a 401. When
83 ``None`` (the default), no auth — only safe for localhost
84 binds, which the CLI enforces.
85
86 Returns
87 -------
88 A :class:`fastapi.FastAPI` instance ready to hand to uvicorn or
89 a TestClient.
90 """
91 try:
92 from contextlib import asynccontextmanager
93
94 from fastapi import FastAPI, HTTPException, Request, status
95 from fastapi.responses import JSONResponse
96 except ImportError as exc:
97 raise SwayError(
98 "sway serve requires the [serve] extra: pip install 'dlm-sway[serve]'"
99 ) from exc
100
101 if cache is None:
102 cache = BackendCache(max_size=2)
103 started_at = time.monotonic()
104 request_count = 0
105 total_run_seconds = 0.0
106
107 # Capture cache in the lifespan closure so the shutdown leg runs
108 # ``evict_all`` (the on_event API is deprecated in FastAPI 0.110+).
109 _cache_for_lifespan = cache
110
111 @asynccontextmanager
112 async def _lifespan(app: Any) -> Any:
113 del app
114 try:
115 yield
116 finally:
117 _cache_for_lifespan.evict_all()
118
119 app = FastAPI(
120 title="sway",
121 version=__version__,
122 description=(
123 "Warm-backend HTTP API for sway. POST /run accepts a spec; "
124 "the daemon keeps backends loaded between calls."
125 ),
126 lifespan=_lifespan,
127 )
128
129 # -- auth middleware --------------------------------------------------
130
131 if api_key is not None:
132 expected_header = f"Bearer {api_key}"
133
134 @app.middleware("http")
135 async def _auth_middleware(
136 request: Request,
137 call_next: Callable[[Request], Awaitable[Any]],
138 ) -> Any:
139 # Health is unauthenticated so an external probe (e.g. a
140 # k8s liveness check) can hit it without distributing the
141 # token. Every other endpoint requires the bearer.
142 if request.url.path == "/health":
143 return await call_next(request)
144 got = request.headers.get("Authorization", "")
145 if got != expected_header:
146 return JSONResponse(
147 status_code=status.HTTP_401_UNAUTHORIZED,
148 content={"detail": "missing or invalid Authorization header"},
149 )
150 return await call_next(request)
151
152 # -- helpers ---------------------------------------------------------
153
154 def _run_with_warm_backend(spec: SwaySpec, *, spec_path: str) -> Any:
155 """Common path for /run and /score: pick the right backend
156 from cache, call the runner, return a serializable result."""
157 from dlm_sway.suite.runner import run as run_suite
158
159 # Same logic as cli/_execute_spec but takes an in-memory spec
160 # instead of a path, and uses the cache for backend identity.
161 if spec.defaults.differential:
162 cached = cache.get_or_load(spec.models.ft)
163 else:
164 # Two-separate (base + ft as distinct backends) isn't yet
165 # cached at this layer — falls through to a fresh build
166 # per request. Cleanest path until a multi-key cache
167 # entry shape lands.
168 from dlm_sway.backends import build_two_separate
169
170 backend = build_two_separate(spec.models)
171 try:
172 return run_suite(spec, backend, spec_path=spec_path)
173 finally:
174 close = getattr(backend, "close", None)
175 if callable(close):
176 try:
177 close()
178 except Exception as exc: # noqa: BLE001
179 _LOG.warning("two-separate backend close raised: %s", exc)
180 return run_suite(spec, cached.backend, spec_path=spec_path)
181
182 # -- /health ---------------------------------------------------------
183
184 @app.get("/health")
185 async def health() -> dict[str, Any]:
186 return {
187 "status": "ok",
188 "sway_version": __version__,
189 "uptime_seconds": time.monotonic() - started_at,
190 "loaded_models": [
191 {
192 "kind": s.kind,
193 "base": s.base,
194 "adapter": str(s.adapter) if s.adapter is not None else None,
195 "dtype": s.dtype,
196 "device": s.device,
197 }
198 for s in cache.loaded_specs()
199 ],
200 "max_loaded_models": cache.max_size,
201 }
202
203 # -- /stats ----------------------------------------------------------
204
205 @app.get("/stats")
206 async def stats() -> dict[str, Any]:
207 return {
208 "uptime_seconds": time.monotonic() - started_at,
209 "request_count": request_count,
210 "total_run_seconds": total_run_seconds,
211 "mean_run_seconds": (total_run_seconds / request_count if request_count > 0 else None),
212 "cached_backends": len(cache.loaded_keys()),
213 "max_loaded_models": cache.max_size,
214 }
215
216 # -- /run ------------------------------------------------------------
217
218 @app.post("/run")
219 async def run_endpoint(req: RunRequest) -> dict[str, Any]:
220 nonlocal request_count, total_run_seconds
221 started = time.monotonic()
222 try:
223 result = _run_with_warm_backend(req.spec, spec_path=req.spec_path)
224 except SwayError as exc:
225 raise HTTPException(status_code=400, detail=str(exc)) from exc
226 elapsed = time.monotonic() - started
227 request_count += 1
228 total_run_seconds += elapsed
229 # Reuse the on-disk-JSON serializer so the daemon's response
230 # shape matches what ``sway run --json-out`` would write — one
231 # less JSON contract to maintain.
232 import json as _json
233
234 from dlm_sway.suite import report
235 from dlm_sway.suite.score import compute as compute_score
236
237 score_obj = compute_score(result)
238 bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj))
239 bundled["request_seconds"] = elapsed
240 return bundled
241
242 # -- /score ----------------------------------------------------------
243
244 @app.post("/score")
245 async def score_endpoint(req: ScoreRequest) -> dict[str, Any]:
246 nonlocal request_count, total_run_seconds
247 # /score is a thin filter over /run: same code path, but only
248 # the named probes execute. Filtering happens at the spec
249 # level (we mutate a copy of spec.suite) so the runner's
250 # null-calibration logic still sees the right downstream
251 # kinds. Frozen pydantic models require .model_copy.
252 if req.probe_names is not None:
253 wanted = set(req.probe_names)
254 filtered_suite = [p for p in req.spec.suite if p.get("name") in wanted]
255 spec = req.spec.model_copy(update={"suite": filtered_suite})
256 else:
257 spec = req.spec
258 started = time.monotonic()
259 try:
260 result = _run_with_warm_backend(spec, spec_path="<serve:/score>")
261 except SwayError as exc:
262 raise HTTPException(status_code=400, detail=str(exc)) from exc
263 elapsed = time.monotonic() - started
264 request_count += 1
265 total_run_seconds += elapsed
266 # /score returns just the probe entries — no folded SwayScore.
267 # Reuse the same per-probe serializer the to_json path uses
268 # for consistency.
269 import json as _json
270
271 from dlm_sway.suite import report
272 from dlm_sway.suite.score import compute as compute_score
273
274 score_obj = compute_score(result)
275 bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj))
276 return {
277 "probes": bundled["probes"],
278 "request_seconds": elapsed,
279 }
280
281 # Stash the cache + counters on the app so tests can introspect.
282 app.state.sway_cache = cache
283
284 return app
285
286
287 def parse_host_port(host: str, port: int) -> tuple[str, int]:
288 """Validate host + port at CLI layer; raise via typer for clean exit.
289
290 Centralizes the "0.0.0.0 without auth = refuse" rule so the CLI
291 and any test harness share it.
292 """
293 if not (1 <= port <= 65535):
294 raise typer.BadParameter(f"port must be 1..65535, got {port}")
295 return host, port