tenseleyflow/documentlanguagemodel / 91d00d3

Browse files

sway(backends): partial MLX backend for pre-converted .npz adapters

Authored by espadonne
SHA
91d00d30ad26471d43ed050a2d95cd2348214591
Parents
2285cc2
Tree
1554acf

3 changed files

StatusFile+-
M sway/src/dlm_sway/backends/__init__.py 9 6
A sway/src/dlm_sway/backends/mlx.py 205 0
M sway/tests/unit/test_backend_registry.py 12 1
sway/src/dlm_sway/backends/__init__.pymodified
@@ -11,7 +11,7 @@ from __future__ import annotations
1111
 from pathlib import Path
1212
 from typing import TYPE_CHECKING
1313
 
14
-from dlm_sway.core.errors import BackendNotAvailableError, SpecValidationError
14
+from dlm_sway.core.errors import SpecValidationError
1515
 from dlm_sway.core.model import ModelSpec
1616
 
1717
 if TYPE_CHECKING:
@@ -48,11 +48,14 @@ def build(base_spec: ModelSpec, *, adapter_path: Path | None = None) -> Differen
4848
         return HuggingFaceDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter)
4949
 
5050
     if base_spec.kind == "mlx":
51
-        raise BackendNotAvailableError(
52
-            "mlx",
53
-            extra="mlx",
54
-            hint="MLX backend shipping in a later milestone.",
55
-        )
51
+        if effective_adapter is None:
52
+            raise SpecValidationError(
53
+                "mlx backend requires an adapter path (set `adapter:` on the ft model; "
54
+                "must be an MLX .npz adapter — use dlm's peft→mlx converter if needed)"
55
+            )
56
+        from dlm_sway.backends.mlx import MLXDifferentialBackend
57
+
58
+        return MLXDifferentialBackend(base_spec=base_spec, adapter_path=effective_adapter)
5659
 
5760
     if base_spec.kind == "custom":
5861
         return _load_custom(base_spec, effective_adapter)
sway/src/dlm_sway/backends/mlx.pyadded
@@ -0,0 +1,205 @@
1
+"""MLX backend for Apple Silicon (darwin-arm64).
2
+
3
+Partial implementation covering the common case: a PEFT adapter that's
4
+already been converted to MLX's ``.npz`` format. Unlike the HF backend,
5
+MLX has no runtime ``disable_adapter`` context — adapters get fused into
6
+the linear layers at load time — so this backend keeps **both** a base
7
+model and an adapted model in memory. Fine for the small (<3B) models
8
+MLX is typically used with on Apple Silicon; document the cost clearly.
9
+
10
+If users point this backend at raw PEFT safetensors, ``mlx_lm.load``
11
+will refuse them with its own error. A future milestone can wire a
12
+PEFT-→-MLX converter; for now the contract is "bring your own .npz".
13
+"""
14
+
15
+from __future__ import annotations
16
+
17
+from collections.abc import Iterator
18
+from contextlib import contextmanager
19
+from dataclasses import dataclass
20
+from pathlib import Path
21
+from typing import TYPE_CHECKING, Any
22
+
23
+import numpy as np
24
+
25
+from dlm_sway.core.errors import BackendNotAvailableError, ProbeError
26
+from dlm_sway.core.model import ModelSpec
27
+from dlm_sway.core.scoring import RollingLogprob, TokenDist
28
+
29
+if TYPE_CHECKING:
30
+    pass
31
+
32
+
33
+def _require_mlx() -> tuple[Any, Any]:
34
+    try:
35
+        import mlx.core as mx
36
+        import mlx_lm
37
+    except ImportError as exc:
38
+        raise BackendNotAvailableError(
39
+            "mlx",
40
+            extra="mlx",
41
+            hint="MLX backend needs mlx + mlx-lm on darwin-arm64.",
42
+        ) from exc
43
+    return mx, mlx_lm
44
+
45
+
46
+@dataclass(slots=True)
47
+class _MLXView:
48
+    """One side (base or ft) of the MLX backend.
49
+
50
+    Both sides carry the same tokenizer (MLX stores it alongside the
51
+    converted model files, so sharing avoids double-loading).
52
+    """
53
+
54
+    id: str
55
+    _model: Any
56
+    _tokenizer: Any
57
+
58
+    def generate(
59
+        self,
60
+        prompt: str,
61
+        *,
62
+        max_new_tokens: int,
63
+        temperature: float = 0.0,
64
+        top_p: float = 1.0,
65
+        seed: int = 0,
66
+    ) -> str:
67
+        del seed  # mlx_lm.generate seeds via its own global state
68
+        _, mlx_lm = _require_mlx()
69
+        kwargs: dict[str, Any] = {"max_tokens": max_new_tokens, "verbose": False}
70
+        if temperature > 0.0:
71
+            kwargs["temp"] = temperature
72
+            kwargs["top_p"] = top_p
73
+        out = mlx_lm.generate(self._model, self._tokenizer, prompt=prompt, **kwargs)
74
+        return str(out)
75
+
76
+    def close(self) -> None:
77
+        return None
78
+
79
+    # -- ScoringBackend ------------------------------------------------
80
+
81
+    def _forward_logits(self, prompt: str) -> np.ndarray:
82
+        """Run the model once and return ``(seq_len, vocab)`` logits."""
83
+        mx, _ = _require_mlx()
84
+        input_ids = self._tokenizer.encode(prompt)
85
+        tokens = mx.array(input_ids)[None, :]  # (1, T)
86
+        out = self._model(tokens)
87
+        # mlx_lm models return an mx.array; convert to numpy for downstream math.
88
+        return np.asarray(out[0])
89
+
90
+    def logprob_of(self, prompt: str, completion: str) -> float:
91
+        input_ids = self._tokenizer.encode(prompt)
92
+        full_ids = self._tokenizer.encode(prompt + completion)
93
+        if len(full_ids) <= len(input_ids):
94
+            raise ProbeError(
95
+                "logprob_of",
96
+                f"completion tokenized to zero tokens (prompt={prompt!r}, completion={completion!r})",
97
+            )
98
+        logits = self._forward_logits(prompt + completion)  # (T, V)
99
+        # Position t predicts token t+1 — slice off the last row and the prompt span.
100
+        shift = logits[len(input_ids) - 1 : -1, :]
101
+        target_ids = np.asarray(full_ids[len(input_ids) :], dtype=np.int64)
102
+        log_probs = _log_softmax(shift.astype(np.float64), axis=-1)
103
+        gathered = log_probs[np.arange(len(target_ids)), target_ids]
104
+        return float(gathered.sum())
105
+
106
+    def rolling_logprob(self, text: str) -> RollingLogprob:
107
+        ids = self._tokenizer.encode(text)
108
+        if len(ids) < 2:
109
+            return RollingLogprob(
110
+                token_ids=np.asarray(ids, dtype=np.int64),
111
+                logprobs=np.array([], dtype=np.float32),
112
+                num_tokens=len(ids),
113
+                total_logprob=0.0,
114
+            )
115
+        logits = self._forward_logits(text)
116
+        log_probs = _log_softmax(logits[:-1].astype(np.float64), axis=-1)
117
+        ids_arr = np.asarray(ids, dtype=np.int64)
118
+        gathered = log_probs[np.arange(len(ids) - 1), ids_arr[1:]]
119
+        return RollingLogprob(
120
+            token_ids=ids_arr,
121
+            logprobs=gathered.astype(np.float32),
122
+            num_tokens=len(ids),
123
+            total_logprob=float(gathered.sum()),
124
+        )
125
+
126
+    def next_token_dist(self, prompt: str, *, top_k: int = 256) -> TokenDist:
127
+        logits = self._forward_logits(prompt)
128
+        last_logits = logits[-1].astype(np.float64)
129
+        log_probs = _log_softmax(last_logits, axis=-1)
130
+        k = min(top_k, log_probs.shape[0])
131
+        # np.argpartition for top-k then sort the partition.
132
+        part = np.argpartition(log_probs, -k)[-k:]
133
+        top_ids = part[np.argsort(log_probs[part])[::-1]]
134
+        top_lp = log_probs[top_ids]
135
+        tail_mass = float(1.0 - np.exp(top_lp).sum())
136
+        tail_logprob = float(np.log(max(tail_mass, 1e-12))) if tail_mass > 1e-12 else 0.0
137
+        return TokenDist(
138
+            token_ids=top_ids.astype(np.int64),
139
+            logprobs=top_lp.astype(np.float32),
140
+            vocab_size=int(log_probs.shape[0]),
141
+            tail_logprob=tail_logprob,
142
+        )
143
+
144
+
145
+class MLXDifferentialBackend:
146
+    """A :class:`~dlm_sway.core.scoring.DifferentialBackend` for MLX models.
147
+
148
+    Loads two copies of the same base model — one bare, one with the
149
+    adapter fused — because MLX has no runtime toggle. Memory cost: 2×
150
+    base weights. On typical Apple Silicon workloads with ≤3B models
151
+    this is acceptable.
152
+    """
153
+
154
+    def __init__(self, *, base_spec: ModelSpec, adapter_path: Path) -> None:
155
+        mx, mlx_lm = _require_mlx()
156
+        self._mx = mx
157
+        self._spec = base_spec
158
+        self._adapter_path = Path(adapter_path).expanduser().resolve()
159
+
160
+        # Load bare base (no adapter).
161
+        self._base_model, self._tokenizer = mlx_lm.load(base_spec.base)
162
+        # Load ft with adapter attached. ``adapter_path`` is mlx_lm's kwarg.
163
+        self._ft_model, _ = mlx_lm.load(base_spec.base, adapter_path=str(self._adapter_path))
164
+        self._active: str | None = None
165
+
166
+    @contextmanager
167
+    def as_base(self) -> Iterator[_MLXView]:
168
+        self._enter("base")
169
+        try:
170
+            yield _MLXView(id="base", _model=self._base_model, _tokenizer=self._tokenizer)
171
+        finally:
172
+            self._exit()
173
+
174
+    @contextmanager
175
+    def as_finetuned(self) -> Iterator[_MLXView]:
176
+        self._enter("ft")
177
+        try:
178
+            yield _MLXView(id="ft", _model=self._ft_model, _tokenizer=self._tokenizer)
179
+        finally:
180
+            self._exit()
181
+
182
+    def close(self) -> None:
183
+        """MLX reclaims memory when references drop; nothing to do here."""
184
+        return
185
+
186
+    def _enter(self, mode: str) -> None:
187
+        if self._active is not None:
188
+            raise RuntimeError(
189
+                f"MLXDifferentialBackend view {self._active!r} already active; "
190
+                f"exit it before entering {mode!r}."
191
+            )
192
+        self._active = mode
193
+
194
+    def _exit(self) -> None:
195
+        self._active = None
196
+
197
+
198
+def _log_softmax(x: np.ndarray, *, axis: int) -> np.ndarray:
199
+    x_max = np.max(x, axis=axis, keepdims=True)
200
+    y = x - x_max
201
+    log_sum = np.log(np.sum(np.exp(y), axis=axis, keepdims=True))
202
+    return np.asarray(y - log_sum, dtype=np.float64)
203
+
204
+
205
+__all__ = ["MLXDifferentialBackend"]
sway/tests/unit/test_backend_registry.pymodified
@@ -26,7 +26,18 @@ class TestRegistry:
2626
         with pytest.raises(SpecValidationError, match="adapter"):
2727
             build(ModelSpec(base="x", kind="hf"))
2828
 
29
-    def test_mlx_not_yet_available(self) -> None:
29
+    def test_mlx_requires_adapter(self) -> None:
30
+        with pytest.raises(SpecValidationError, match="adapter"):
31
+            build(ModelSpec(base="x", kind="mlx"))
32
+
33
+    def test_mlx_dispatch_raises_when_mlx_missing(self) -> None:
34
+        # On non-Apple-Silicon (or Apple without mlx installed), constructing
35
+        # the MLX backend raises BackendNotAvailableError with a pip hint.
36
+        # We skip this assertion if mlx happens to be installed.
37
+        import importlib.util
38
+
39
+        if importlib.util.find_spec("mlx") is not None:
40
+            pytest.skip("mlx is installed; error path not exercised")
3041
         with pytest.raises(BackendNotAvailableError) as exc_info:
3142
             build(ModelSpec(base="x", kind="mlx", adapter=Path("/tmp/a")))
3243
         assert exc_info.value.backend == "mlx"