| 1 | """MLX backend for Apple Silicon (darwin-arm64). |
| 2 | |
| 3 | Partial implementation covering the common case: a PEFT adapter that's |
| 4 | already been converted to MLX's ``.npz`` format. Unlike the HF backend, |
| 5 | MLX has no runtime ``disable_adapter`` context — adapters get fused into |
| 6 | the linear layers at load time — so this backend keeps **both** a base |
| 7 | model and an adapted model in memory. Fine for the small (<3B) models |
| 8 | MLX is typically used with on Apple Silicon; document the cost clearly. |
| 9 | |
| 10 | If users point this backend at raw PEFT safetensors, ``mlx_lm.load`` |
| 11 | will refuse them with its own error. A future milestone can wire a |
| 12 | PEFT-→-MLX converter; for now the contract is "bring your own .npz". |
| 13 | """ |
| 14 | |
| 15 | from __future__ import annotations |
| 16 | |
| 17 | from collections.abc import Iterator, Sequence |
| 18 | from contextlib import contextmanager |
| 19 | from dataclasses import dataclass |
| 20 | from pathlib import Path |
| 21 | from typing import TYPE_CHECKING, Any |
| 22 | |
| 23 | import numpy as np |
| 24 | |
| 25 | from dlm_sway.backends._instrumentation import BackendInstrumentation |
| 26 | from dlm_sway.core.errors import BackendNotAvailableError, ProbeError |
| 27 | from dlm_sway.core.model import ModelSpec |
| 28 | from dlm_sway.core.scoring import RollingLogprob, TokenDist |
| 29 | |
| 30 | if TYPE_CHECKING: |
| 31 | pass |
| 32 | |
| 33 | |
| 34 | def _require_mlx() -> tuple[Any, Any]: |
| 35 | try: |
| 36 | import mlx.core as mx |
| 37 | import mlx_lm |
| 38 | except ImportError as exc: |
| 39 | raise BackendNotAvailableError( |
| 40 | "mlx", |
| 41 | extra="mlx", |
| 42 | hint="MLX backend needs mlx + mlx-lm on darwin-arm64.", |
| 43 | ) from exc |
| 44 | return mx, mlx_lm |
| 45 | |
| 46 | |
| 47 | @dataclass(slots=True) |
| 48 | class _MLXView: |
| 49 | """One side (base or ft) of the MLX backend. |
| 50 | |
| 51 | Both sides carry the same tokenizer (MLX stores it alongside the |
| 52 | converted model files, so sharing avoids double-loading). Scoring |
| 53 | methods route through ``_inst`` for the shared cache + trace + |
| 54 | stats (Sprint 07). |
| 55 | """ |
| 56 | |
| 57 | id: str |
| 58 | _model: Any |
| 59 | _tokenizer: Any |
| 60 | _inst: BackendInstrumentation |
| 61 | |
| 62 | def generate( |
| 63 | self, |
| 64 | prompt: str, |
| 65 | *, |
| 66 | max_new_tokens: int, |
| 67 | temperature: float = 0.0, |
| 68 | top_p: float = 1.0, |
| 69 | seed: int = 0, |
| 70 | ) -> str: |
| 71 | del seed # mlx_lm.generate seeds via its own global state |
| 72 | _, mlx_lm = _require_mlx() |
| 73 | kwargs: dict[str, Any] = {"max_tokens": max_new_tokens, "verbose": False} |
| 74 | if temperature > 0.0: |
| 75 | kwargs["temp"] = temperature |
| 76 | kwargs["top_p"] = top_p |
| 77 | out = mlx_lm.generate(self._model, self._tokenizer, prompt=prompt, **kwargs) |
| 78 | return str(out) |
| 79 | |
| 80 | def close(self) -> None: |
| 81 | return None |
| 82 | |
| 83 | # -- ScoringBackend ------------------------------------------------ |
| 84 | |
| 85 | def _forward_logits(self, prompt: str) -> np.ndarray: |
| 86 | """Run the model once and return ``(seq_len, vocab)`` logits.""" |
| 87 | mx, _ = _require_mlx() |
| 88 | input_ids = self._tokenizer.encode(prompt) |
| 89 | tokens = mx.array(input_ids)[None, :] # (1, T) |
| 90 | out = self._model(tokens) |
| 91 | # mlx_lm models often emit bf16/fp16 arrays whose buffer |
| 92 | # protocol numpy doesn't understand directly. Cast to fp32 |
| 93 | # in MLX before handing to numpy — keeps scoring math in |
| 94 | # fp32 anyway (downstream _log_softmax up-casts to fp64). |
| 95 | return np.asarray(out[0].astype(mx.float32)) |
| 96 | |
| 97 | def logprob_of(self, prompt: str, completion: str) -> float: |
| 98 | key_prompt = f"{prompt}\x00{completion}" |
| 99 | return self._inst.cached( |
| 100 | "logprob_of", |
| 101 | self.id, |
| 102 | key_prompt, |
| 103 | 0, |
| 104 | lambda: self._compute_logprob_of(prompt, completion), |
| 105 | ) |
| 106 | |
| 107 | def _compute_logprob_of(self, prompt: str, completion: str) -> float: |
| 108 | input_ids = self._tokenizer.encode(prompt) |
| 109 | full_ids = self._tokenizer.encode(prompt + completion) |
| 110 | if len(full_ids) <= len(input_ids): |
| 111 | raise ProbeError( |
| 112 | "logprob_of", |
| 113 | f"completion tokenized to zero tokens (prompt={prompt!r}, completion={completion!r})", |
| 114 | ) |
| 115 | logits = self._forward_logits(prompt + completion) # (T, V) |
| 116 | # Position t predicts token t+1 — slice off the last row and the prompt span. |
| 117 | shift = logits[len(input_ids) - 1 : -1, :] |
| 118 | target_ids = np.asarray(full_ids[len(input_ids) :], dtype=np.int64) |
| 119 | log_probs = _log_softmax(shift.astype(np.float64), axis=-1) |
| 120 | gathered = log_probs[np.arange(len(target_ids)), target_ids] |
| 121 | return float(gathered.sum()) |
| 122 | |
| 123 | def rolling_logprob(self, text: str) -> RollingLogprob: |
| 124 | return self._inst.cached( |
| 125 | "rolling_logprob", |
| 126 | self.id, |
| 127 | text, |
| 128 | 0, |
| 129 | lambda: self._compute_rolling_logprob(text), |
| 130 | ) |
| 131 | |
| 132 | def _compute_rolling_logprob(self, text: str) -> RollingLogprob: |
| 133 | ids = self._tokenizer.encode(text) |
| 134 | if len(ids) < 2: |
| 135 | return RollingLogprob( |
| 136 | token_ids=np.asarray(ids, dtype=np.int64), |
| 137 | logprobs=np.array([], dtype=np.float32), |
| 138 | num_tokens=len(ids), |
| 139 | total_logprob=0.0, |
| 140 | ) |
| 141 | logits = self._forward_logits(text) |
| 142 | log_probs = _log_softmax(logits[:-1].astype(np.float64), axis=-1) |
| 143 | ids_arr = np.asarray(ids, dtype=np.int64) |
| 144 | gathered = log_probs[np.arange(len(ids) - 1), ids_arr[1:]] |
| 145 | return RollingLogprob( |
| 146 | token_ids=ids_arr, |
| 147 | logprobs=gathered.astype(np.float32), |
| 148 | num_tokens=len(ids), |
| 149 | total_logprob=float(gathered.sum()), |
| 150 | ) |
| 151 | |
| 152 | def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 153 | return self._inst.cached( |
| 154 | "next_token_dist", |
| 155 | self.id, |
| 156 | prompt, |
| 157 | top_k, |
| 158 | lambda: self._compute_next_token_dist(prompt, top_k=top_k), |
| 159 | ) |
| 160 | |
| 161 | def next_token_dist_batch(self, prompts: Sequence[str], *, top_k: int = 256) -> list[TokenDist]: |
| 162 | """MLX batched variant — per-prompt loop for now. |
| 163 | |
| 164 | MLX's per-prompt forward on Apple Silicon is already fast |
| 165 | enough that the kernel-launch amortization a real batched |
| 166 | forward (padded ``mx.array`` + attention mask) would buy is |
| 167 | small relative to the HF CUDA/MPS case. The S07 cache still |
| 168 | short-circuits repeat prompts. Future work can swap in an |
| 169 | ``mx.array`` padded forward and route through ``cached_batch`` |
| 170 | for counter bookkeeping parity with HF. |
| 171 | """ |
| 172 | return [self.next_token_dist(p, top_k=top_k) for p in prompts] |
| 173 | |
| 174 | def _compute_next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist: |
| 175 | logits = self._forward_logits(prompt) |
| 176 | last_logits = logits[-1].astype(np.float64) |
| 177 | log_probs = _log_softmax(last_logits, axis=-1) |
| 178 | vocab = int(log_probs.shape[0]) |
| 179 | k = min(top_k, vocab) |
| 180 | # np.argpartition for top-k then sort the partition. |
| 181 | part = np.argpartition(log_probs, -k)[-k:] |
| 182 | top_ids = part[np.argsort(log_probs[part])[::-1]] |
| 183 | top_lp = log_probs[top_ids] |
| 184 | # B6: see TokenDist.tail_logprob — None means k covers vocab, |
| 185 | # 0.0 means measurable tail underflowed to zero. |
| 186 | if k == vocab: |
| 187 | tail_logprob: float | None = None |
| 188 | else: |
| 189 | tail_mass = float(1.0 - np.exp(top_lp).sum()) |
| 190 | tail_logprob = float(np.log(tail_mass)) if tail_mass > 1e-12 else 0.0 |
| 191 | return TokenDist( |
| 192 | token_ids=top_ids.astype(np.int64), |
| 193 | logprobs=top_lp.astype(np.float32), |
| 194 | vocab_size=vocab, |
| 195 | tail_logprob=tail_logprob, |
| 196 | ) |
| 197 | |
| 198 | |
| 199 | class MLXDifferentialBackend: |
| 200 | """A :class:`~dlm_sway.core.scoring.DifferentialBackend` for MLX models. |
| 201 | |
| 202 | Loads two copies of the same base model — one bare, one with the |
| 203 | adapter fused — because MLX has no runtime toggle. Memory cost: 2× |
| 204 | base weights. On typical Apple Silicon workloads with ≤3B models |
| 205 | this is acceptable. |
| 206 | """ |
| 207 | |
| 208 | #: MLX holds two distinct model objects — threading against them |
| 209 | #: is still unsafe today because the ``_active`` flag is single-slot |
| 210 | #: and ``mlx_lm`` isn't documented as thread-safe. False by default; |
| 211 | #: see ``.docs/design/backend-concurrency.md``. |
| 212 | safe_for_concurrent_views: bool = False |
| 213 | |
| 214 | def __init__(self, *, base_spec: ModelSpec, adapter_path: Path) -> None: |
| 215 | mx, mlx_lm = _require_mlx() |
| 216 | self._mx = mx |
| 217 | self._spec = base_spec |
| 218 | raw_path = Path(adapter_path).expanduser().resolve() |
| 219 | # S24: when the user points us at a PEFT adapter (typical |
| 220 | # `dlm export` output), auto-convert into the user's cache |
| 221 | # so the headline `.dlm → sway` flow on MLX just works. |
| 222 | # Cached by content hash so repeated runs skip the convert. |
| 223 | self._adapter_path = _ensure_mlx_adapter(raw_path) |
| 224 | |
| 225 | # Load bare base (no adapter). |
| 226 | self._base_model, self._tokenizer = mlx_lm.load(base_spec.base) |
| 227 | # Load ft with adapter attached. ``adapter_path`` is mlx_lm's kwarg. |
| 228 | self._ft_model, _ = mlx_lm.load(base_spec.base, adapter_path=str(self._adapter_path)) |
| 229 | self._active: str | None = None |
| 230 | # Sprint 07: shared cache + trace + stats. See the HF backend |
| 231 | # for the design note on view-id-based invalidation. |
| 232 | self._inst = BackendInstrumentation() |
| 233 | |
| 234 | @contextmanager |
| 235 | def as_base(self) -> Iterator[_MLXView]: |
| 236 | self._enter("base") |
| 237 | try: |
| 238 | yield _MLXView( |
| 239 | id="base", |
| 240 | _model=self._base_model, |
| 241 | _tokenizer=self._tokenizer, |
| 242 | _inst=self._inst, |
| 243 | ) |
| 244 | finally: |
| 245 | self._exit() |
| 246 | |
| 247 | @contextmanager |
| 248 | def as_finetuned(self) -> Iterator[_MLXView]: |
| 249 | self._enter("ft") |
| 250 | try: |
| 251 | yield _MLXView( |
| 252 | id="ft", |
| 253 | _model=self._ft_model, |
| 254 | _tokenizer=self._tokenizer, |
| 255 | _inst=self._inst, |
| 256 | ) |
| 257 | finally: |
| 258 | self._exit() |
| 259 | |
| 260 | def close(self) -> None: |
| 261 | """MLX reclaims memory when references drop; flush the tracer.""" |
| 262 | inst = getattr(self, "_inst", None) |
| 263 | if inst is not None: |
| 264 | inst.close() |
| 265 | |
| 266 | def _enter(self, mode: str) -> None: |
| 267 | if self._active is not None: |
| 268 | raise RuntimeError( |
| 269 | f"MLXDifferentialBackend view {self._active!r} already active; " |
| 270 | f"exit it before entering {mode!r}." |
| 271 | ) |
| 272 | self._active = mode |
| 273 | |
| 274 | def _exit(self) -> None: |
| 275 | self._active = None |
| 276 | |
| 277 | |
| 278 | def _ensure_mlx_adapter(adapter_path: Path) -> Path: |
| 279 | """Auto-convert PEFT adapters to MLX-LM format on first load (S24). |
| 280 | |
| 281 | Detection is structural: if ``adapter_path/adapter_model.safetensors`` |
| 282 | exists, we treat it as PEFT and run the converter. If it already |
| 283 | contains ``adapters.safetensors`` (mlx-lm's filename), we leave it |
| 284 | alone — assumes the user converted manually or the dir is already |
| 285 | MLX-shaped. |
| 286 | |
| 287 | Cached at ``${XDG_CACHE_HOME:-$HOME/.cache}/dlm-sway/mlx-converted/<sha>/`` |
| 288 | keyed on a hash of the source ``adapter_model.safetensors`` bytes. |
| 289 | Repeated runs on the same adapter version skip conversion entirely |
| 290 | (~10 ms hash + dir lookup). |
| 291 | """ |
| 292 | if (adapter_path / "adapters.safetensors").exists(): |
| 293 | # Already in MLX format — pass through unchanged. |
| 294 | return adapter_path |
| 295 | if not (adapter_path / "adapter_model.safetensors").exists(): |
| 296 | # Neither MLX nor PEFT shape; let mlx_lm.load surface its own error. |
| 297 | return adapter_path |
| 298 | |
| 299 | # Compute a content hash of the source PEFT safetensors. blake2b |
| 300 | # in 16-byte digest mode is overkill on file IO but unambiguous — |
| 301 | # different adapter versions never collide. |
| 302 | import hashlib |
| 303 | |
| 304 | src_st = adapter_path / "adapter_model.safetensors" |
| 305 | h = hashlib.blake2b(digest_size=16) |
| 306 | with src_st.open("rb") as fh: |
| 307 | for chunk in iter(lambda: fh.read(1024 * 1024), b""): |
| 308 | h.update(chunk) |
| 309 | sha = h.hexdigest() |
| 310 | |
| 311 | cache_root = _mlx_cache_root() / sha |
| 312 | if (cache_root / "adapters.safetensors").exists() and ( |
| 313 | cache_root / "adapter_config.json" |
| 314 | ).exists(): |
| 315 | return cache_root |
| 316 | |
| 317 | # First-run conversion. Import here to keep the cycle off the |
| 318 | # import path of users who never touch MLX. |
| 319 | from dlm_sway.backends._mlx_convert import convert_peft_to_mlx |
| 320 | |
| 321 | cache_root.mkdir(parents=True, exist_ok=True) |
| 322 | convert_peft_to_mlx(adapter_path, cache_root, overwrite=True) |
| 323 | return cache_root |
| 324 | |
| 325 | |
| 326 | def _mlx_cache_root() -> Path: |
| 327 | """``$XDG_CACHE_HOME/dlm-sway/mlx-converted/`` (or ``~/.cache/...``). |
| 328 | |
| 329 | Honors XDG so Linux users get their conventional cache location; |
| 330 | macOS users get ``~/.cache/...`` (XDG isn't standard on darwin |
| 331 | but uv + many Python tools follow this convention there too). |
| 332 | """ |
| 333 | import os |
| 334 | |
| 335 | base = os.environ.get("XDG_CACHE_HOME") or str(Path.home() / ".cache") |
| 336 | return Path(base) / "dlm-sway" / "mlx-converted" |
| 337 | |
| 338 | |
| 339 | def _log_softmax(x: np.ndarray, *, axis: int) -> np.ndarray: |
| 340 | x_max = np.max(x, axis=axis, keepdims=True) |
| 341 | y = x - x_max |
| 342 | log_sum = np.log(np.sum(np.exp(y), axis=axis, keepdims=True)) |
| 343 | return np.asarray(y - log_sum, dtype=np.float64) |
| 344 | |
| 345 | |
| 346 | __all__ = ["MLXDifferentialBackend"] |