Python · 7748 bytes Raw Blame History
1 """LRU cache for warm differential backends.
2
3 The point of the daemon is keeping backends loaded across requests.
4 This module owns the cache: keyed by an immutable identity tuple over
5 the ModelSpec fields that determine backend identity, capped at a
6 configurable size, evicts least-recently-used on overflow with proper
7 ``close()`` so weights actually get freed.
8
9 The cache is **process-local** — there's no on-disk component. Restart
10 the daemon and the cache resets cold. That's intentional: warm-backend
11 caching is a memory tradeoff, and persisting weights to disk would
12 duplicate what HuggingFace's own cache already does at the file level.
13 """
14
15 from __future__ import annotations
16
17 import logging
18 import threading
19 from collections import OrderedDict
20 from collections.abc import Hashable
21 from dataclasses import dataclass
22 from pathlib import Path
23
24 from dlm_sway.core.errors import SwayError
25 from dlm_sway.core.model import ModelSpec
26 from dlm_sway.core.scoring import DifferentialBackend
27
28 _LOG = logging.getLogger(__name__)
29
30
31 def cache_key_for(spec: ModelSpec) -> tuple[Hashable, ...]:
32 """Identity tuple for a ModelSpec.
33
34 Two ModelSpecs that differ only in fields that don't affect the
35 loaded backend (e.g. ``trust_remote_code`` on the same already-
36 cached model) hash to the same key. We're conservative — any field
37 that touches model identity (``base``, ``adapter``, ``dtype``,
38 ``device``, ``kind``) goes into the key. Path normalization
39 happens upstream in ModelSpec's validator.
40 """
41 return (
42 spec.kind,
43 spec.base,
44 str(spec.adapter) if spec.adapter is not None else None,
45 spec.dtype,
46 spec.device,
47 )
48
49
50 @dataclass(frozen=True, slots=True)
51 class CachedBackend:
52 """One entry in the cache. Frozen — fields don't mutate after load.
53
54 ``key`` is the identity tuple; ``backend`` is the live object;
55 ``model_spec`` is kept for introspection (``GET /health`` lists
56 the loaded models so users can see what's warm).
57 """
58
59 key: tuple[Hashable, ...]
60 backend: DifferentialBackend
61 model_spec: ModelSpec
62 load_seconds: float
63
64
65 class BackendCache:
66 """LRU cache of differential backends.
67
68 Thread-safe via a single internal lock. The cache contract:
69
70 1. ``get_or_load(spec)`` returns a backend; on miss, builds it
71 (paying the load cost) and admits to the cache.
72 2. On overflow (``len > max_size``), evict LRU. Eviction calls
73 ``backend.close()`` if the backend implements it; otherwise
74 drops the reference and lets GC handle it.
75 3. ``get_or_load`` is single-flight per key — concurrent requests
76 for the same model wait on the loader thread instead of
77 building the backend twice.
78 """
79
80 def __init__(self, max_size: int = 2) -> None:
81 if max_size < 1:
82 raise ValueError(f"max_size must be >= 1; got {max_size}")
83 self._max = int(max_size)
84 self._entries: OrderedDict[tuple[Hashable, ...], CachedBackend] = OrderedDict()
85 self._lock = threading.RLock()
86 # Per-key load locks so concurrent requests for the same model
87 # serialize at the loader instead of building twice.
88 self._key_locks: dict[tuple[Hashable, ...], threading.Lock] = {}
89
90 @property
91 def max_size(self) -> int:
92 return self._max
93
94 def loaded_keys(self) -> list[tuple[Hashable, ...]]:
95 """Snapshot of currently-cached keys, MRU last. Used by /health."""
96 with self._lock:
97 return list(self._entries.keys())
98
99 def loaded_specs(self) -> list[ModelSpec]:
100 """Snapshot of currently-cached model specs, MRU last."""
101 with self._lock:
102 return [entry.model_spec for entry in self._entries.values()]
103
104 def get_or_load(self, spec: ModelSpec, *, adapter_path: Path | None = None) -> CachedBackend:
105 """Return a cached backend for ``spec`` or build + admit one.
106
107 ``adapter_path`` overrides ``spec.adapter`` for the build call —
108 mirrors the upstream :func:`dlm_sway.backends.build` contract
109 so callers handing in a separately-resolved adapter (e.g. via
110 the dlm bridge) don't have to construct a copy of the spec.
111 Cache key uses ``spec.adapter``, NOT the override; if you want
112 a different adapter to cache distinctly, pass a spec that
113 encodes it.
114 """
115 key = cache_key_for(spec)
116
117 # Fast path — spec already cached.
118 with self._lock:
119 entry = self._entries.get(key)
120 if entry is not None:
121 # Touch LRU position.
122 self._entries.move_to_end(key)
123 return entry
124 key_lock = self._key_locks.setdefault(key, threading.Lock())
125
126 # Slow path — single-flight load.
127 with key_lock:
128 with self._lock:
129 # Recheck after acquiring the load lock — another thread
130 # may have completed the load while we waited.
131 entry = self._entries.get(key)
132 if entry is not None:
133 self._entries.move_to_end(key)
134 return entry
135
136 entry = _build_entry(spec, key=key, adapter_path=adapter_path)
137
138 with self._lock:
139 # Evict to fit before admitting; ensures we never spike
140 # over max_size + 1 between admission and eviction.
141 while len(self._entries) >= self._max:
142 self._evict_lru_locked()
143 self._entries[key] = entry
144 self._entries.move_to_end(key)
145 return entry
146
147 def evict_all(self) -> None:
148 """Close every backend. Called on daemon shutdown."""
149 with self._lock:
150 keys = list(self._entries.keys())
151 for k in keys:
152 self._evict_locked(k)
153
154 # -- internals -----------------------------------------------------
155
156 def _evict_lru_locked(self) -> None:
157 # Caller holds self._lock. ``OrderedDict.__iter__`` yields
158 # insertion order; LRU is the first key.
159 if not self._entries:
160 return
161 lru_key = next(iter(self._entries))
162 self._evict_locked(lru_key)
163
164 def _evict_locked(self, key: tuple[Hashable, ...]) -> None:
165 # Caller holds self._lock.
166 entry = self._entries.pop(key, None)
167 if entry is None:
168 return
169 # Backends carry a ``close()`` when they own GPU memory or
170 # network connections (HF, MLX, API). Dummy doesn't —
171 # so don't require it. Failure during close is logged and
172 # swallowed: a daemon stays up even if one backend's close()
173 # raises.
174 close = getattr(entry.backend, "close", None)
175 if callable(close):
176 try:
177 close()
178 except Exception as exc: # noqa: BLE001
179 _LOG.warning("backend close raised on eviction: %s", exc)
180 self._key_locks.pop(key, None)
181
182
183 def _build_entry(
184 spec: ModelSpec,
185 *,
186 key: tuple[Hashable, ...],
187 adapter_path: Path | None,
188 ) -> CachedBackend:
189 """Materialize a backend from a spec, timing the load."""
190 import time
191
192 from dlm_sway.backends import build as build_backend
193
194 started = time.monotonic()
195 try:
196 backend = build_backend(spec, adapter_path=adapter_path)
197 except SwayError:
198 raise
199 except Exception as exc: # noqa: BLE001 — surface load failures as SwayError
200 raise SwayError(
201 f"backend load failed for kind={spec.kind} base={spec.base!r}: "
202 f"{type(exc).__name__}: {exc}"
203 ) from exc
204 elapsed = time.monotonic() - started
205 return CachedBackend(key=key, backend=backend, model_spec=spec, load_seconds=elapsed)