tenseleyflow/documentlanguagemodel / 089eee4

Browse files

fix(inference): MLX adapter — write adapters.safetensors + mlx-lm config schema (audit-11 B1)

mlx_lm.load_adapters (current API) reads adapters.safetensors and expects
adapter_config.json with `num_layers` + `lora_parameters` — neither of
which PEFT's config has. The prior path copied PEFT's adapter_config.json
verbatim and wrote adapters.npz; result: every first-run `dlm prompt`
default auto-backend crashed with `AttributeError: 'SimpleNamespace'
object has no attribute 'num_layers'`.

- peft_safetensors_to_mlx_safetensors: writes .safetensors instead of
.npz (mlx-lm hardcodes the former now).
- build_mlx_adapter_config: translates PEFT shape (r, lora_alpha,
lora_dropout, target_modules, use_dora) → mlx-lm shape (num_layers,
lora_parameters.{rank,scale,dropout,keys}, fine_tune_type).
Authored by espadonne
SHA
089eee4af4bd434027320c1fca89f81625857a43
Parents
dab9a73
Tree
e566c6e

1 changed file

StatusFile+-
M src/dlm/inference/mlx_adapter.py 86 27
src/dlm/inference/mlx_adapter.pymodified
@@ -1,7 +1,7 @@
1
-"""PEFT safetensors → MLX-LM `.npz` LoRA-adapter converter.
1
+"""PEFT safetensors → MLX-LM LoRA-adapter converter.
22
 
3
-Sprint 21 ships MLX as a second inference backend on Apple Silicon.
4
-PEFT writes LoRA weights as `adapter_model.safetensors` with keys like:
3
+Ships MLX as a second inference backend on Apple Silicon. PEFT writes
4
+LoRA weights as `adapter_model.safetensors` with keys like:
55
 
66
     base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
77
     base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight
@@ -14,23 +14,26 @@ PEFT writes LoRA weights as `adapter_model.safetensors` with keys like:
1414
 (no `base_model` prefix, lowercase `lora_a`/`lora_b`, `.weight` stripped
1515
 because mlx-lm adapter files store bare tensors under the final segment).
1616
 
17
-The converter is split so the key-mapping logic is pure and
18
-unit-testable; the tensor I/O layer is a thin wrapper around
19
-`safetensors.torch.load_file` + `numpy.savez`. We write `.npz` via
20
-numpy (not `mx.savez`) so conversion itself does not require MLX to
21
-be importable — the artifact is still MLX-loadable because `mx.load`
22
-understands numpy's `.npz` format.
17
+**Output format: `adapters.safetensors`.** Current mlx-lm
18
+(`tuner/utils.py:137`) hardcodes `model.load_weights(adapters.safetensors)`
19
+— the `.npz` path worked on earlier mlx_lm releases but no longer.
20
+We write safetensors directly. Conversion itself doesn't require MLX
21
+to be importable — the safetensors file is written via the pure
22
+`safetensors.torch` dependency.
23
+
24
+**mlx-lm `adapter_config.json` schema.** mlx-lm's loader builds a
25
+`types.SimpleNamespace` from the file and reads `config.num_layers` +
26
+`config.lora_parameters`. PEFT's config has neither. `build_mlx_adapter_config`
27
+translates PEFT-shape into mlx-lm-shape using the HF base config's
28
+`num_hidden_layers` + PEFT's `r` / `lora_alpha` / `lora_dropout` /
29
+`target_modules`.
2330
 """
2431
 
2532
 from __future__ import annotations
2633
 
2734
 import re
2835
 from pathlib import Path
29
-from typing import TYPE_CHECKING
30
-
31
-if TYPE_CHECKING:
32
-    from numpy.typing import NDArray
33
-
36
+from typing import Any
3437
 
3538
 _PEFT_PREFIX = re.compile(r"^base_model\.model\.")
3639
 """PEFT wraps the HF base in a double `base_model.model.` — the outer
@@ -93,11 +96,15 @@ def map_all_keys(peft_keys: list[str]) -> dict[str, str]:
9396
     return mapping
9497
 
9598
 
96
-def peft_safetensors_to_mlx_npz(  # pragma: no cover - I/O + torch deps
99
+def peft_safetensors_to_mlx_safetensors(  # pragma: no cover - I/O + torch deps
97100
     peft_adapter_dir: Path,
98
-    mlx_npz_path: Path,
101
+    mlx_safetensors_path: Path,
99102
 ) -> dict[str, str]:
100
-    """Convert `<adapter>/adapter_model.safetensors` → `<mlx_npz_path>`.
103
+    """Convert `<adapter>/adapter_model.safetensors` → `<mlx_safetensors_path>`.
104
+
105
+    mlx-lm's current loader reads `adapters.safetensors` (not `.npz`);
106
+    we write safetensors with MLX-shaped keys so `model.load_weights`
107
+    accepts the file without further translation.
101108
 
102109
     Returns the key mapping actually written (peft_key → mlx_key) so
103110
     callers can log it when `--verbose`.
@@ -105,8 +112,7 @@ def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps
105112
     Pragma'd: exercised end-to-end by the slow parity integration test
106113
     (covered via `map_all_keys` unit tests for the logic).
107114
     """
108
-    import numpy as np
109
-    from safetensors.torch import load_file
115
+    from safetensors.torch import load_file, save_file
110116
 
111117
     src = peft_adapter_dir / "adapter_model.safetensors"
112118
     if not src.exists():
@@ -115,13 +121,66 @@ def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps
115121
     tensors = load_file(str(src))
116122
     mapping = map_all_keys(list(tensors.keys()))
117123
 
118
-    np_tensors: dict[str, NDArray[np.float32]] = {}
119
-    for peft_key, mlx_key in mapping.items():
120
-        tensor = tensors[peft_key]
121
-        # safetensors.torch.load_file returns torch.Tensor; .numpy()
122
-        # is the standard bridge. fp16 is preserved across the write.
123
-        np_tensors[mlx_key] = tensor.detach().cpu().numpy()
124
+    mlx_tensors = {mlx_key: tensors[peft_key] for peft_key, mlx_key in mapping.items()}
124125
 
125
-    mlx_npz_path.parent.mkdir(parents=True, exist_ok=True)
126
-    np.savez(str(mlx_npz_path), **np_tensors)  # type: ignore[arg-type]
126
+    mlx_safetensors_path.parent.mkdir(parents=True, exist_ok=True)
127
+    save_file(mlx_tensors, str(mlx_safetensors_path))
127128
     return mapping
129
+
130
+
131
+def build_mlx_adapter_config(
132
+    peft_config: dict[str, Any],
133
+    base_num_hidden_layers: int,
134
+) -> dict[str, Any]:
135
+    """Translate a PEFT `adapter_config.json` into mlx-lm's schema.
136
+
137
+    mlx-lm's `load_adapters` (tuner/utils.py) reads:
138
+
139
+    - `config.num_layers` — how many trailing layers receive LoRA
140
+      (matches mlx-lm convention: -1 = all; a positive N = last N).
141
+      We emit the base model's `num_hidden_layers` so every layer
142
+      gets the adapter — matches PEFT default when `layers_to_transform`
143
+      isn't set.
144
+    - `config.lora_parameters.rank` ← PEFT `r`
145
+    - `config.lora_parameters.scale` ← PEFT `lora_alpha / r`
146
+    - `config.lora_parameters.dropout` ← PEFT `lora_dropout`
147
+    - `config.lora_parameters.keys` ← PEFT `target_modules`
148
+    - `config.fine_tune_type` — "lora" unless PEFT `use_dora=True`,
149
+      in which case "dora".
150
+
151
+    Fails loud (`MlxConversionError`) when the PEFT config is missing
152
+    fields we cannot substitute — `r` and `target_modules` in particular
153
+    are load-bearing.
154
+    """
155
+    try:
156
+        rank = int(peft_config["r"])
157
+    except (KeyError, TypeError, ValueError) as exc:
158
+        raise MlxConversionError(
159
+            f"PEFT adapter_config.json missing or non-integer 'r' (LoRA rank): {exc}"
160
+        ) from exc
161
+    target_modules = peft_config.get("target_modules")
162
+    if not isinstance(target_modules, list) or not target_modules:
163
+        raise MlxConversionError(
164
+            "PEFT adapter_config.json 'target_modules' must be a non-empty list; got "
165
+            f"{target_modules!r}. mlx-lm needs this to wire LoRA into the right ops."
166
+        )
167
+    lora_alpha = float(peft_config.get("lora_alpha", rank))
168
+    lora_dropout = float(peft_config.get("lora_dropout", 0.0))
169
+    use_dora = bool(peft_config.get("use_dora", False))
170
+
171
+    if base_num_hidden_layers < 1:
172
+        raise MlxConversionError(
173
+            f"base model reports num_hidden_layers={base_num_hidden_layers} (expected >=1); "
174
+            "cannot stage mlx adapter without a valid layer count"
175
+        )
176
+
177
+    return {
178
+        "fine_tune_type": "dora" if use_dora else "lora",
179
+        "num_layers": int(base_num_hidden_layers),
180
+        "lora_parameters": {
181
+            "rank": rank,
182
+            "scale": lora_alpha / rank if rank else float(lora_alpha),
183
+            "dropout": lora_dropout,
184
+            "keys": list(target_modules),
185
+        },
186
+    }