Python · 13242 bytes Raw Blame History
1 """MLX backend for Apple Silicon (darwin-arm64).
2
3 Partial implementation covering the common case: a PEFT adapter that's
4 already been converted to MLX's ``.npz`` format. Unlike the HF backend,
5 MLX has no runtime ``disable_adapter`` context — adapters get fused into
6 the linear layers at load time — so this backend keeps **both** a base
7 model and an adapted model in memory. Fine for the small (<3B) models
8 MLX is typically used with on Apple Silicon; document the cost clearly.
9
10 If users point this backend at raw PEFT safetensors, ``mlx_lm.load``
11 will refuse them with its own error. A future milestone can wire a
12 PEFT-→-MLX converter; for now the contract is "bring your own .npz".
13 """
14
15 from __future__ import annotations
16
17 from collections.abc import Iterator, Sequence
18 from contextlib import contextmanager
19 from dataclasses import dataclass
20 from pathlib import Path
21 from typing import TYPE_CHECKING, Any
22
23 import numpy as np
24
25 from dlm_sway.backends._instrumentation import BackendInstrumentation
26 from dlm_sway.core.errors import BackendNotAvailableError, ProbeError
27 from dlm_sway.core.model import ModelSpec
28 from dlm_sway.core.scoring import RollingLogprob, TokenDist
29
30 if TYPE_CHECKING:
31 pass
32
33
34 def _require_mlx() -> tuple[Any, Any]:
35 try:
36 import mlx.core as mx
37 import mlx_lm
38 except ImportError as exc:
39 raise BackendNotAvailableError(
40 "mlx",
41 extra="mlx",
42 hint="MLX backend needs mlx + mlx-lm on darwin-arm64.",
43 ) from exc
44 return mx, mlx_lm
45
46
47 @dataclass(slots=True)
48 class _MLXView:
49 """One side (base or ft) of the MLX backend.
50
51 Both sides carry the same tokenizer (MLX stores it alongside the
52 converted model files, so sharing avoids double-loading). Scoring
53 methods route through ``_inst`` for the shared cache + trace +
54 stats (Sprint 07).
55 """
56
57 id: str
58 _model: Any
59 _tokenizer: Any
60 _inst: BackendInstrumentation
61
62 def generate(
63 self,
64 prompt: str,
65 *,
66 max_new_tokens: int,
67 temperature: float = 0.0,
68 top_p: float = 1.0,
69 seed: int = 0,
70 ) -> str:
71 del seed # mlx_lm.generate seeds via its own global state
72 _, mlx_lm = _require_mlx()
73 kwargs: dict[str, Any] = {"max_tokens": max_new_tokens, "verbose": False}
74 if temperature > 0.0:
75 kwargs["temp"] = temperature
76 kwargs["top_p"] = top_p
77 out = mlx_lm.generate(self._model, self._tokenizer, prompt=prompt, **kwargs)
78 return str(out)
79
80 def close(self) -> None:
81 return None
82
83 # -- ScoringBackend ------------------------------------------------
84
85 def _forward_logits(self, prompt: str) -> np.ndarray:
86 """Run the model once and return ``(seq_len, vocab)`` logits."""
87 mx, _ = _require_mlx()
88 input_ids = self._tokenizer.encode(prompt)
89 tokens = mx.array(input_ids)[None, :] # (1, T)
90 out = self._model(tokens)
91 # mlx_lm models often emit bf16/fp16 arrays whose buffer
92 # protocol numpy doesn't understand directly. Cast to fp32
93 # in MLX before handing to numpy — keeps scoring math in
94 # fp32 anyway (downstream _log_softmax up-casts to fp64).
95 return np.asarray(out[0].astype(mx.float32))
96
97 def logprob_of(self, prompt: str, completion: str) -> float:
98 key_prompt = f"{prompt}\x00{completion}"
99 return self._inst.cached(
100 "logprob_of",
101 self.id,
102 key_prompt,
103 0,
104 lambda: self._compute_logprob_of(prompt, completion),
105 )
106
107 def _compute_logprob_of(self, prompt: str, completion: str) -> float:
108 input_ids = self._tokenizer.encode(prompt)
109 full_ids = self._tokenizer.encode(prompt + completion)
110 if len(full_ids) <= len(input_ids):
111 raise ProbeError(
112 "logprob_of",
113 f"completion tokenized to zero tokens (prompt={prompt!r}, completion={completion!r})",
114 )
115 logits = self._forward_logits(prompt + completion) # (T, V)
116 # Position t predicts token t+1 — slice off the last row and the prompt span.
117 shift = logits[len(input_ids) - 1 : -1, :]
118 target_ids = np.asarray(full_ids[len(input_ids) :], dtype=np.int64)
119 log_probs = _log_softmax(shift.astype(np.float64), axis=-1)
120 gathered = log_probs[np.arange(len(target_ids)), target_ids]
121 return float(gathered.sum())
122
123 def rolling_logprob(self, text: str) -> RollingLogprob:
124 return self._inst.cached(
125 "rolling_logprob",
126 self.id,
127 text,
128 0,
129 lambda: self._compute_rolling_logprob(text),
130 )
131
132 def _compute_rolling_logprob(self, text: str) -> RollingLogprob:
133 ids = self._tokenizer.encode(text)
134 if len(ids) < 2:
135 return RollingLogprob(
136 token_ids=np.asarray(ids, dtype=np.int64),
137 logprobs=np.array([], dtype=np.float32),
138 num_tokens=len(ids),
139 total_logprob=0.0,
140 )
141 logits = self._forward_logits(text)
142 log_probs = _log_softmax(logits[:-1].astype(np.float64), axis=-1)
143 ids_arr = np.asarray(ids, dtype=np.int64)
144 gathered = log_probs[np.arange(len(ids) - 1), ids_arr[1:]]
145 return RollingLogprob(
146 token_ids=ids_arr,
147 logprobs=gathered.astype(np.float32),
148 num_tokens=len(ids),
149 total_logprob=float(gathered.sum()),
150 )
151
152 def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
153 return self._inst.cached(
154 "next_token_dist",
155 self.id,
156 prompt,
157 top_k,
158 lambda: self._compute_next_token_dist(prompt, top_k=top_k),
159 )
160
161 def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]:
162 """MLX batched variant — per-prompt loop for now.
163
164 MLX's per-prompt forward on Apple Silicon is already fast
165 enough that the kernel-launch amortization a real batched
166 forward (padded ``mx.array`` + attention mask) would buy is
167 small relative to the HF CUDA/MPS case. The S07 cache still
168 short-circuits repeat prompts. Future work can swap in an
169 ``mx.array`` padded forward and route through ``cached_batch``
170 for counter bookkeeping parity with HF.
171 """
172 return [self.next_token_dist(p, top_k=top_k) for p in prompts]
173
174 def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
175 logits = self._forward_logits(prompt)
176 last_logits = logits[-1].astype(np.float64)
177 log_probs = _log_softmax(last_logits, axis=-1)
178 vocab = int(log_probs.shape[0])
179 k = min(top_k, vocab)
180 # np.argpartition for top-k then sort the partition.
181 part = np.argpartition(log_probs, -k)[-k:]
182 top_ids = part[np.argsort(log_probs[part])[::-1]]
183 top_lp = log_probs[top_ids]
184 # B6: see TokenDist.tail_logprob — None means k covers vocab,
185 # 0.0 means measurable tail underflowed to zero.
186 if k == vocab:
187 tail_logprob: float | None = None
188 else:
189 tail_mass = float(1.0 - np.exp(top_lp).sum())
190 tail_logprob = float(np.log(tail_mass)) if tail_mass > 1e-12 else 0.0
191 return TokenDist(
192 token_ids=top_ids.astype(np.int64),
193 logprobs=top_lp.astype(np.float32),
194 vocab_size=vocab,
195 tail_logprob=tail_logprob,
196 )
197
198
199 class MLXDifferentialBackend:
200 """A :class:`~dlm_sway.core.scoring.DifferentialBackend` for MLX models.
201
202 Loads two copies of the same base model — one bare, one with the
203 adapter fused — because MLX has no runtime toggle. Memory cost: 2×
204 base weights. On typical Apple Silicon workloads with ≤3B models
205 this is acceptable.
206 """
207
208 #: MLX holds two distinct model objects — threading against them
209 #: is still unsafe today because the ``_active`` flag is single-slot
210 #: and ``mlx_lm`` isn't documented as thread-safe. False by default;
211 #: see ``.docs/design/backend-concurrency.md``.
212 safe_for_concurrent_views: bool = False
213
214 def __init__(self, *, base_spec: ModelSpec, adapter_path: Path) -> None:
215 mx, mlx_lm = _require_mlx()
216 self._mx = mx
217 self._spec = base_spec
218 raw_path = Path(adapter_path).expanduser().resolve()
219 # S24: when the user points us at a PEFT adapter (typical
220 # `dlm export` output), auto-convert into the user's cache
221 # so the headline `.dlm → sway` flow on MLX just works.
222 # Cached by content hash so repeated runs skip the convert.
223 self._adapter_path = _ensure_mlx_adapter(raw_path)
224
225 # Load bare base (no adapter).
226 self._base_model, self._tokenizer = mlx_lm.load(base_spec.base)
227 # Load ft with adapter attached. ``adapter_path`` is mlx_lm's kwarg.
228 self._ft_model, _ = mlx_lm.load(base_spec.base, adapter_path=str(self._adapter_path))
229 self._active: str | None = None
230 # Sprint 07: shared cache + trace + stats. See the HF backend
231 # for the design note on view-id-based invalidation.
232 self._inst = BackendInstrumentation()
233
234 @contextmanager
235 def as_base(self) -> Iterator[_MLXView]:
236 self._enter("base")
237 try:
238 yield _MLXView(
239 id="base",
240 _model=self._base_model,
241 _tokenizer=self._tokenizer,
242 _inst=self._inst,
243 )
244 finally:
245 self._exit()
246
247 @contextmanager
248 def as_finetuned(self) -> Iterator[_MLXView]:
249 self._enter("ft")
250 try:
251 yield _MLXView(
252 id="ft",
253 _model=self._ft_model,
254 _tokenizer=self._tokenizer,
255 _inst=self._inst,
256 )
257 finally:
258 self._exit()
259
260 def close(self) -> None:
261 """MLX reclaims memory when references drop; flush the tracer."""
262 inst = getattr(self, "_inst", None)
263 if inst is not None:
264 inst.close()
265
266 def _enter(self, mode: str) -> None:
267 if self._active is not None:
268 raise RuntimeError(
269 f"MLXDifferentialBackend view {self._active!r} already active; "
270 f"exit it before entering {mode!r}."
271 )
272 self._active = mode
273
274 def _exit(self) -> None:
275 self._active = None
276
277
278 def _ensure_mlx_adapter(adapter_path: Path) -> Path:
279 """Auto-convert PEFT adapters to MLX-LM format on first load (S24).
280
281 Detection is structural: if ``adapter_path/adapter_model.safetensors``
282 exists, we treat it as PEFT and run the converter. If it already
283 contains ``adapters.safetensors`` (mlx-lm's filename), we leave it
284 alone — assumes the user converted manually or the dir is already
285 MLX-shaped.
286
287 Cached at ``${XDG_CACHE_HOME:-$HOME/.cache}/dlm-sway/mlx-converted/<sha>/``
288 keyed on a hash of the source ``adapter_model.safetensors`` bytes.
289 Repeated runs on the same adapter version skip conversion entirely
290 (~10 ms hash + dir lookup).
291 """
292 if (adapter_path / "adapters.safetensors").exists():
293 # Already in MLX format — pass through unchanged.
294 return adapter_path
295 if not (adapter_path / "adapter_model.safetensors").exists():
296 # Neither MLX nor PEFT shape; let mlx_lm.load surface its own error.
297 return adapter_path
298
299 # Compute a content hash of the source PEFT safetensors. blake2b
300 # in 16-byte digest mode is overkill on file IO but unambiguous —
301 # different adapter versions never collide.
302 import hashlib
303
304 src_st = adapter_path / "adapter_model.safetensors"
305 h = hashlib.blake2b(digest_size=16)
306 with src_st.open("rb") as fh:
307 for chunk in iter(lambda: fh.read(1024 * 1024), b""):
308 h.update(chunk)
309 sha = h.hexdigest()
310
311 cache_root = _mlx_cache_root() / sha
312 if (cache_root / "adapters.safetensors").exists() and (
313 cache_root / "adapter_config.json"
314 ).exists():
315 return cache_root
316
317 # First-run conversion. Import here to keep the cycle off the
318 # import path of users who never touch MLX.
319 from dlm_sway.backends._mlx_convert import convert_peft_to_mlx
320
321 cache_root.mkdir(parents=True, exist_ok=True)
322 convert_peft_to_mlx(adapter_path, cache_root, overwrite=True)
323 return cache_root
324
325
326 def _mlx_cache_root() -> Path:
327 """``$XDG_CACHE_HOME/dlm-sway/mlx-converted/`` (or ``~/.cache/...``).
328
329 Honors XDG so Linux users get their conventional cache location;
330 macOS users get ``~/.cache/...`` (XDG isn't standard on darwin
331 but uv + many Python tools follow this convention there too).
332 """
333 import os
334
335 base = os.environ.get("XDG_CACHE_HOME") or str(Path.home() / ".cache")
336 return Path(base) / "dlm-sway" / "mlx-converted"
337
338
339 def _log_softmax(x: np.ndarray, *, axis: int) -> np.ndarray:
340 x_max = np.max(x, axis=axis, keepdims=True)
341 y = x - x_max
342 log_sum = np.log(np.sum(np.exp(y), axis=axis, keepdims=True))
343 return np.asarray(y - log_sum, dtype=np.float64)
344
345
346 __all__ = ["MLXDifferentialBackend"]