Python · 10794 bytes Raw Blame History
1 """PEFT → MLX-LM LoRA adapter converter (Sprint 24, audit F01).
2
3 Closes the headline doc-vs-code gap from Audit 03: the README pitched
4 HF / MLX / HTTP backends as co-equal, but the MLX path needed a
5 pre-converted ``.npz`` adapter that nothing in the toolchain produced.
6 With this module, ``dlm train`` → ``sway run`` on the MLX backend
7 works end-to-end.
8
9 The converter is **pure I/O** — it reads PEFT's ``adapter_model.safetensors``
10 + ``adapter_config.json``, transposes the LoRA matrices to MLX's layout,
11 and writes ``adapters.safetensors`` + ``adapter_config.json`` in
12 mlx-lm's expected shape. No torch dependency, no model load — runs on
13 CPU in milliseconds even for 100M-param adapters.
14
15 ## Format reference (verified against PEFT >= 0.13 + mlx-lm 0.31)
16
17 **PEFT layout:**
18
19 - ``adapter_model.safetensors`` — weight names follow
20 ``base_model.model.<dotted_path>.lora_<A|B>.weight`` (modern PEFT
21 has dropped the ``.default`` adapter-name suffix that older
22 versions shipped).
23 - Shapes: ``lora_A: [r, in_features]``, ``lora_B: [out_features, r]``
24 - ``adapter_config.json`` — fields we care about: ``r``, ``lora_alpha``,
25 ``lora_dropout``, ``target_modules``, ``modules_to_save``.
26
27 **MLX-LM layout (verified by reading
28 ``mlx_lm.tuner.utils.load_adapters``):**
29
30 - ``adapters.safetensors`` — weight names mirror Python attribute paths
31 on the model: ``<dotted_path>.lora_a`` / ``<dotted_path>.lora_b``.
32 Shapes: ``lora_a: [in_features, r]``, ``lora_b: [r, out_features]``.
33 (Both layouts transposed vs PEFT.)
34 - ``adapter_config.json`` — fields ``fine_tune_type`` (= ``"lora"``),
35 ``num_layers`` (count of layers from end to wrap, generous default
36 is fine), and ``lora_parameters`` ``{rank, scale, dropout, keys}``
37 where ``keys`` is the per-layer attribute subset (e.g.
38 ``["self_attn.q_proj", "self_attn.v_proj"]``).
39
40 ## Limitations (called out by the sprint risks)
41
42 - **QLoRA 4-bit adapters** are not supported. PEFT QLoRA stores
43 weights as quantized + scales; conversion would need dequantize-
44 on-convert. Defer until a user asks.
45 - **modules_to_save** (e.g. ``embed_tokens``, ``lm_head``) emits
46 a clear warning — MLX-LM's LoRA loader ignores extra full-tensor
47 modules; the converter writes only the LoRA factors.
48 - **Unsupported ranks** raise ``MlxConvertError``. mlx-lm supports
49 arbitrary rank in principle; the audit's prove-the-value targets
50 ``r ∈ {8, 16, 32, 64}``, the canonical PEFT defaults.
51 """
52
53 from __future__ import annotations
54
55 import json
56 from pathlib import Path
57 from typing import Any
58
59 from dlm_sway.core.errors import SwayError
60
61
62 class MlxConvertError(SwayError):
63 """Raised when a PEFT adapter cannot be converted to MLX-LM format.
64
65 Surfaces structural problems (unsupported PEFT shape, missing
66 files, malformed config) the user can act on, vs. silent fallbacks
67 that produce a broken MLX adapter.
68 """
69
70
71 def convert_peft_to_mlx(src: Path, dst: Path, *, overwrite: bool = False) -> dict[str, Any]:
72 """Convert a PEFT LoRA adapter directory to MLX-LM format.
73
74 Parameters
75 ----------
76 src:
77 PEFT adapter directory containing ``adapter_model.safetensors``
78 and ``adapter_config.json``.
79 dst:
80 Output directory. Created if missing. Must be empty unless
81 ``overwrite=True``.
82 overwrite:
83 Whether to overwrite an existing ``adapters.safetensors`` /
84 ``adapter_config.json`` in ``dst``. Default refuses to clobber.
85
86 Returns
87 -------
88 dict
89 Summary report: ``{rank, scale, num_keys, target_modules,
90 modules_to_save_skipped}``. Useful for the CLI's before/after
91 size + shape print.
92
93 Raises
94 ------
95 MlxConvertError
96 On any structural mismatch (missing files, unsupported PEFT
97 config, dst not empty without overwrite).
98 """
99 src = Path(src)
100 dst = Path(dst)
101
102 src_safetensors = src / "adapter_model.safetensors"
103 src_config = src / "adapter_config.json"
104 if not src_safetensors.exists():
105 raise MlxConvertError(
106 f"PEFT adapter missing {src_safetensors.name}: not a PEFT "
107 f"adapter directory? (looked at {src})"
108 )
109 if not src_config.exists():
110 raise MlxConvertError(
111 f"PEFT adapter missing {src_config.name}: cannot determine "
112 f"rank/alpha/target modules (looked at {src})"
113 )
114
115 config = json.loads(src_config.read_text(encoding="utf-8"))
116 if config.get("peft_type", "").upper() != "LORA":
117 raise MlxConvertError(
118 f"unsupported PEFT type {config.get('peft_type')!r} — only "
119 f"'LORA' is supported (QLoRA / DoRA / IA3 not yet handled)"
120 )
121
122 rank = int(config.get("r") or 0)
123 if rank <= 0:
124 raise MlxConvertError(f"invalid LoRA rank in config: {config.get('r')!r}")
125 lora_alpha = float(config.get("lora_alpha", rank))
126 lora_dropout = float(config.get("lora_dropout", 0.0))
127 target_modules = config.get("target_modules") or []
128 if isinstance(target_modules, str):
129 target_modules = [target_modules]
130 modules_to_save = config.get("modules_to_save") or []
131
132 dst.mkdir(parents=True, exist_ok=True)
133 dst_safetensors = dst / "adapters.safetensors"
134 dst_config = dst / "adapter_config.json"
135 if (dst_safetensors.exists() or dst_config.exists()) and not overwrite:
136 raise MlxConvertError(
137 f"destination {dst} already contains an MLX adapter; pass overwrite=True"
138 )
139
140 # Lazy-import safetensors so the bare ``import dlm_sway`` cost
141 # doesn't pay the safetensors load. The HF backend already pulls
142 # safetensors via the [hf] extra; the converter only needs the
143 # numpy I/O variant which is in the same package.
144 try:
145 from safetensors.numpy import load_file, save_file
146 except ImportError as exc:
147 raise MlxConvertError(
148 "safetensors not installed — required for the MLX converter. "
149 "Install with: pip install 'dlm-sway[hf]' or pip install safetensors"
150 ) from exc
151
152 weights = load_file(str(src_safetensors))
153 converted: dict[str, Any] = {}
154 per_layer_keys: set[str] = set()
155 layer_indices: set[int] = set()
156 skipped_full_modules: list[str] = []
157
158 for key, tensor in weights.items():
159 # PEFT puts everything under base_model.model.* (the wrapper
160 # adds two levels: PeftModel.base_model and the wrapped HF
161 # model). Strip cleanly; bail if a key doesn't match the
162 # expected shape — we'd rather raise than silently emit an
163 # incorrectly-named MLX tensor.
164 if not key.startswith("base_model.model."):
165 raise MlxConvertError(
166 f"unexpected weight key {key!r}: missing 'base_model.model.' "
167 f"prefix the PEFT wrapper produces"
168 )
169 path = key[len("base_model.model.") :]
170
171 # PEFT key tail: <attr_path>.lora_A.weight or .lora_B.weight.
172 # Some older PEFT versions added .default in between; tolerate
173 # that for forward-compat with stored artifacts.
174 for suffix, mlx_field in (
175 (".lora_A.default.weight", "lora_a"),
176 (".lora_B.default.weight", "lora_b"),
177 (".lora_A.weight", "lora_a"),
178 (".lora_B.weight", "lora_b"),
179 ):
180 if path.endswith(suffix):
181 parent = path[: -len(suffix)]
182 # Transpose: PEFT lora_A is (r, in) → MLX lora_a (in, r);
183 # PEFT lora_B is (out, r) → MLX lora_b (r, out).
184 converted[f"{parent}.{mlx_field}"] = tensor.T.copy()
185 # Track the per-layer attribute path for MLX's `keys`
186 # field. Strip the model.layers.<N>. prefix when
187 # present so keys are layer-relative.
188 rel_key = _strip_layer_prefix(parent)
189 per_layer_keys.add(rel_key)
190 layer_idx = _extract_layer_index(parent)
191 if layer_idx is not None:
192 layer_indices.add(layer_idx)
193 break
194 else:
195 # Modules_to_save (full-weight overrides like embed_tokens,
196 # lm_head) won't end in .lora_A/B — flag and skip. MLX's
197 # LoRA loader ignores them; copying wouldn't help.
198 if any(m in path for m in modules_to_save) if modules_to_save else False:
199 skipped_full_modules.append(key)
200 continue
201 raise MlxConvertError(
202 f"unexpected PEFT weight {key!r}: doesn't match "
203 f"lora_A/lora_B suffix and isn't in modules_to_save"
204 )
205
206 if not converted:
207 raise MlxConvertError(f"no LoRA weights extracted from {src_safetensors} — empty adapter?")
208
209 # num_layers — generous default that covers the actual base model.
210 # mlx-lm slices model.layers[-num_layers:] so over-counting just
211 # picks up the whole list.
212 num_layers = max(layer_indices, default=0) + 1 if layer_indices else 32
213
214 # MLX scale = lora_alpha / r (matches PEFT's effective scale).
215 scale = lora_alpha / rank
216
217 mlx_config: dict[str, Any] = {
218 "fine_tune_type": "lora",
219 "num_layers": num_layers,
220 "lora_parameters": {
221 "rank": rank,
222 "scale": scale,
223 "dropout": lora_dropout,
224 "keys": sorted(per_layer_keys),
225 },
226 }
227
228 save_file(converted, str(dst_safetensors))
229 dst_config.write_text(json.dumps(mlx_config, indent=2) + "\n", encoding="utf-8")
230
231 return {
232 "rank": rank,
233 "scale": scale,
234 "num_keys": len(converted),
235 "num_layers": num_layers,
236 "target_modules": list(target_modules),
237 "modules_to_save_skipped": skipped_full_modules,
238 "src_bytes": src_safetensors.stat().st_size,
239 "dst_bytes": dst_safetensors.stat().st_size,
240 }
241
242
243 def _strip_layer_prefix(attr_path: str) -> str:
244 """Return the per-layer-relative path for a full attribute path.
245
246 ``model.layers.5.self_attn.q_proj`` → ``self_attn.q_proj``
247 ``transformer.h.0.attn.c_attn`` → ``attn.c_attn``
248 ``model.embed_tokens`` (no layer prefix) → unchanged.
249
250 MLX's ``adapter_config.json`` ``keys`` field is checked against
251 each layer's ``named_modules()`` paths — those are layer-relative,
252 not model-rooted.
253 """
254 parts = attr_path.split(".")
255 for i, p in enumerate(parts):
256 if p.isdigit() and i + 1 < len(parts):
257 return ".".join(parts[i + 1 :])
258 return attr_path
259
260
261 def _extract_layer_index(attr_path: str) -> int | None:
262 """Return the first numeric-segment index in ``attr_path``, or
263 None when no layer index is present (e.g. embedding overrides).
264 """
265 for p in attr_path.split("."):
266 if p.isdigit():
267 return int(p)
268 return None