tenseleyflow/documentlanguagemodel / 06a0cf2

Browse files

feat(inference): PEFT → MLX .npz adapter converter (key mapping + npz writer)

Authored by espadonne
SHA
06a0cf2d8e245f01300dc493800e34f9a9954b98
Parents
f1b691a
Tree
1ee4da9

4 changed files

StatusFile+-
M src/dlm/inference/backends/pytorch_backend.py 1 3
A src/dlm/inference/mlx_adapter.py 127 0
M tests/unit/inference/test_backends_base.py 1 3
A tests/unit/inference/test_mlx_adapter_conversion.py 79 0
src/dlm/inference/backends/pytorch_backend.pymodified
@@ -36,9 +36,7 @@ class PyTorchBackend(InferenceBackend):
3636
     ) -> None:
3737
         from dlm.inference.loader import load_for_inference
3838
 
39
-        self._loaded = load_for_inference(
40
-            store, base, self._caps, adapter_name=adapter_name
41
-        )
39
+        self._loaded = load_for_inference(store, base, self._caps, adapter_name=adapter_name)
4240
 
4341
     def generate(self, prompt: str, **gen_kwargs: Any) -> str:
4442
         if self._loaded is None:
src/dlm/inference/mlx_adapter.pyadded
@@ -0,0 +1,127 @@
1
+"""PEFT safetensors → MLX-LM `.npz` LoRA-adapter converter.
2
+
3
+Sprint 21 ships MLX as a second inference backend on Apple Silicon.
4
+PEFT writes LoRA weights as `adapter_model.safetensors` with keys like:
5
+
6
+    base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
7
+    base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight
8
+
9
+`mlx-lm`'s `load_adapters` expects the flattened, lowercased layout:
10
+
11
+    model.layers.0.self_attn.q_proj.lora_a
12
+    model.layers.0.self_attn.q_proj.lora_b
13
+
14
+(no `base_model` prefix, lowercase `lora_a`/`lora_b`, `.weight` stripped
15
+because mlx-lm adapter files store bare tensors under the final segment).
16
+
17
+The converter is split so the key-mapping logic is pure and
18
+unit-testable; the tensor I/O layer is a thin wrapper around
19
+`safetensors.torch.load_file` + `numpy.savez`. We write `.npz` via
20
+numpy (not `mx.savez`) so conversion itself does not require MLX to
21
+be importable — the artifact is still MLX-loadable because `mx.load`
22
+understands numpy's `.npz` format.
23
+"""
24
+
25
+from __future__ import annotations
26
+
27
+import re
28
+from pathlib import Path
29
+from typing import TYPE_CHECKING
30
+
31
+if TYPE_CHECKING:
32
+    from numpy.typing import NDArray
33
+
34
+
35
+_PEFT_PREFIX = re.compile(r"^base_model\.model\.")
36
+"""PEFT wraps the HF base in a double `base_model.model.` — the outer
37
+`base_model` is the PEFT wrapper, the inner `model` is the HF model
38
+attribute. We strip only the outer `base_model.model.` once so inner
39
+`model.` (HF's actual attribute name) remains."""
40
+
41
+_LORA_AB = re.compile(r"\.lora_([AB])\.weight$")
42
+"""Matches the trailing `.lora_A.weight` / `.lora_B.weight` suffix."""
43
+
44
+
45
+class MlxConversionError(RuntimeError):
46
+    """Raised when a PEFT adapter cannot be converted to the MLX layout."""
47
+
48
+
49
+def map_peft_key_to_mlx(peft_key: str) -> str | None:
50
+    """Return the MLX-LM key for a PEFT tensor, or None if the key should be skipped.
51
+
52
+    Rules:
53
+    - Strip the leading `base_model.model.` wrapper prefix. PEFT's
54
+      outer wrapper is redundant under MLX-LM's flattened naming.
55
+    - Rewrite `...lora_A.weight` → `...lora_a` and `...lora_B.weight`
56
+      → `...lora_b`. Case change + suffix drop in one pass.
57
+    - Return None for any tensor that isn't a LoRA A/B pair (e.g.
58
+      `modules_to_save` copies of `embed_tokens` — MLX-LM handles those
59
+      via the base model, not the adapter file).
60
+
61
+    Pure string transformation; no tensor shape changes happen here.
62
+    """
63
+    if not _LORA_AB.search(peft_key):
64
+        return None
65
+    stripped = _PEFT_PREFIX.sub("", peft_key, count=1)
66
+    return _LORA_AB.sub(lambda m: f".lora_{m.group(1).lower()}", stripped)
67
+
68
+
69
+def map_all_keys(peft_keys: list[str]) -> dict[str, str]:
70
+    """Build the peft_key → mlx_key mapping for a whole adapter file.
71
+
72
+    Non-LoRA keys are silently dropped (see `map_peft_key_to_mlx`).
73
+    Duplicate output keys trigger `MlxConversionError` — that would
74
+    mean two PEFT tensors collapsed to the same MLX name, which
75
+    silently overwriting would mask a real adapter-layout bug.
76
+    """
77
+    mapping: dict[str, str] = {}
78
+    seen: dict[str, str] = {}
79
+    for key in peft_keys:
80
+        mapped = map_peft_key_to_mlx(key)
81
+        if mapped is None:
82
+            continue
83
+        if mapped in seen:
84
+            raise MlxConversionError(
85
+                f"two PEFT keys map to the same MLX key {mapped!r}: {seen[mapped]!r} and {key!r}"
86
+            )
87
+        seen[mapped] = key
88
+        mapping[key] = mapped
89
+    if not mapping:
90
+        raise MlxConversionError(
91
+            "PEFT adapter has no LoRA A/B weight tensors — not a convertible LoRA checkpoint"
92
+        )
93
+    return mapping
94
+
95
+
96
+def peft_safetensors_to_mlx_npz(  # pragma: no cover - I/O + torch deps
97
+    peft_adapter_dir: Path,
98
+    mlx_npz_path: Path,
99
+) -> dict[str, str]:
100
+    """Convert `<adapter>/adapter_model.safetensors` → `<mlx_npz_path>`.
101
+
102
+    Returns the key mapping actually written (peft_key → mlx_key) so
103
+    callers can log it when `--verbose`.
104
+
105
+    Pragma'd: exercised end-to-end by the slow parity integration test
106
+    (covered via `map_all_keys` unit tests for the logic).
107
+    """
108
+    import numpy as np
109
+    from safetensors.torch import load_file
110
+
111
+    src = peft_adapter_dir / "adapter_model.safetensors"
112
+    if not src.exists():
113
+        raise MlxConversionError(f"no adapter_model.safetensors in {peft_adapter_dir}")
114
+
115
+    tensors = load_file(str(src))
116
+    mapping = map_all_keys(list(tensors.keys()))
117
+
118
+    np_tensors: dict[str, NDArray[np.float32]] = {}
119
+    for peft_key, mlx_key in mapping.items():
120
+        tensor = tensors[peft_key]
121
+        # safetensors.torch.load_file returns torch.Tensor; .numpy()
122
+        # is the standard bridge. fp16 is preserved across the write.
123
+        np_tensors[mlx_key] = tensor.detach().cpu().numpy()
124
+
125
+    mlx_npz_path.parent.mkdir(parents=True, exist_ok=True)
126
+    np.savez(str(mlx_npz_path), **np_tensors)  # type: ignore[arg-type]
127
+    return mapping
tests/unit/inference/test_backends_base.pymodified
@@ -40,9 +40,7 @@ class TestPyTorchBackendLoadDelegation:
4040
             backend = PyTorchBackend(caps)
4141
             backend.load(base, store, adapter_name="knowledge")
4242
 
43
-            m_load.assert_called_once_with(
44
-                store, base, caps, adapter_name="knowledge"
45
-            )
43
+            m_load.assert_called_once_with(store, base, caps, adapter_name="knowledge")
4644
 
4745
     def test_generate_after_load_delegates(self) -> None:
4846
         with (
tests/unit/inference/test_mlx_adapter_conversion.pyadded
@@ -0,0 +1,79 @@
1
+"""Key mapping for PEFT → MLX LoRA adapter conversion."""
2
+
3
+from __future__ import annotations
4
+
5
+import pytest
6
+
7
+from dlm.inference.mlx_adapter import (
8
+    MlxConversionError,
9
+    map_all_keys,
10
+    map_peft_key_to_mlx,
11
+)
12
+
13
+
14
+class TestMapPeftKey:
15
+    def test_lora_a_lowercases_and_strips_weight(self) -> None:
16
+        got = map_peft_key_to_mlx("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight")
17
+        assert got == "model.layers.0.self_attn.q_proj.lora_a"
18
+
19
+    def test_lora_b_lowercases_and_strips_weight(self) -> None:
20
+        got = map_peft_key_to_mlx("base_model.model.model.layers.5.mlp.down_proj.lora_B.weight")
21
+        assert got == "model.layers.5.mlp.down_proj.lora_b"
22
+
23
+    def test_base_model_prefix_stripped_once_only(self) -> None:
24
+        # Inner `model.` (HF attribute) survives; only the outer
25
+        # `base_model.model.` wrapper is dropped.
26
+        got = map_peft_key_to_mlx("base_model.model.model.embed_tokens.lora_A.weight")
27
+        assert got == "model.embed_tokens.lora_a"
28
+
29
+    def test_non_lora_key_returns_none(self) -> None:
30
+        # modules_to_save duplicates, bias tensors, etc.
31
+        assert (
32
+            map_peft_key_to_mlx(
33
+                "base_model.model.model.embed_tokens.modules_to_save.default.weight"
34
+            )
35
+            is None
36
+        )
37
+        assert map_peft_key_to_mlx("something.else.bias") is None
38
+
39
+    def test_bare_bias_key_returns_none(self) -> None:
40
+        assert map_peft_key_to_mlx("base_model.model.model.layers.0.self_attn.q_proj.bias") is None
41
+
42
+
43
+class TestMapAllKeys:
44
+    def test_pair_mapping(self) -> None:
45
+        keys = [
46
+            "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight",
47
+            "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight",
48
+            "base_model.model.model.layers.1.mlp.up_proj.lora_A.weight",
49
+            "base_model.model.model.layers.1.mlp.up_proj.lora_B.weight",
50
+        ]
51
+        mapping = map_all_keys(keys)
52
+        assert len(mapping) == 4
53
+        assert mapping[keys[0]] == "model.layers.0.self_attn.q_proj.lora_a"
54
+        assert mapping[keys[1]] == "model.layers.0.self_attn.q_proj.lora_b"
55
+
56
+    def test_non_lora_keys_skipped_silently(self) -> None:
57
+        keys = [
58
+            "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight",
59
+            "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight",
60
+            "base_model.model.model.embed_tokens.modules_to_save.default.weight",
61
+        ]
62
+        mapping = map_all_keys(keys)
63
+        assert len(mapping) == 2
64
+
65
+    def test_empty_adapter_raises(self) -> None:
66
+        with pytest.raises(MlxConversionError, match="no LoRA A/B"):
67
+            map_all_keys(["just.a.bias"])
68
+
69
+    def test_duplicate_mlx_key_raises(self) -> None:
70
+        # Two keys that both resolve to `q_proj.lora_a` after the
71
+        # outer `base_model.model.` strip — one already unwrapped,
72
+        # one wrapped. Defensive branch that matters if PEFT ever
73
+        # emits the same tensor under both wrapped + unwrapped names.
74
+        collision = [
75
+            "q_proj.lora_A.weight",
76
+            "base_model.model.q_proj.lora_A.weight",
77
+        ]
78
+        with pytest.raises(MlxConversionError, match="map to the same"):
79
+            map_all_keys(collision)