Python · 14225 bytes Raw Blame History
1 """Unit tests for the PEFT→MLX-LM LoRA converter (Sprint 24, F01).
2
3 Builds synthetic PEFT-shaped inputs (no torch / peft install required —
4 just safetensors + json + numpy) and asserts the converter produces
5 mlx-lm-shaped outputs with the right keys, shape transposes, and
6 config fields.
7
8 End-to-end verification against real PEFT and a real
9 ``mlx_lm.load(..., adapter_path=...)`` call lives in
10 ``tests/integration/test_mlx_converter_e2e.py`` (slow + online +
11 darwin-arm64).
12 """
13
14 from __future__ import annotations
15
16 import json
17 from pathlib import Path
18
19 import numpy as np
20 import pytest
21
22 # safetensors ships in the [hf] extra (not [dev]). The fast lane runs
23 # without [hf]; skip the whole module when missing rather than fail
24 # collection. Local + slow-lane runs install [hf] and exercise these.
25 safetensors_numpy = pytest.importorskip(
26 "safetensors.numpy",
27 reason="safetensors not installed (install via the [hf] extra)",
28 )
29 load_file = safetensors_numpy.load_file
30 save_file = safetensors_numpy.save_file
31
32 from dlm_sway.backends._mlx_convert import ( # noqa: E402 — import-after-skip
33 MlxConvertError,
34 _extract_layer_index,
35 _strip_layer_prefix,
36 convert_peft_to_mlx,
37 )
38
39
40 def _write_synthetic_peft_adapter(
41 dst: Path,
42 *,
43 rank: int = 8,
44 alpha: int = 16,
45 dropout: float = 0.0,
46 num_layers: int = 3,
47 target_modules: tuple[str, ...] = ("q_proj", "v_proj"),
48 in_features: int = 64,
49 out_features: int = 64,
50 modules_to_save: list[str] | None = None,
51 ) -> None:
52 """Produce a minimal but format-correct PEFT adapter directory."""
53 dst.mkdir(parents=True, exist_ok=True)
54 weights: dict[str, np.ndarray] = {}
55 for layer_idx in range(num_layers):
56 for module in target_modules:
57 base = f"base_model.model.model.layers.{layer_idx}.self_attn.{module}"
58 # PEFT shapes: lora_A=(r, in), lora_B=(out, r)
59 weights[f"{base}.lora_A.weight"] = (
60 np.random.RandomState(layer_idx).randn(rank, in_features).astype(np.float32)
61 )
62 weights[f"{base}.lora_B.weight"] = (
63 np.random.RandomState(layer_idx + 1000).randn(out_features, rank).astype(np.float32)
64 )
65 save_file(weights, str(dst / "adapter_model.safetensors"))
66 config = {
67 "peft_type": "LORA",
68 "r": rank,
69 "lora_alpha": alpha,
70 "lora_dropout": dropout,
71 "target_modules": list(target_modules),
72 "modules_to_save": modules_to_save or [],
73 "task_type": "CAUSAL_LM",
74 "bias": "none",
75 }
76 (dst / "adapter_config.json").write_text(json.dumps(config), encoding="utf-8")
77
78
79 class TestStripLayerPrefix:
80 """The helper that turns full attribute paths into layer-relative
81 paths for MLX's ``adapter_config.json::keys`` field."""
82
83 def test_typical_decoder_layer_path(self) -> None:
84 assert _strip_layer_prefix("model.layers.5.self_attn.q_proj") == "self_attn.q_proj"
85
86 def test_gpt2_style_path(self) -> None:
87 assert _strip_layer_prefix("transformer.h.0.attn.c_attn") == "attn.c_attn"
88
89 def test_no_layer_index_returns_input(self) -> None:
90 """Embedding-style paths (no numeric segment) pass through."""
91 assert _strip_layer_prefix("model.embed_tokens") == "model.embed_tokens"
92
93 def test_extract_layer_index(self) -> None:
94 assert _extract_layer_index("model.layers.0.self_attn.q_proj") == 0
95 assert _extract_layer_index("model.layers.42.self_attn.q_proj") == 42
96 assert _extract_layer_index("model.embed_tokens") is None
97
98
99 class TestConvertPeftToMlxBasic:
100 """Happy path: standard PEFT LoRA → MLX adapter."""
101
102 def test_produces_expected_output_files(self, tmp_path: Path) -> None:
103 src = tmp_path / "peft"
104 dst = tmp_path / "mlx"
105 _write_synthetic_peft_adapter(src)
106
107 report = convert_peft_to_mlx(src, dst)
108
109 assert (dst / "adapters.safetensors").exists()
110 assert (dst / "adapter_config.json").exists()
111 assert report["rank"] == 8
112 assert report["scale"] == pytest.approx(2.0) # 16 / 8
113 assert report["num_keys"] == 12 # 3 layers × 2 modules × 2 (lora_a + lora_b)
114 assert report["num_layers"] == 3
115
116 def test_mlx_config_shape(self, tmp_path: Path) -> None:
117 """The written ``adapter_config.json`` matches mlx-lm's
118 ``load_adapters`` expectations."""
119 src = tmp_path / "peft"
120 dst = tmp_path / "mlx"
121 _write_synthetic_peft_adapter(src, rank=16, alpha=32, dropout=0.1, num_layers=4)
122
123 convert_peft_to_mlx(src, dst)
124 cfg = json.loads((dst / "adapter_config.json").read_text(encoding="utf-8"))
125
126 assert cfg["fine_tune_type"] == "lora"
127 assert cfg["num_layers"] == 4
128 params = cfg["lora_parameters"]
129 assert params["rank"] == 16
130 assert params["scale"] == pytest.approx(2.0) # 32 / 16
131 assert params["dropout"] == pytest.approx(0.1)
132 assert params["keys"] == ["self_attn.q_proj", "self_attn.v_proj"]
133
134 def test_lora_factor_shapes_transposed(self, tmp_path: Path) -> None:
135 """PEFT lora_A=(r, in) → MLX lora_a=(in, r); same for lora_B/lora_b."""
136 src = tmp_path / "peft"
137 dst = tmp_path / "mlx"
138 _write_synthetic_peft_adapter(src, rank=8, in_features=64, out_features=128, num_layers=1)
139 # Sanity-check the synthetic input first.
140 peft_w = load_file(str(src / "adapter_model.safetensors"))
141 a_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"
142 b_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"
143 assert peft_w[a_key].shape == (8, 64)
144 assert peft_w[b_key].shape == (128, 8)
145
146 convert_peft_to_mlx(src, dst)
147 mlx_w = load_file(str(dst / "adapters.safetensors"))
148 assert "model.layers.0.self_attn.q_proj.lora_a" in mlx_w
149 assert "model.layers.0.self_attn.q_proj.lora_b" in mlx_w
150 assert mlx_w["model.layers.0.self_attn.q_proj.lora_a"].shape == (64, 8)
151 assert mlx_w["model.layers.0.self_attn.q_proj.lora_b"].shape == (8, 128)
152
153 def test_values_preserved_through_transpose(self, tmp_path: Path) -> None:
154 """Round-trip the underlying numbers — transpose must be the
155 only operation, not a reshape with data loss."""
156 src = tmp_path / "peft"
157 dst = tmp_path / "mlx"
158 _write_synthetic_peft_adapter(src, num_layers=1)
159
160 peft_w = load_file(str(src / "adapter_model.safetensors"))
161 convert_peft_to_mlx(src, dst)
162 mlx_w = load_file(str(dst / "adapters.safetensors"))
163
164 a_in = peft_w["base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"]
165 a_out = mlx_w["model.layers.0.self_attn.q_proj.lora_a"]
166 np.testing.assert_array_equal(a_in.T, a_out)
167
168
169 class TestConvertPeftToMlxErrors:
170 """Structural errors must surface as ``MlxConvertError`` with
171 actionable messages, not as cryptic IO / KeyError tracebacks."""
172
173 def test_missing_safetensors_raises(self, tmp_path: Path) -> None:
174 src = tmp_path / "empty"
175 src.mkdir()
176 (src / "adapter_config.json").write_text('{"peft_type": "LORA", "r": 8}')
177 with pytest.raises(MlxConvertError, match="missing adapter_model.safetensors"):
178 convert_peft_to_mlx(src, tmp_path / "mlx")
179
180 def test_missing_config_raises(self, tmp_path: Path) -> None:
181 src = tmp_path / "empty"
182 src.mkdir()
183 save_file({}, str(src / "adapter_model.safetensors"))
184 with pytest.raises(MlxConvertError, match="missing adapter_config.json"):
185 convert_peft_to_mlx(src, tmp_path / "mlx")
186
187 def test_non_lora_peft_type_raises(self, tmp_path: Path) -> None:
188 src = tmp_path / "ia3"
189 _write_synthetic_peft_adapter(src)
190 cfg = json.loads((src / "adapter_config.json").read_text())
191 cfg["peft_type"] = "IA3"
192 (src / "adapter_config.json").write_text(json.dumps(cfg))
193 with pytest.raises(MlxConvertError, match="unsupported PEFT type"):
194 convert_peft_to_mlx(src, tmp_path / "mlx")
195
196 def test_invalid_rank_raises(self, tmp_path: Path) -> None:
197 src = tmp_path / "bad"
198 _write_synthetic_peft_adapter(src)
199 cfg = json.loads((src / "adapter_config.json").read_text())
200 cfg["r"] = 0
201 (src / "adapter_config.json").write_text(json.dumps(cfg))
202 with pytest.raises(MlxConvertError, match="invalid LoRA rank"):
203 convert_peft_to_mlx(src, tmp_path / "mlx")
204
205 def test_dst_not_empty_refuses_without_overwrite(self, tmp_path: Path) -> None:
206 src = tmp_path / "peft"
207 dst = tmp_path / "mlx"
208 _write_synthetic_peft_adapter(src)
209 # Pre-create the output to simulate a stale conversion.
210 dst.mkdir()
211 (dst / "adapters.safetensors").write_bytes(b"old")
212 with pytest.raises(MlxConvertError, match="overwrite=True"):
213 convert_peft_to_mlx(src, dst)
214
215 def test_dst_overwrite_replaces_existing(self, tmp_path: Path) -> None:
216 src = tmp_path / "peft"
217 dst = tmp_path / "mlx"
218 _write_synthetic_peft_adapter(src)
219 dst.mkdir()
220 (dst / "adapters.safetensors").write_bytes(b"old")
221 convert_peft_to_mlx(src, dst, overwrite=True)
222 # Should not be the placeholder bytes any more.
223 assert (dst / "adapters.safetensors").read_bytes()[:4] != b"old\x00"
224
225 def test_unexpected_key_prefix_raises(self, tmp_path: Path) -> None:
226 """A safetensors file whose keys don't have the
227 ``base_model.model.`` PEFT-wrapper prefix shouldn't be silently
228 emitted with a wrong MLX path."""
229 src = tmp_path / "weird"
230 src.mkdir()
231 save_file(
232 {"some.other.lora_A.weight": np.zeros((8, 64), dtype=np.float32)},
233 str(src / "adapter_model.safetensors"),
234 )
235 cfg = {
236 "peft_type": "LORA",
237 "r": 8,
238 "lora_alpha": 16,
239 "target_modules": ["x"],
240 }
241 (src / "adapter_config.json").write_text(json.dumps(cfg))
242 with pytest.raises(MlxConvertError, match="missing 'base_model.model.' prefix"):
243 convert_peft_to_mlx(src, tmp_path / "mlx")
244
245
246 class TestEnsureMlxAdapterAutoConvert:
247 """``MLXDifferentialBackend.__init__`` calls ``_ensure_mlx_adapter``
248 to upgrade PEFT-shaped adapter dirs to MLX format on the fly. The
249 function lives in ``backends/mlx.py`` so it doesn't pull mlx-lm
250 when the path is already MLX-shaped."""
251
252 def test_passes_through_when_dir_is_already_mlx_shape(self, tmp_path: Path) -> None:
253 """Existing ``adapters.safetensors`` → no conversion, return
254 the same path unchanged. (Manual conversions / pre-built MLX
255 adapters from other tools must not be re-converted.)"""
256 from dlm_sway.backends.mlx import _ensure_mlx_adapter
257
258 mlx_dir = tmp_path / "mlx"
259 mlx_dir.mkdir()
260 save_file({}, str(mlx_dir / "adapters.safetensors"))
261 (mlx_dir / "adapter_config.json").write_text('{"fine_tune_type":"lora"}')
262 out = _ensure_mlx_adapter(mlx_dir)
263 assert out == mlx_dir
264
265 def test_auto_converts_peft_dir_into_cache(
266 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
267 ) -> None:
268 """A PEFT-shaped dir gets converted into XDG_CACHE_HOME on
269 first call; the returned path is the cache dir, not the source."""
270 from dlm_sway.backends.mlx import _ensure_mlx_adapter
271
272 cache_root = tmp_path / "cache"
273 monkeypatch.setenv("XDG_CACHE_HOME", str(cache_root))
274
275 peft_dir = tmp_path / "peft"
276 _write_synthetic_peft_adapter(peft_dir)
277 out = _ensure_mlx_adapter(peft_dir)
278
279 assert out != peft_dir
280 assert (out / "adapters.safetensors").exists()
281 assert (out / "adapter_config.json").exists()
282 assert str(out).startswith(str(cache_root))
283
284 def test_repeated_calls_short_circuit_on_cache_hit(
285 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
286 ) -> None:
287 """Same PEFT bytes → same cache hash → second call returns the
288 cached dir without re-converting (touch mtime to detect)."""
289 from dlm_sway.backends.mlx import _ensure_mlx_adapter
290
291 monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "cache"))
292 peft_dir = tmp_path / "peft"
293 _write_synthetic_peft_adapter(peft_dir)
294
295 first = _ensure_mlx_adapter(peft_dir)
296 first_mtime = (first / "adapters.safetensors").stat().st_mtime_ns
297
298 # Second call — should NOT rewrite the file.
299 second = _ensure_mlx_adapter(peft_dir)
300 assert second == first
301 assert (second / "adapters.safetensors").stat().st_mtime_ns == first_mtime
302
303 def test_passes_through_unrecognized_dir(self, tmp_path: Path) -> None:
304 """A directory with neither shape — let mlx_lm.load surface
305 its own error rather than this helper second-guessing."""
306 from dlm_sway.backends.mlx import _ensure_mlx_adapter
307
308 empty = tmp_path / "empty"
309 empty.mkdir()
310 out = _ensure_mlx_adapter(empty)
311 assert out == empty
312
313
314 class TestModulesToSave:
315 """``modules_to_save`` (e.g. embed_tokens, lm_head) must be skipped
316 cleanly with a report entry, not crash the converter."""
317
318 def test_modules_to_save_skipped_and_reported(self, tmp_path: Path) -> None:
319 src = tmp_path / "peft"
320 dst = tmp_path / "mlx"
321 _write_synthetic_peft_adapter(src, num_layers=1, modules_to_save=["embed_tokens"])
322 # Inject a non-LoRA full-weight tensor that simulates
323 # PEFT's modules_to_save serialization.
324 existing = load_file(str(src / "adapter_model.safetensors"))
325 existing["base_model.model.model.embed_tokens.modules_to_save.default.weight"] = np.zeros(
326 (100, 64), dtype=np.float32
327 )
328 save_file(existing, str(src / "adapter_model.safetensors"))
329
330 report = convert_peft_to_mlx(src, dst)
331 assert len(report["modules_to_save_skipped"]) == 1
332 assert "embed_tokens" in report["modules_to_save_skipped"][0]
333 # Real LoRA factors still extracted.
334 assert report["num_keys"] == 4 # 1 layer × 2 modules × 2 factors