| 1 | """FastAPI app for the ``sway serve`` daemon. |
| 2 | |
| 3 | Three concerns sit in this module: |
| 4 | |
| 5 | 1. ``create_app(...)`` — factory that returns a FastAPI instance with |
| 6 | the four documented endpoints wired up. Used by the CLI's uvicorn |
| 7 | launcher and unit tests' :class:`fastapi.testclient.TestClient`. |
| 8 | 2. Pydantic request/response models for ``/run`` and ``/score``. |
| 9 | 3. The endpoint handlers themselves, which delegate to |
| 10 | :func:`dlm_sway.suite.runner.run` against backends fetched from |
| 11 | :class:`dlm_sway.serve.cache.BackendCache`. |
| 12 | |
| 13 | Auth is intentionally minimal in v1: localhost-only by default; binding |
| 14 | to ``0.0.0.0`` is rejected at the CLI layer unless ``--api-key`` is |
| 15 | set, and an API-key middleware validates the bearer token on every |
| 16 | request when an API key was provided. Full OAuth is out of scope. |
| 17 | """ |
| 18 | |
| 19 | from __future__ import annotations |
| 20 | |
| 21 | import logging |
| 22 | import time |
| 23 | from collections.abc import Awaitable, Callable |
| 24 | from typing import Any |
| 25 | |
| 26 | import typer |
| 27 | from pydantic import BaseModel |
| 28 | |
| 29 | from dlm_sway import __version__ |
| 30 | from dlm_sway.core.errors import SwayError |
| 31 | from dlm_sway.serve.cache import BackendCache |
| 32 | from dlm_sway.suite.spec import SwaySpec |
| 33 | |
| 34 | _LOG = logging.getLogger(__name__) |
| 35 | |
| 36 | |
| 37 | # --------------------------------------------------------------------------- |
| 38 | # Request / response models |
| 39 | # --------------------------------------------------------------------------- |
| 40 | |
| 41 | |
| 42 | class RunRequest(BaseModel): |
| 43 | """Body of ``POST /run`` — full spec plus optional overrides.""" |
| 44 | |
| 45 | spec: SwaySpec |
| 46 | spec_path: str = "<serve>" |
| 47 | seed: int | None = None |
| 48 | |
| 49 | |
| 50 | class ScoreRequest(BaseModel): |
| 51 | """Body of ``POST /score`` — full spec plus a subset filter. |
| 52 | |
| 53 | When ``probe_names`` is set, only those probes execute; the |
| 54 | response shape is the per-probe entry list (no folded SwayScore). |
| 55 | """ |
| 56 | |
| 57 | spec: SwaySpec |
| 58 | probe_names: list[str] | None = None |
| 59 | |
| 60 | |
| 61 | # --------------------------------------------------------------------------- |
| 62 | # App factory |
| 63 | # --------------------------------------------------------------------------- |
| 64 | |
| 65 | |
| 66 | def create_app( |
| 67 | *, |
| 68 | cache: BackendCache | None = None, |
| 69 | api_key: str | None = None, |
| 70 | ) -> Any: |
| 71 | """Build a FastAPI app wired to a :class:`BackendCache`. |
| 72 | |
| 73 | Parameters |
| 74 | ---------- |
| 75 | cache: |
| 76 | Optional pre-built cache. When omitted, a new cache with |
| 77 | ``max_size=2`` is created. Tests inject a dummy-backed cache |
| 78 | via this parameter; the CLI passes one with the user's |
| 79 | ``--max-loaded-models``. |
| 80 | api_key: |
| 81 | Optional bearer token. When set, every request must carry |
| 82 | ``Authorization: Bearer <key>`` or it gets a 401. When |
| 83 | ``None`` (the default), no auth — only safe for localhost |
| 84 | binds, which the CLI enforces. |
| 85 | |
| 86 | Returns |
| 87 | ------- |
| 88 | A :class:`fastapi.FastAPI` instance ready to hand to uvicorn or |
| 89 | a TestClient. |
| 90 | """ |
| 91 | try: |
| 92 | from contextlib import asynccontextmanager |
| 93 | |
| 94 | from fastapi import FastAPI, HTTPException, Request, status |
| 95 | from fastapi.responses import JSONResponse |
| 96 | except ImportError as exc: |
| 97 | raise SwayError( |
| 98 | "sway serve requires the [serve] extra: pip install 'dlm-sway[serve]'" |
| 99 | ) from exc |
| 100 | |
| 101 | if cache is None: |
| 102 | cache = BackendCache(max_size=2) |
| 103 | started_at = time.monotonic() |
| 104 | request_count = 0 |
| 105 | total_run_seconds = 0.0 |
| 106 | |
| 107 | # Capture cache in the lifespan closure so the shutdown leg runs |
| 108 | # ``evict_all`` (the on_event API is deprecated in FastAPI 0.110+). |
| 109 | _cache_for_lifespan = cache |
| 110 | |
| 111 | @asynccontextmanager |
| 112 | async def _lifespan(app: Any) -> Any: |
| 113 | del app |
| 114 | try: |
| 115 | yield |
| 116 | finally: |
| 117 | _cache_for_lifespan.evict_all() |
| 118 | |
| 119 | app = FastAPI( |
| 120 | title="sway", |
| 121 | version=__version__, |
| 122 | description=( |
| 123 | "Warm-backend HTTP API for sway. POST /run accepts a spec; " |
| 124 | "the daemon keeps backends loaded between calls." |
| 125 | ), |
| 126 | lifespan=_lifespan, |
| 127 | ) |
| 128 | |
| 129 | # -- auth middleware -------------------------------------------------- |
| 130 | |
| 131 | if api_key is not None: |
| 132 | expected_header = f"Bearer {api_key}" |
| 133 | |
| 134 | @app.middleware("http") |
| 135 | async def _auth_middleware( |
| 136 | request: Request, |
| 137 | call_next: Callable[[Request], Awaitable[Any]], |
| 138 | ) -> Any: |
| 139 | # Health is unauthenticated so an external probe (e.g. a |
| 140 | # k8s liveness check) can hit it without distributing the |
| 141 | # token. Every other endpoint requires the bearer. |
| 142 | if request.url.path == "/health": |
| 143 | return await call_next(request) |
| 144 | got = request.headers.get("Authorization", "") |
| 145 | if got != expected_header: |
| 146 | return JSONResponse( |
| 147 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 148 | content={"detail": "missing or invalid Authorization header"}, |
| 149 | ) |
| 150 | return await call_next(request) |
| 151 | |
| 152 | # -- helpers --------------------------------------------------------- |
| 153 | |
| 154 | def _run_with_warm_backend(spec: SwaySpec, *, spec_path: str) -> Any: |
| 155 | """Common path for /run and /score: pick the right backend |
| 156 | from cache, call the runner, return a serializable result.""" |
| 157 | from dlm_sway.suite.runner import run as run_suite |
| 158 | |
| 159 | # Same logic as cli/_execute_spec but takes an in-memory spec |
| 160 | # instead of a path, and uses the cache for backend identity. |
| 161 | if spec.defaults.differential: |
| 162 | cached = cache.get_or_load(spec.models.ft) |
| 163 | else: |
| 164 | # Two-separate (base + ft as distinct backends) isn't yet |
| 165 | # cached at this layer — falls through to a fresh build |
| 166 | # per request. Cleanest path until a multi-key cache |
| 167 | # entry shape lands. |
| 168 | from dlm_sway.backends import build_two_separate |
| 169 | |
| 170 | backend = build_two_separate(spec.models) |
| 171 | try: |
| 172 | return run_suite(spec, backend, spec_path=spec_path) |
| 173 | finally: |
| 174 | close = getattr(backend, "close", None) |
| 175 | if callable(close): |
| 176 | try: |
| 177 | close() |
| 178 | except Exception as exc: # noqa: BLE001 |
| 179 | _LOG.warning("two-separate backend close raised: %s", exc) |
| 180 | return run_suite(spec, cached.backend, spec_path=spec_path) |
| 181 | |
| 182 | # -- /health --------------------------------------------------------- |
| 183 | |
| 184 | @app.get("/health") |
| 185 | async def health() -> dict[str, Any]: |
| 186 | return { |
| 187 | "status": "ok", |
| 188 | "sway_version": __version__, |
| 189 | "uptime_seconds": time.monotonic() - started_at, |
| 190 | "loaded_models": [ |
| 191 | { |
| 192 | "kind": s.kind, |
| 193 | "base": s.base, |
| 194 | "adapter": str(s.adapter) if s.adapter is not None else None, |
| 195 | "dtype": s.dtype, |
| 196 | "device": s.device, |
| 197 | } |
| 198 | for s in cache.loaded_specs() |
| 199 | ], |
| 200 | "max_loaded_models": cache.max_size, |
| 201 | } |
| 202 | |
| 203 | # -- /stats ---------------------------------------------------------- |
| 204 | |
| 205 | @app.get("/stats") |
| 206 | async def stats() -> dict[str, Any]: |
| 207 | return { |
| 208 | "uptime_seconds": time.monotonic() - started_at, |
| 209 | "request_count": request_count, |
| 210 | "total_run_seconds": total_run_seconds, |
| 211 | "mean_run_seconds": (total_run_seconds / request_count if request_count > 0 else None), |
| 212 | "cached_backends": len(cache.loaded_keys()), |
| 213 | "max_loaded_models": cache.max_size, |
| 214 | } |
| 215 | |
| 216 | # -- /run ------------------------------------------------------------ |
| 217 | |
| 218 | @app.post("/run") |
| 219 | async def run_endpoint(req: RunRequest) -> dict[str, Any]: |
| 220 | nonlocal request_count, total_run_seconds |
| 221 | started = time.monotonic() |
| 222 | try: |
| 223 | result = _run_with_warm_backend(req.spec, spec_path=req.spec_path) |
| 224 | except SwayError as exc: |
| 225 | raise HTTPException(status_code=400, detail=str(exc)) from exc |
| 226 | elapsed = time.monotonic() - started |
| 227 | request_count += 1 |
| 228 | total_run_seconds += elapsed |
| 229 | # Reuse the on-disk-JSON serializer so the daemon's response |
| 230 | # shape matches what ``sway run --json-out`` would write — one |
| 231 | # less JSON contract to maintain. |
| 232 | import json as _json |
| 233 | |
| 234 | from dlm_sway.suite import report |
| 235 | from dlm_sway.suite.score import compute as compute_score |
| 236 | |
| 237 | score_obj = compute_score(result) |
| 238 | bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj)) |
| 239 | bundled["request_seconds"] = elapsed |
| 240 | return bundled |
| 241 | |
| 242 | # -- /score ---------------------------------------------------------- |
| 243 | |
| 244 | @app.post("/score") |
| 245 | async def score_endpoint(req: ScoreRequest) -> dict[str, Any]: |
| 246 | nonlocal request_count, total_run_seconds |
| 247 | # /score is a thin filter over /run: same code path, but only |
| 248 | # the named probes execute. Filtering happens at the spec |
| 249 | # level (we mutate a copy of spec.suite) so the runner's |
| 250 | # null-calibration logic still sees the right downstream |
| 251 | # kinds. Frozen pydantic models require .model_copy. |
| 252 | if req.probe_names is not None: |
| 253 | wanted = set(req.probe_names) |
| 254 | filtered_suite = [p for p in req.spec.suite if p.get("name") in wanted] |
| 255 | spec = req.spec.model_copy(update={"suite": filtered_suite}) |
| 256 | else: |
| 257 | spec = req.spec |
| 258 | started = time.monotonic() |
| 259 | try: |
| 260 | result = _run_with_warm_backend(spec, spec_path="<serve:/score>") |
| 261 | except SwayError as exc: |
| 262 | raise HTTPException(status_code=400, detail=str(exc)) from exc |
| 263 | elapsed = time.monotonic() - started |
| 264 | request_count += 1 |
| 265 | total_run_seconds += elapsed |
| 266 | # /score returns just the probe entries — no folded SwayScore. |
| 267 | # Reuse the same per-probe serializer the to_json path uses |
| 268 | # for consistency. |
| 269 | import json as _json |
| 270 | |
| 271 | from dlm_sway.suite import report |
| 272 | from dlm_sway.suite.score import compute as compute_score |
| 273 | |
| 274 | score_obj = compute_score(result) |
| 275 | bundled: dict[str, Any] = _json.loads(report.to_json(result, score_obj)) |
| 276 | return { |
| 277 | "probes": bundled["probes"], |
| 278 | "request_seconds": elapsed, |
| 279 | } |
| 280 | |
| 281 | # Stash the cache + counters on the app so tests can introspect. |
| 282 | app.state.sway_cache = cache |
| 283 | |
| 284 | return app |
| 285 | |
| 286 | |
| 287 | def parse_host_port(host: str, port: int) -> tuple[str, int]: |
| 288 | """Validate host + port at CLI layer; raise via typer for clean exit. |
| 289 | |
| 290 | Centralizes the "0.0.0.0 without auth = refuse" rule so the CLI |
| 291 | and any test harness share it. |
| 292 | """ |
| 293 | if not (1 <= port <= 65535): |
| 294 | raise typer.BadParameter(f"port must be 1..65535, got {port}") |
| 295 | return host, port |