"""PEFT → MLX-LM LoRA adapter converter (Sprint 24, audit F01). Closes the headline doc-vs-code gap from Audit 03: the README pitched HF / MLX / HTTP backends as co-equal, but the MLX path needed a pre-converted ``.npz`` adapter that nothing in the toolchain produced. With this module, ``dlm train`` → ``sway run`` on the MLX backend works end-to-end. The converter is **pure I/O** — it reads PEFT's ``adapter_model.safetensors`` + ``adapter_config.json``, transposes the LoRA matrices to MLX's layout, and writes ``adapters.safetensors`` + ``adapter_config.json`` in mlx-lm's expected shape. No torch dependency, no model load — runs on CPU in milliseconds even for 100M-param adapters. ## Format reference (verified against PEFT >= 0.13 + mlx-lm 0.31) **PEFT layout:** - ``adapter_model.safetensors`` — weight names follow ``base_model.model..lora_.weight`` (modern PEFT has dropped the ``.default`` adapter-name suffix that older versions shipped). - Shapes: ``lora_A: [r, in_features]``, ``lora_B: [out_features, r]`` - ``adapter_config.json`` — fields we care about: ``r``, ``lora_alpha``, ``lora_dropout``, ``target_modules``, ``modules_to_save``. **MLX-LM layout (verified by reading ``mlx_lm.tuner.utils.load_adapters``):** - ``adapters.safetensors`` — weight names mirror Python attribute paths on the model: ``.lora_a`` / ``.lora_b``. Shapes: ``lora_a: [in_features, r]``, ``lora_b: [r, out_features]``. (Both layouts transposed vs PEFT.) - ``adapter_config.json`` — fields ``fine_tune_type`` (= ``"lora"``), ``num_layers`` (count of layers from end to wrap, generous default is fine), and ``lora_parameters`` ``{rank, scale, dropout, keys}`` where ``keys`` is the per-layer attribute subset (e.g. ``["self_attn.q_proj", "self_attn.v_proj"]``). ## Limitations (called out by the sprint risks) - **QLoRA 4-bit adapters** are not supported. PEFT QLoRA stores weights as quantized + scales; conversion would need dequantize- on-convert. Defer until a user asks. - **modules_to_save** (e.g. ``embed_tokens``, ``lm_head``) emits a clear warning — MLX-LM's LoRA loader ignores extra full-tensor modules; the converter writes only the LoRA factors. - **Unsupported ranks** raise ``MlxConvertError``. mlx-lm supports arbitrary rank in principle; the audit's prove-the-value targets ``r ∈ {8, 16, 32, 64}``, the canonical PEFT defaults. """ from __future__ import annotations import json from pathlib import Path from typing import Any from dlm_sway.core.errors import SwayError class MlxConvertError(SwayError): """Raised when a PEFT adapter cannot be converted to MLX-LM format. Surfaces structural problems (unsupported PEFT shape, missing files, malformed config) the user can act on, vs. silent fallbacks that produce a broken MLX adapter. """ def convert_peft_to_mlx(src: Path, dst: Path, *, overwrite: bool = False) -> dict[str, Any]: """Convert a PEFT LoRA adapter directory to MLX-LM format. Parameters ---------- src: PEFT adapter directory containing ``adapter_model.safetensors`` and ``adapter_config.json``. dst: Output directory. Created if missing. Must be empty unless ``overwrite=True``. overwrite: Whether to overwrite an existing ``adapters.safetensors`` / ``adapter_config.json`` in ``dst``. Default refuses to clobber. Returns ------- dict Summary report: ``{rank, scale, num_keys, target_modules, modules_to_save_skipped}``. Useful for the CLI's before/after size + shape print. Raises ------ MlxConvertError On any structural mismatch (missing files, unsupported PEFT config, dst not empty without overwrite). """ src = Path(src) dst = Path(dst) src_safetensors = src / "adapter_model.safetensors" src_config = src / "adapter_config.json" if not src_safetensors.exists(): raise MlxConvertError( f"PEFT adapter missing {src_safetensors.name}: not a PEFT " f"adapter directory? (looked at {src})" ) if not src_config.exists(): raise MlxConvertError( f"PEFT adapter missing {src_config.name}: cannot determine " f"rank/alpha/target modules (looked at {src})" ) config = json.loads(src_config.read_text(encoding="utf-8")) if config.get("peft_type", "").upper() != "LORA": raise MlxConvertError( f"unsupported PEFT type {config.get('peft_type')!r} — only " f"'LORA' is supported (QLoRA / DoRA / IA3 not yet handled)" ) rank = int(config.get("r") or 0) if rank <= 0: raise MlxConvertError(f"invalid LoRA rank in config: {config.get('r')!r}") lora_alpha = float(config.get("lora_alpha", rank)) lora_dropout = float(config.get("lora_dropout", 0.0)) target_modules = config.get("target_modules") or [] if isinstance(target_modules, str): target_modules = [target_modules] modules_to_save = config.get("modules_to_save") or [] dst.mkdir(parents=True, exist_ok=True) dst_safetensors = dst / "adapters.safetensors" dst_config = dst / "adapter_config.json" if (dst_safetensors.exists() or dst_config.exists()) and not overwrite: raise MlxConvertError( f"destination {dst} already contains an MLX adapter; pass overwrite=True" ) # Lazy-import safetensors so the bare ``import dlm_sway`` cost # doesn't pay the safetensors load. The HF backend already pulls # safetensors via the [hf] extra; the converter only needs the # numpy I/O variant which is in the same package. try: from safetensors.numpy import load_file, save_file except ImportError as exc: raise MlxConvertError( "safetensors not installed — required for the MLX converter. " "Install with: pip install 'dlm-sway[hf]' or pip install safetensors" ) from exc weights = load_file(str(src_safetensors)) converted: dict[str, Any] = {} per_layer_keys: set[str] = set() layer_indices: set[int] = set() skipped_full_modules: list[str] = [] for key, tensor in weights.items(): # PEFT puts everything under base_model.model.* (the wrapper # adds two levels: PeftModel.base_model and the wrapped HF # model). Strip cleanly; bail if a key doesn't match the # expected shape — we'd rather raise than silently emit an # incorrectly-named MLX tensor. if not key.startswith("base_model.model."): raise MlxConvertError( f"unexpected weight key {key!r}: missing 'base_model.model.' " f"prefix the PEFT wrapper produces" ) path = key[len("base_model.model.") :] # PEFT key tail: .lora_A.weight or .lora_B.weight. # Some older PEFT versions added .default in between; tolerate # that for forward-compat with stored artifacts. for suffix, mlx_field in ( (".lora_A.default.weight", "lora_a"), (".lora_B.default.weight", "lora_b"), (".lora_A.weight", "lora_a"), (".lora_B.weight", "lora_b"), ): if path.endswith(suffix): parent = path[: -len(suffix)] # Transpose: PEFT lora_A is (r, in) → MLX lora_a (in, r); # PEFT lora_B is (out, r) → MLX lora_b (r, out). converted[f"{parent}.{mlx_field}"] = tensor.T.copy() # Track the per-layer attribute path for MLX's `keys` # field. Strip the model.layers.. prefix when # present so keys are layer-relative. rel_key = _strip_layer_prefix(parent) per_layer_keys.add(rel_key) layer_idx = _extract_layer_index(parent) if layer_idx is not None: layer_indices.add(layer_idx) break else: # Modules_to_save (full-weight overrides like embed_tokens, # lm_head) won't end in .lora_A/B — flag and skip. MLX's # LoRA loader ignores them; copying wouldn't help. if any(m in path for m in modules_to_save) if modules_to_save else False: skipped_full_modules.append(key) continue raise MlxConvertError( f"unexpected PEFT weight {key!r}: doesn't match " f"lora_A/lora_B suffix and isn't in modules_to_save" ) if not converted: raise MlxConvertError(f"no LoRA weights extracted from {src_safetensors} — empty adapter?") # num_layers — generous default that covers the actual base model. # mlx-lm slices model.layers[-num_layers:] so over-counting just # picks up the whole list. num_layers = max(layer_indices, default=0) + 1 if layer_indices else 32 # MLX scale = lora_alpha / r (matches PEFT's effective scale). scale = lora_alpha / rank mlx_config: dict[str, Any] = { "fine_tune_type": "lora", "num_layers": num_layers, "lora_parameters": { "rank": rank, "scale": scale, "dropout": lora_dropout, "keys": sorted(per_layer_keys), }, } save_file(converted, str(dst_safetensors)) dst_config.write_text(json.dumps(mlx_config, indent=2) + "\n", encoding="utf-8") return { "rank": rank, "scale": scale, "num_keys": len(converted), "num_layers": num_layers, "target_modules": list(target_modules), "modules_to_save_skipped": skipped_full_modules, "src_bytes": src_safetensors.stat().st_size, "dst_bytes": dst_safetensors.stat().st_size, } def _strip_layer_prefix(attr_path: str) -> str: """Return the per-layer-relative path for a full attribute path. ``model.layers.5.self_attn.q_proj`` → ``self_attn.q_proj`` ``transformer.h.0.attn.c_attn`` → ``attn.c_attn`` ``model.embed_tokens`` (no layer prefix) → unchanged. MLX's ``adapter_config.json`` ``keys`` field is checked against each layer's ``named_modules()`` paths — those are layer-relative, not model-rooted. """ parts = attr_path.split(".") for i, p in enumerate(parts): if p.isdigit() and i + 1 < len(parts): return ".".join(parts[i + 1 :]) return attr_path def _extract_layer_index(attr_path: str) -> int | None: """Return the first numeric-segment index in ``attr_path``, or None when no layer index is present (e.g. embedding overrides). """ for p in attr_path.split("."): if p.isdigit(): return int(p) return None