@@ -1,7 +1,7 @@ |
| 1 | | -"""PEFT safetensors → MLX-LM `.npz` LoRA-adapter converter. |
| 1 | +"""PEFT safetensors → MLX-LM LoRA-adapter converter. |
| 2 | 2 | |
| 3 | | -Sprint 21 ships MLX as a second inference backend on Apple Silicon. |
| 4 | | -PEFT writes LoRA weights as `adapter_model.safetensors` with keys like: |
| 3 | +Ships MLX as a second inference backend on Apple Silicon. PEFT writes |
| 4 | +LoRA weights as `adapter_model.safetensors` with keys like: |
| 5 | 5 | |
| 6 | 6 | base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight |
| 7 | 7 | base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight |
@@ -14,23 +14,26 @@ PEFT writes LoRA weights as `adapter_model.safetensors` with keys like: |
| 14 | 14 | (no `base_model` prefix, lowercase `lora_a`/`lora_b`, `.weight` stripped |
| 15 | 15 | because mlx-lm adapter files store bare tensors under the final segment). |
| 16 | 16 | |
| 17 | | -The converter is split so the key-mapping logic is pure and |
| 18 | | -unit-testable; the tensor I/O layer is a thin wrapper around |
| 19 | | -`safetensors.torch.load_file` + `numpy.savez`. We write `.npz` via |
| 20 | | -numpy (not `mx.savez`) so conversion itself does not require MLX to |
| 21 | | -be importable — the artifact is still MLX-loadable because `mx.load` |
| 22 | | -understands numpy's `.npz` format. |
| 17 | +**Output format: `adapters.safetensors`.** Current mlx-lm |
| 18 | +(`tuner/utils.py:137`) hardcodes `model.load_weights(adapters.safetensors)` |
| 19 | +— the `.npz` path worked on earlier mlx_lm releases but no longer. |
| 20 | +We write safetensors directly. Conversion itself doesn't require MLX |
| 21 | +to be importable — the safetensors file is written via the pure |
| 22 | +`safetensors.torch` dependency. |
| 23 | + |
| 24 | +**mlx-lm `adapter_config.json` schema.** mlx-lm's loader builds a |
| 25 | +`types.SimpleNamespace` from the file and reads `config.num_layers` + |
| 26 | +`config.lora_parameters`. PEFT's config has neither. `build_mlx_adapter_config` |
| 27 | +translates PEFT-shape into mlx-lm-shape using the HF base config's |
| 28 | +`num_hidden_layers` + PEFT's `r` / `lora_alpha` / `lora_dropout` / |
| 29 | +`target_modules`. |
| 23 | 30 | """ |
| 24 | 31 | |
| 25 | 32 | from __future__ import annotations |
| 26 | 33 | |
| 27 | 34 | import re |
| 28 | 35 | from pathlib import Path |
| 29 | | -from typing import TYPE_CHECKING |
| 30 | | - |
| 31 | | -if TYPE_CHECKING: |
| 32 | | - from numpy.typing import NDArray |
| 33 | | - |
| 36 | +from typing import Any |
| 34 | 37 | |
| 35 | 38 | _PEFT_PREFIX = re.compile(r"^base_model\.model\.") |
| 36 | 39 | """PEFT wraps the HF base in a double `base_model.model.` — the outer |
@@ -93,11 +96,15 @@ def map_all_keys(peft_keys: list[str]) -> dict[str, str]: |
| 93 | 96 | return mapping |
| 94 | 97 | |
| 95 | 98 | |
| 96 | | -def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps |
| 99 | +def peft_safetensors_to_mlx_safetensors( # pragma: no cover - I/O + torch deps |
| 97 | 100 | peft_adapter_dir: Path, |
| 98 | | - mlx_npz_path: Path, |
| 101 | + mlx_safetensors_path: Path, |
| 99 | 102 | ) -> dict[str, str]: |
| 100 | | - """Convert `<adapter>/adapter_model.safetensors` → `<mlx_npz_path>`. |
| 103 | + """Convert `<adapter>/adapter_model.safetensors` → `<mlx_safetensors_path>`. |
| 104 | + |
| 105 | + mlx-lm's current loader reads `adapters.safetensors` (not `.npz`); |
| 106 | + we write safetensors with MLX-shaped keys so `model.load_weights` |
| 107 | + accepts the file without further translation. |
| 101 | 108 | |
| 102 | 109 | Returns the key mapping actually written (peft_key → mlx_key) so |
| 103 | 110 | callers can log it when `--verbose`. |
@@ -105,8 +112,7 @@ def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps |
| 105 | 112 | Pragma'd: exercised end-to-end by the slow parity integration test |
| 106 | 113 | (covered via `map_all_keys` unit tests for the logic). |
| 107 | 114 | """ |
| 108 | | - import numpy as np |
| 109 | | - from safetensors.torch import load_file |
| 115 | + from safetensors.torch import load_file, save_file |
| 110 | 116 | |
| 111 | 117 | src = peft_adapter_dir / "adapter_model.safetensors" |
| 112 | 118 | if not src.exists(): |
@@ -115,13 +121,66 @@ def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps |
| 115 | 121 | tensors = load_file(str(src)) |
| 116 | 122 | mapping = map_all_keys(list(tensors.keys())) |
| 117 | 123 | |
| 118 | | - np_tensors: dict[str, NDArray[np.float32]] = {} |
| 119 | | - for peft_key, mlx_key in mapping.items(): |
| 120 | | - tensor = tensors[peft_key] |
| 121 | | - # safetensors.torch.load_file returns torch.Tensor; .numpy() |
| 122 | | - # is the standard bridge. fp16 is preserved across the write. |
| 123 | | - np_tensors[mlx_key] = tensor.detach().cpu().numpy() |
| 124 | + mlx_tensors = {mlx_key: tensors[peft_key] for peft_key, mlx_key in mapping.items()} |
| 124 | 125 | |
| 125 | | - mlx_npz_path.parent.mkdir(parents=True, exist_ok=True) |
| 126 | | - np.savez(str(mlx_npz_path), **np_tensors) # type: ignore[arg-type] |
| 126 | + mlx_safetensors_path.parent.mkdir(parents=True, exist_ok=True) |
| 127 | + save_file(mlx_tensors, str(mlx_safetensors_path)) |
| 127 | 128 | return mapping |
| 129 | + |
| 130 | + |
| 131 | +def build_mlx_adapter_config( |
| 132 | + peft_config: dict[str, Any], |
| 133 | + base_num_hidden_layers: int, |
| 134 | +) -> dict[str, Any]: |
| 135 | + """Translate a PEFT `adapter_config.json` into mlx-lm's schema. |
| 136 | + |
| 137 | + mlx-lm's `load_adapters` (tuner/utils.py) reads: |
| 138 | + |
| 139 | + - `config.num_layers` — how many trailing layers receive LoRA |
| 140 | + (matches mlx-lm convention: -1 = all; a positive N = last N). |
| 141 | + We emit the base model's `num_hidden_layers` so every layer |
| 142 | + gets the adapter — matches PEFT default when `layers_to_transform` |
| 143 | + isn't set. |
| 144 | + - `config.lora_parameters.rank` ← PEFT `r` |
| 145 | + - `config.lora_parameters.scale` ← PEFT `lora_alpha / r` |
| 146 | + - `config.lora_parameters.dropout` ← PEFT `lora_dropout` |
| 147 | + - `config.lora_parameters.keys` ← PEFT `target_modules` |
| 148 | + - `config.fine_tune_type` — "lora" unless PEFT `use_dora=True`, |
| 149 | + in which case "dora". |
| 150 | + |
| 151 | + Fails loud (`MlxConversionError`) when the PEFT config is missing |
| 152 | + fields we cannot substitute — `r` and `target_modules` in particular |
| 153 | + are load-bearing. |
| 154 | + """ |
| 155 | + try: |
| 156 | + rank = int(peft_config["r"]) |
| 157 | + except (KeyError, TypeError, ValueError) as exc: |
| 158 | + raise MlxConversionError( |
| 159 | + f"PEFT adapter_config.json missing or non-integer 'r' (LoRA rank): {exc}" |
| 160 | + ) from exc |
| 161 | + target_modules = peft_config.get("target_modules") |
| 162 | + if not isinstance(target_modules, list) or not target_modules: |
| 163 | + raise MlxConversionError( |
| 164 | + "PEFT adapter_config.json 'target_modules' must be a non-empty list; got " |
| 165 | + f"{target_modules!r}. mlx-lm needs this to wire LoRA into the right ops." |
| 166 | + ) |
| 167 | + lora_alpha = float(peft_config.get("lora_alpha", rank)) |
| 168 | + lora_dropout = float(peft_config.get("lora_dropout", 0.0)) |
| 169 | + use_dora = bool(peft_config.get("use_dora", False)) |
| 170 | + |
| 171 | + if base_num_hidden_layers < 1: |
| 172 | + raise MlxConversionError( |
| 173 | + f"base model reports num_hidden_layers={base_num_hidden_layers} (expected >=1); " |
| 174 | + "cannot stage mlx adapter without a valid layer count" |
| 175 | + ) |
| 176 | + |
| 177 | + return { |
| 178 | + "fine_tune_type": "dora" if use_dora else "lora", |
| 179 | + "num_layers": int(base_num_hidden_layers), |
| 180 | + "lora_parameters": { |
| 181 | + "rank": rank, |
| 182 | + "scale": lora_alpha / rank if rank else float(lora_alpha), |
| 183 | + "dropout": lora_dropout, |
| 184 | + "keys": list(target_modules), |
| 185 | + }, |
| 186 | + } |