feat(inference): PEFT → MLX .npz adapter converter (key mapping + npz writer)
- SHA
06a0cf2d8e245f01300dc493800e34f9a9954b98- Parents
-
f1b691a - Tree
1ee4da9
06a0cf2
06a0cf2d8e245f01300dc493800e34f9a9954b98f1b691a
1ee4da9| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm/inference/backends/pytorch_backend.py
|
1 | 3 |
| A |
src/dlm/inference/mlx_adapter.py
|
127 | 0 |
| M |
tests/unit/inference/test_backends_base.py
|
1 | 3 |
| A |
tests/unit/inference/test_mlx_adapter_conversion.py
|
79 | 0 |
src/dlm/inference/backends/pytorch_backend.pymodified@@ -36,9 +36,7 @@ class PyTorchBackend(InferenceBackend): | ||
| 36 | 36 | ) -> None: |
| 37 | 37 | from dlm.inference.loader import load_for_inference |
| 38 | 38 | |
| 39 | - self._loaded = load_for_inference( | |
| 40 | - store, base, self._caps, adapter_name=adapter_name | |
| 41 | - ) | |
| 39 | + self._loaded = load_for_inference(store, base, self._caps, adapter_name=adapter_name) | |
| 42 | 40 | |
| 43 | 41 | def generate(self, prompt: str, **gen_kwargs: Any) -> str: |
| 44 | 42 | if self._loaded is None: |
src/dlm/inference/mlx_adapter.pyadded@@ -0,0 +1,127 @@ | ||
| 1 | +"""PEFT safetensors → MLX-LM `.npz` LoRA-adapter converter. | |
| 2 | + | |
| 3 | +Sprint 21 ships MLX as a second inference backend on Apple Silicon. | |
| 4 | +PEFT writes LoRA weights as `adapter_model.safetensors` with keys like: | |
| 5 | + | |
| 6 | + base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight | |
| 7 | + base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight | |
| 8 | + | |
| 9 | +`mlx-lm`'s `load_adapters` expects the flattened, lowercased layout: | |
| 10 | + | |
| 11 | + model.layers.0.self_attn.q_proj.lora_a | |
| 12 | + model.layers.0.self_attn.q_proj.lora_b | |
| 13 | + | |
| 14 | +(no `base_model` prefix, lowercase `lora_a`/`lora_b`, `.weight` stripped | |
| 15 | +because mlx-lm adapter files store bare tensors under the final segment). | |
| 16 | + | |
| 17 | +The converter is split so the key-mapping logic is pure and | |
| 18 | +unit-testable; the tensor I/O layer is a thin wrapper around | |
| 19 | +`safetensors.torch.load_file` + `numpy.savez`. We write `.npz` via | |
| 20 | +numpy (not `mx.savez`) so conversion itself does not require MLX to | |
| 21 | +be importable — the artifact is still MLX-loadable because `mx.load` | |
| 22 | +understands numpy's `.npz` format. | |
| 23 | +""" | |
| 24 | + | |
| 25 | +from __future__ import annotations | |
| 26 | + | |
| 27 | +import re | |
| 28 | +from pathlib import Path | |
| 29 | +from typing import TYPE_CHECKING | |
| 30 | + | |
| 31 | +if TYPE_CHECKING: | |
| 32 | + from numpy.typing import NDArray | |
| 33 | + | |
| 34 | + | |
| 35 | +_PEFT_PREFIX = re.compile(r"^base_model\.model\.") | |
| 36 | +"""PEFT wraps the HF base in a double `base_model.model.` — the outer | |
| 37 | +`base_model` is the PEFT wrapper, the inner `model` is the HF model | |
| 38 | +attribute. We strip only the outer `base_model.model.` once so inner | |
| 39 | +`model.` (HF's actual attribute name) remains.""" | |
| 40 | + | |
| 41 | +_LORA_AB = re.compile(r"\.lora_([AB])\.weight$") | |
| 42 | +"""Matches the trailing `.lora_A.weight` / `.lora_B.weight` suffix.""" | |
| 43 | + | |
| 44 | + | |
| 45 | +class MlxConversionError(RuntimeError): | |
| 46 | + """Raised when a PEFT adapter cannot be converted to the MLX layout.""" | |
| 47 | + | |
| 48 | + | |
| 49 | +def map_peft_key_to_mlx(peft_key: str) -> str | None: | |
| 50 | + """Return the MLX-LM key for a PEFT tensor, or None if the key should be skipped. | |
| 51 | + | |
| 52 | + Rules: | |
| 53 | + - Strip the leading `base_model.model.` wrapper prefix. PEFT's | |
| 54 | + outer wrapper is redundant under MLX-LM's flattened naming. | |
| 55 | + - Rewrite `...lora_A.weight` → `...lora_a` and `...lora_B.weight` | |
| 56 | + → `...lora_b`. Case change + suffix drop in one pass. | |
| 57 | + - Return None for any tensor that isn't a LoRA A/B pair (e.g. | |
| 58 | + `modules_to_save` copies of `embed_tokens` — MLX-LM handles those | |
| 59 | + via the base model, not the adapter file). | |
| 60 | + | |
| 61 | + Pure string transformation; no tensor shape changes happen here. | |
| 62 | + """ | |
| 63 | + if not _LORA_AB.search(peft_key): | |
| 64 | + return None | |
| 65 | + stripped = _PEFT_PREFIX.sub("", peft_key, count=1) | |
| 66 | + return _LORA_AB.sub(lambda m: f".lora_{m.group(1).lower()}", stripped) | |
| 67 | + | |
| 68 | + | |
| 69 | +def map_all_keys(peft_keys: list[str]) -> dict[str, str]: | |
| 70 | + """Build the peft_key → mlx_key mapping for a whole adapter file. | |
| 71 | + | |
| 72 | + Non-LoRA keys are silently dropped (see `map_peft_key_to_mlx`). | |
| 73 | + Duplicate output keys trigger `MlxConversionError` — that would | |
| 74 | + mean two PEFT tensors collapsed to the same MLX name, which | |
| 75 | + silently overwriting would mask a real adapter-layout bug. | |
| 76 | + """ | |
| 77 | + mapping: dict[str, str] = {} | |
| 78 | + seen: dict[str, str] = {} | |
| 79 | + for key in peft_keys: | |
| 80 | + mapped = map_peft_key_to_mlx(key) | |
| 81 | + if mapped is None: | |
| 82 | + continue | |
| 83 | + if mapped in seen: | |
| 84 | + raise MlxConversionError( | |
| 85 | + f"two PEFT keys map to the same MLX key {mapped!r}: {seen[mapped]!r} and {key!r}" | |
| 86 | + ) | |
| 87 | + seen[mapped] = key | |
| 88 | + mapping[key] = mapped | |
| 89 | + if not mapping: | |
| 90 | + raise MlxConversionError( | |
| 91 | + "PEFT adapter has no LoRA A/B weight tensors — not a convertible LoRA checkpoint" | |
| 92 | + ) | |
| 93 | + return mapping | |
| 94 | + | |
| 95 | + | |
| 96 | +def peft_safetensors_to_mlx_npz( # pragma: no cover - I/O + torch deps | |
| 97 | + peft_adapter_dir: Path, | |
| 98 | + mlx_npz_path: Path, | |
| 99 | +) -> dict[str, str]: | |
| 100 | + """Convert `<adapter>/adapter_model.safetensors` → `<mlx_npz_path>`. | |
| 101 | + | |
| 102 | + Returns the key mapping actually written (peft_key → mlx_key) so | |
| 103 | + callers can log it when `--verbose`. | |
| 104 | + | |
| 105 | + Pragma'd: exercised end-to-end by the slow parity integration test | |
| 106 | + (covered via `map_all_keys` unit tests for the logic). | |
| 107 | + """ | |
| 108 | + import numpy as np | |
| 109 | + from safetensors.torch import load_file | |
| 110 | + | |
| 111 | + src = peft_adapter_dir / "adapter_model.safetensors" | |
| 112 | + if not src.exists(): | |
| 113 | + raise MlxConversionError(f"no adapter_model.safetensors in {peft_adapter_dir}") | |
| 114 | + | |
| 115 | + tensors = load_file(str(src)) | |
| 116 | + mapping = map_all_keys(list(tensors.keys())) | |
| 117 | + | |
| 118 | + np_tensors: dict[str, NDArray[np.float32]] = {} | |
| 119 | + for peft_key, mlx_key in mapping.items(): | |
| 120 | + tensor = tensors[peft_key] | |
| 121 | + # safetensors.torch.load_file returns torch.Tensor; .numpy() | |
| 122 | + # is the standard bridge. fp16 is preserved across the write. | |
| 123 | + np_tensors[mlx_key] = tensor.detach().cpu().numpy() | |
| 124 | + | |
| 125 | + mlx_npz_path.parent.mkdir(parents=True, exist_ok=True) | |
| 126 | + np.savez(str(mlx_npz_path), **np_tensors) # type: ignore[arg-type] | |
| 127 | + return mapping | |
tests/unit/inference/test_backends_base.pymodified@@ -40,9 +40,7 @@ class TestPyTorchBackendLoadDelegation: | ||
| 40 | 40 | backend = PyTorchBackend(caps) |
| 41 | 41 | backend.load(base, store, adapter_name="knowledge") |
| 42 | 42 | |
| 43 | - m_load.assert_called_once_with( | |
| 44 | - store, base, caps, adapter_name="knowledge" | |
| 45 | - ) | |
| 43 | + m_load.assert_called_once_with(store, base, caps, adapter_name="knowledge") | |
| 46 | 44 | |
| 47 | 45 | def test_generate_after_load_delegates(self) -> None: |
| 48 | 46 | with ( |
tests/unit/inference/test_mlx_adapter_conversion.pyadded@@ -0,0 +1,79 @@ | ||
| 1 | +"""Key mapping for PEFT → MLX LoRA adapter conversion.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +import pytest | |
| 6 | + | |
| 7 | +from dlm.inference.mlx_adapter import ( | |
| 8 | + MlxConversionError, | |
| 9 | + map_all_keys, | |
| 10 | + map_peft_key_to_mlx, | |
| 11 | +) | |
| 12 | + | |
| 13 | + | |
| 14 | +class TestMapPeftKey: | |
| 15 | + def test_lora_a_lowercases_and_strips_weight(self) -> None: | |
| 16 | + got = map_peft_key_to_mlx("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight") | |
| 17 | + assert got == "model.layers.0.self_attn.q_proj.lora_a" | |
| 18 | + | |
| 19 | + def test_lora_b_lowercases_and_strips_weight(self) -> None: | |
| 20 | + got = map_peft_key_to_mlx("base_model.model.model.layers.5.mlp.down_proj.lora_B.weight") | |
| 21 | + assert got == "model.layers.5.mlp.down_proj.lora_b" | |
| 22 | + | |
| 23 | + def test_base_model_prefix_stripped_once_only(self) -> None: | |
| 24 | + # Inner `model.` (HF attribute) survives; only the outer | |
| 25 | + # `base_model.model.` wrapper is dropped. | |
| 26 | + got = map_peft_key_to_mlx("base_model.model.model.embed_tokens.lora_A.weight") | |
| 27 | + assert got == "model.embed_tokens.lora_a" | |
| 28 | + | |
| 29 | + def test_non_lora_key_returns_none(self) -> None: | |
| 30 | + # modules_to_save duplicates, bias tensors, etc. | |
| 31 | + assert ( | |
| 32 | + map_peft_key_to_mlx( | |
| 33 | + "base_model.model.model.embed_tokens.modules_to_save.default.weight" | |
| 34 | + ) | |
| 35 | + is None | |
| 36 | + ) | |
| 37 | + assert map_peft_key_to_mlx("something.else.bias") is None | |
| 38 | + | |
| 39 | + def test_bare_bias_key_returns_none(self) -> None: | |
| 40 | + assert map_peft_key_to_mlx("base_model.model.model.layers.0.self_attn.q_proj.bias") is None | |
| 41 | + | |
| 42 | + | |
| 43 | +class TestMapAllKeys: | |
| 44 | + def test_pair_mapping(self) -> None: | |
| 45 | + keys = [ | |
| 46 | + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight", | |
| 47 | + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight", | |
| 48 | + "base_model.model.model.layers.1.mlp.up_proj.lora_A.weight", | |
| 49 | + "base_model.model.model.layers.1.mlp.up_proj.lora_B.weight", | |
| 50 | + ] | |
| 51 | + mapping = map_all_keys(keys) | |
| 52 | + assert len(mapping) == 4 | |
| 53 | + assert mapping[keys[0]] == "model.layers.0.self_attn.q_proj.lora_a" | |
| 54 | + assert mapping[keys[1]] == "model.layers.0.self_attn.q_proj.lora_b" | |
| 55 | + | |
| 56 | + def test_non_lora_keys_skipped_silently(self) -> None: | |
| 57 | + keys = [ | |
| 58 | + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight", | |
| 59 | + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight", | |
| 60 | + "base_model.model.model.embed_tokens.modules_to_save.default.weight", | |
| 61 | + ] | |
| 62 | + mapping = map_all_keys(keys) | |
| 63 | + assert len(mapping) == 2 | |
| 64 | + | |
| 65 | + def test_empty_adapter_raises(self) -> None: | |
| 66 | + with pytest.raises(MlxConversionError, match="no LoRA A/B"): | |
| 67 | + map_all_keys(["just.a.bias"]) | |
| 68 | + | |
| 69 | + def test_duplicate_mlx_key_raises(self) -> None: | |
| 70 | + # Two keys that both resolve to `q_proj.lora_a` after the | |
| 71 | + # outer `base_model.model.` strip — one already unwrapped, | |
| 72 | + # one wrapped. Defensive branch that matters if PEFT ever | |
| 73 | + # emits the same tensor under both wrapped + unwrapped names. | |
| 74 | + collision = [ | |
| 75 | + "q_proj.lora_A.weight", | |
| 76 | + "base_model.model.q_proj.lora_A.weight", | |
| 77 | + ] | |
| 78 | + with pytest.raises(MlxConversionError, match="map to the same"): | |
| 79 | + map_all_keys(collision) | |