@@ -0,0 +1,288 @@ |
| 1 | +"""FastAPI app for the ``sway serve`` daemon. |
| 2 | + |
| 3 | +Three concerns sit in this module: |
| 4 | + |
| 5 | +1. ``create_app(...)`` — factory that returns a FastAPI instance with |
| 6 | + the four documented endpoints wired up. Used by the CLI's uvicorn |
| 7 | + launcher and unit tests' :class:`fastapi.testclient.TestClient`. |
| 8 | +2. Pydantic request/response models for ``/run`` and ``/score``. |
| 9 | +3. The endpoint handlers themselves, which delegate to |
| 10 | + :func:`dlm_sway.suite.runner.run` against backends fetched from |
| 11 | + :class:`dlm_sway.serve.cache.BackendCache`. |
| 12 | + |
| 13 | +Auth is intentionally minimal in v1: localhost-only by default; binding |
| 14 | +to ``0.0.0.0`` is rejected at the CLI layer unless ``--api-key`` is |
| 15 | +set, and an API-key middleware validates the bearer token on every |
| 16 | +request when an API key was provided. Full OAuth is out of scope. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import logging |
| 22 | +import time |
| 23 | +from collections.abc import Awaitable, Callable |
| 24 | +from typing import Any |
| 25 | + |
| 26 | +import typer |
| 27 | +from pydantic import BaseModel |
| 28 | + |
| 29 | +from dlm_sway import __version__ |
| 30 | +from dlm_sway.core.errors import SwayError |
| 31 | +from dlm_sway.serve.cache import BackendCache |
| 32 | +from dlm_sway.suite.spec import SwaySpec |
| 33 | + |
| 34 | +_LOG = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +# --------------------------------------------------------------------------- |
| 38 | +# Request / response models |
| 39 | +# --------------------------------------------------------------------------- |
| 40 | + |
| 41 | + |
| 42 | +class RunRequest(BaseModel): |
| 43 | + """Body of ``POST /run`` — full spec plus optional overrides.""" |
| 44 | + |
| 45 | + spec: SwaySpec |
| 46 | + spec_path: str = "<serve>" |
| 47 | + seed: int | None = None |
| 48 | + |
| 49 | + |
| 50 | +class ScoreRequest(BaseModel): |
| 51 | + """Body of ``POST /score`` — full spec plus a subset filter. |
| 52 | + |
| 53 | + When ``probe_names`` is set, only those probes execute; the |
| 54 | + response shape is the per-probe entry list (no folded SwayScore). |
| 55 | + """ |
| 56 | + |
| 57 | + spec: SwaySpec |
| 58 | + probe_names: list[str] | None = None |
| 59 | + |
| 60 | + |
| 61 | +# --------------------------------------------------------------------------- |
| 62 | +# App factory |
| 63 | +# --------------------------------------------------------------------------- |
| 64 | + |
| 65 | + |
| 66 | +def create_app( |
| 67 | + *, |
| 68 | + cache: BackendCache | None = None, |
| 69 | + api_key: str | None = None, |
| 70 | +) -> Any: |
| 71 | + """Build a FastAPI app wired to a :class:`BackendCache`. |
| 72 | + |
| 73 | + Parameters |
| 74 | + ---------- |
| 75 | + cache: |
| 76 | + Optional pre-built cache. When omitted, a new cache with |
| 77 | + ``max_size=2`` is created. Tests inject a dummy-backed cache |
| 78 | + via this parameter; the CLI passes one with the user's |
| 79 | + ``--max-loaded-models``. |
| 80 | + api_key: |
| 81 | + Optional bearer token. When set, every request must carry |
| 82 | + ``Authorization: Bearer <key>`` or it gets a 401. When |
| 83 | + ``None`` (the default), no auth — only safe for localhost |
| 84 | + binds, which the CLI enforces. |
| 85 | + |
| 86 | + Returns |
| 87 | + ------- |
| 88 | + A :class:`fastapi.FastAPI` instance ready to hand to uvicorn or |
| 89 | + a TestClient. |
| 90 | + """ |
| 91 | + try: |
| 92 | + from fastapi import FastAPI, HTTPException, Request, status |
| 93 | + from fastapi.responses import JSONResponse |
| 94 | + except ImportError as exc: |
| 95 | + raise SwayError( |
| 96 | + "sway serve requires the [serve] extra: pip install 'dlm-sway[serve]'" |
| 97 | + ) from exc |
| 98 | + |
| 99 | + if cache is None: |
| 100 | + cache = BackendCache(max_size=2) |
| 101 | + started_at = time.monotonic() |
| 102 | + request_count = 0 |
| 103 | + total_run_seconds = 0.0 |
| 104 | + |
| 105 | + app = FastAPI( |
| 106 | + title="sway", |
| 107 | + version=__version__, |
| 108 | + description=( |
| 109 | + "Warm-backend HTTP API for sway. POST /run accepts a spec; " |
| 110 | + "the daemon keeps backends loaded between calls." |
| 111 | + ), |
| 112 | + ) |
| 113 | + |
| 114 | + # -- auth middleware -------------------------------------------------- |
| 115 | + |
| 116 | + if api_key is not None: |
| 117 | + expected_header = f"Bearer {api_key}" |
| 118 | + |
| 119 | + @app.middleware("http") |
| 120 | + async def _auth_middleware( |
| 121 | + request: Request, |
| 122 | + call_next: Callable[[Request], Awaitable[Any]], |
| 123 | + ) -> Any: |
| 124 | + # Health is unauthenticated so an external probe (e.g. a |
| 125 | + # k8s liveness check) can hit it without distributing the |
| 126 | + # token. Every other endpoint requires the bearer. |
| 127 | + if request.url.path == "/health": |
| 128 | + return await call_next(request) |
| 129 | + got = request.headers.get("Authorization", "") |
| 130 | + if got != expected_header: |
| 131 | + return JSONResponse( |
| 132 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 133 | + content={"detail": "missing or invalid Authorization header"}, |
| 134 | + ) |
| 135 | + return await call_next(request) |
| 136 | + |
| 137 | + # -- helpers --------------------------------------------------------- |
| 138 | + |
| 139 | + def _run_with_warm_backend(spec: SwaySpec, *, spec_path: str) -> Any: |
| 140 | + """Common path for /run and /score: pick the right backend |
| 141 | + from cache, call the runner, return a serializable result.""" |
| 142 | + from dlm_sway.suite.runner import run as run_suite |
| 143 | + |
| 144 | + # Same logic as cli/_execute_spec but takes an in-memory spec |
| 145 | + # instead of a path, and uses the cache for backend identity. |
| 146 | + if spec.defaults.differential: |
| 147 | + cached = cache.get_or_load(spec.models.ft) |
| 148 | + else: |
| 149 | + # Two-separate (base + ft as distinct backends) isn't yet |
| 150 | + # cached at this layer — falls through to a fresh build |
| 151 | + # per request. Cleanest path until a multi-key cache |
| 152 | + # entry shape lands. |
| 153 | + from dlm_sway.backends import build_two_separate |
| 154 | + |
| 155 | + backend = build_two_separate(spec.models) |
| 156 | + try: |
| 157 | + return run_suite(spec, backend, spec_path=spec_path) |
| 158 | + finally: |
| 159 | + close = getattr(backend, "close", None) |
| 160 | + if callable(close): |
| 161 | + try: |
| 162 | + close() |
| 163 | + except Exception as exc: # noqa: BLE001 |
| 164 | + _LOG.warning("two-separate backend close raised: %s", exc) |
| 165 | + return run_suite(spec, cached.backend, spec_path=spec_path) |
| 166 | + |
| 167 | + # -- /health --------------------------------------------------------- |
| 168 | + |
| 169 | + @app.get("/health") |
| 170 | + async def health() -> dict[str, Any]: |
| 171 | + return { |
| 172 | + "status": "ok", |
| 173 | + "sway_version": __version__, |
| 174 | + "uptime_seconds": time.monotonic() - started_at, |
| 175 | + "loaded_models": [ |
| 176 | + { |
| 177 | + "kind": s.kind, |
| 178 | + "base": s.base, |
| 179 | + "adapter": str(s.adapter) if s.adapter is not None else None, |
| 180 | + "dtype": s.dtype, |
| 181 | + "device": s.device, |
| 182 | + } |
| 183 | + for s in cache.loaded_specs() |
| 184 | + ], |
| 185 | + "max_loaded_models": cache.max_size, |
| 186 | + } |
| 187 | + |
| 188 | + # -- /stats ---------------------------------------------------------- |
| 189 | + |
| 190 | + @app.get("/stats") |
| 191 | + async def stats() -> dict[str, Any]: |
| 192 | + return { |
| 193 | + "uptime_seconds": time.monotonic() - started_at, |
| 194 | + "request_count": request_count, |
| 195 | + "total_run_seconds": total_run_seconds, |
| 196 | + "mean_run_seconds": ( |
| 197 | + total_run_seconds / request_count if request_count > 0 else None |
| 198 | + ), |
| 199 | + "cached_backends": len(cache.loaded_keys()), |
| 200 | + "max_loaded_models": cache.max_size, |
| 201 | + } |
| 202 | + |
| 203 | + # -- /run ------------------------------------------------------------ |
| 204 | + |
| 205 | + @app.post("/run") |
| 206 | + async def run_endpoint(req: RunRequest) -> dict[str, Any]: |
| 207 | + nonlocal request_count, total_run_seconds |
| 208 | + started = time.monotonic() |
| 209 | + try: |
| 210 | + result = _run_with_warm_backend(req.spec, spec_path=req.spec_path) |
| 211 | + except SwayError as exc: |
| 212 | + raise HTTPException(status_code=400, detail=str(exc)) from exc |
| 213 | + elapsed = time.monotonic() - started |
| 214 | + request_count += 1 |
| 215 | + total_run_seconds += elapsed |
| 216 | + # Reuse the on-disk-JSON serializer so the daemon's response |
| 217 | + # shape matches what ``sway run --json-out`` would write — one |
| 218 | + # less JSON contract to maintain. |
| 219 | + import json as _json |
| 220 | + |
| 221 | + from dlm_sway.suite import report |
| 222 | + from dlm_sway.suite.score import compute as compute_score |
| 223 | + |
| 224 | + score_obj = compute_score(result) |
| 225 | + bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj)) |
| 226 | + bundled["request_seconds"] = elapsed |
| 227 | + return bundled |
| 228 | + |
| 229 | + # -- /score ---------------------------------------------------------- |
| 230 | + |
| 231 | + @app.post("/score") |
| 232 | + async def score_endpoint(req: ScoreRequest) -> dict[str, Any]: |
| 233 | + nonlocal request_count, total_run_seconds |
| 234 | + # /score is a thin filter over /run: same code path, but only |
| 235 | + # the named probes execute. Filtering happens at the spec |
| 236 | + # level (we mutate a copy of spec.suite) so the runner's |
| 237 | + # null-calibration logic still sees the right downstream |
| 238 | + # kinds. Frozen pydantic models require .model_copy. |
| 239 | + if req.probe_names is not None: |
| 240 | + wanted = set(req.probe_names) |
| 241 | + filtered_suite = [p for p in req.spec.suite if p.get("name") in wanted] |
| 242 | + spec = req.spec.model_copy(update={"suite": filtered_suite}) |
| 243 | + else: |
| 244 | + spec = req.spec |
| 245 | + started = time.monotonic() |
| 246 | + try: |
| 247 | + result = _run_with_warm_backend(spec, spec_path="<serve:/score>") |
| 248 | + except SwayError as exc: |
| 249 | + raise HTTPException(status_code=400, detail=str(exc)) from exc |
| 250 | + elapsed = time.monotonic() - started |
| 251 | + request_count += 1 |
| 252 | + total_run_seconds += elapsed |
| 253 | + # /score returns just the probe entries — no folded SwayScore. |
| 254 | + # Reuse the same per-probe serializer the to_json path uses |
| 255 | + # for consistency. |
| 256 | + import json as _json |
| 257 | + |
| 258 | + from dlm_sway.suite import report |
| 259 | + from dlm_sway.suite.score import compute as compute_score |
| 260 | + |
| 261 | + score_obj = compute_score(result) |
| 262 | + bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj)) |
| 263 | + return { |
| 264 | + "probes": bundled["probes"], |
| 265 | + "request_seconds": elapsed, |
| 266 | + } |
| 267 | + |
| 268 | + # -- shutdown -------------------------------------------------------- |
| 269 | + |
| 270 | + @app.on_event("shutdown") |
| 271 | + async def _on_shutdown() -> None: |
| 272 | + cache.evict_all() |
| 273 | + |
| 274 | + # Stash the cache + counters on the app so tests can introspect. |
| 275 | + app.state.sway_cache = cache |
| 276 | + |
| 277 | + return app |
| 278 | + |
| 279 | + |
| 280 | +def parse_host_port(host: str, port: int) -> tuple[str, int]: |
| 281 | + """Validate host + port at CLI layer; raise via typer for clean exit. |
| 282 | + |
| 283 | + Centralizes the "0.0.0.0 without auth = refuse" rule so the CLI |
| 284 | + and any test harness share it. |
| 285 | + """ |
| 286 | + if not (1 <= port <= 65535): |
| 287 | + raise typer.BadParameter(f"port must be 1..65535, got {port}") |
| 288 | + return host, port |