| 1 | """Unit tests for the PEFT→MLX-LM LoRA converter (Sprint 24, F01). |
| 2 | |
| 3 | Builds synthetic PEFT-shaped inputs (no torch / peft install required — |
| 4 | just safetensors + json + numpy) and asserts the converter produces |
| 5 | mlx-lm-shaped outputs with the right keys, shape transposes, and |
| 6 | config fields. |
| 7 | |
| 8 | End-to-end verification against real PEFT and a real |
| 9 | ``mlx_lm.load(..., adapter_path=...)`` call lives in |
| 10 | ``tests/integration/test_mlx_converter_e2e.py`` (slow + online + |
| 11 | darwin-arm64). |
| 12 | """ |
| 13 | |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import json |
| 17 | from pathlib import Path |
| 18 | |
| 19 | import numpy as np |
| 20 | import pytest |
| 21 | |
| 22 | # safetensors ships in the [hf] extra (not [dev]). The fast lane runs |
| 23 | # without [hf]; skip the whole module when missing rather than fail |
| 24 | # collection. Local + slow-lane runs install [hf] and exercise these. |
| 25 | safetensors_numpy = pytest.importorskip( |
| 26 | "safetensors.numpy", |
| 27 | reason="safetensors not installed (install via the [hf] extra)", |
| 28 | ) |
| 29 | load_file = safetensors_numpy.load_file |
| 30 | save_file = safetensors_numpy.save_file |
| 31 | |
| 32 | from dlm_sway.backends._mlx_convert import ( # noqa: E402 — import-after-skip |
| 33 | MlxConvertError, |
| 34 | _extract_layer_index, |
| 35 | _strip_layer_prefix, |
| 36 | convert_peft_to_mlx, |
| 37 | ) |
| 38 | |
| 39 | |
| 40 | def _write_synthetic_peft_adapter( |
| 41 | dst: Path, |
| 42 | *, |
| 43 | rank: int = 8, |
| 44 | alpha: int = 16, |
| 45 | dropout: float = 0.0, |
| 46 | num_layers: int = 3, |
| 47 | target_modules: tuple[str, ...] = ("q_proj", "v_proj"), |
| 48 | in_features: int = 64, |
| 49 | out_features: int = 64, |
| 50 | modules_to_save: list[str] | None = None, |
| 51 | ) -> None: |
| 52 | """Produce a minimal but format-correct PEFT adapter directory.""" |
| 53 | dst.mkdir(parents=True, exist_ok=True) |
| 54 | weights: dict[str, np.ndarray] = {} |
| 55 | for layer_idx in range(num_layers): |
| 56 | for module in target_modules: |
| 57 | base = f"base_model.model.model.layers.{layer_idx}.self_attn.{module}" |
| 58 | # PEFT shapes: lora_A=(r, in), lora_B=(out, r) |
| 59 | weights[f"{base}.lora_A.weight"] = ( |
| 60 | np.random.RandomState(layer_idx).randn(rank, in_features).astype(np.float32) |
| 61 | ) |
| 62 | weights[f"{base}.lora_B.weight"] = ( |
| 63 | np.random.RandomState(layer_idx + 1000).randn(out_features, rank).astype(np.float32) |
| 64 | ) |
| 65 | save_file(weights, str(dst / "adapter_model.safetensors")) |
| 66 | config = { |
| 67 | "peft_type": "LORA", |
| 68 | "r": rank, |
| 69 | "lora_alpha": alpha, |
| 70 | "lora_dropout": dropout, |
| 71 | "target_modules": list(target_modules), |
| 72 | "modules_to_save": modules_to_save or [], |
| 73 | "task_type": "CAUSAL_LM", |
| 74 | "bias": "none", |
| 75 | } |
| 76 | (dst / "adapter_config.json").write_text(json.dumps(config), encoding="utf-8") |
| 77 | |
| 78 | |
| 79 | class TestStripLayerPrefix: |
| 80 | """The helper that turns full attribute paths into layer-relative |
| 81 | paths for MLX's ``adapter_config.json::keys`` field.""" |
| 82 | |
| 83 | def test_typical_decoder_layer_path(self) -> None: |
| 84 | assert _strip_layer_prefix("model.layers.5.self_attn.q_proj") == "self_attn.q_proj" |
| 85 | |
| 86 | def test_gpt2_style_path(self) -> None: |
| 87 | assert _strip_layer_prefix("transformer.h.0.attn.c_attn") == "attn.c_attn" |
| 88 | |
| 89 | def test_no_layer_index_returns_input(self) -> None: |
| 90 | """Embedding-style paths (no numeric segment) pass through.""" |
| 91 | assert _strip_layer_prefix("model.embed_tokens") == "model.embed_tokens" |
| 92 | |
| 93 | def test_extract_layer_index(self) -> None: |
| 94 | assert _extract_layer_index("model.layers.0.self_attn.q_proj") == 0 |
| 95 | assert _extract_layer_index("model.layers.42.self_attn.q_proj") == 42 |
| 96 | assert _extract_layer_index("model.embed_tokens") is None |
| 97 | |
| 98 | |
| 99 | class TestConvertPeftToMlxBasic: |
| 100 | """Happy path: standard PEFT LoRA → MLX adapter.""" |
| 101 | |
| 102 | def test_produces_expected_output_files(self, tmp_path: Path) -> None: |
| 103 | src = tmp_path / "peft" |
| 104 | dst = tmp_path / "mlx" |
| 105 | _write_synthetic_peft_adapter(src) |
| 106 | |
| 107 | report = convert_peft_to_mlx(src, dst) |
| 108 | |
| 109 | assert (dst / "adapters.safetensors").exists() |
| 110 | assert (dst / "adapter_config.json").exists() |
| 111 | assert report["rank"] == 8 |
| 112 | assert report["scale"] == pytest.approx(2.0) # 16 / 8 |
| 113 | assert report["num_keys"] == 12 # 3 layers × 2 modules × 2 (lora_a + lora_b) |
| 114 | assert report["num_layers"] == 3 |
| 115 | |
| 116 | def test_mlx_config_shape(self, tmp_path: Path) -> None: |
| 117 | """The written ``adapter_config.json`` matches mlx-lm's |
| 118 | ``load_adapters`` expectations.""" |
| 119 | src = tmp_path / "peft" |
| 120 | dst = tmp_path / "mlx" |
| 121 | _write_synthetic_peft_adapter(src, rank=16, alpha=32, dropout=0.1, num_layers=4) |
| 122 | |
| 123 | convert_peft_to_mlx(src, dst) |
| 124 | cfg = json.loads((dst / "adapter_config.json").read_text(encoding="utf-8")) |
| 125 | |
| 126 | assert cfg["fine_tune_type"] == "lora" |
| 127 | assert cfg["num_layers"] == 4 |
| 128 | params = cfg["lora_parameters"] |
| 129 | assert params["rank"] == 16 |
| 130 | assert params["scale"] == pytest.approx(2.0) # 32 / 16 |
| 131 | assert params["dropout"] == pytest.approx(0.1) |
| 132 | assert params["keys"] == ["self_attn.q_proj", "self_attn.v_proj"] |
| 133 | |
| 134 | def test_lora_factor_shapes_transposed(self, tmp_path: Path) -> None: |
| 135 | """PEFT lora_A=(r, in) → MLX lora_a=(in, r); same for lora_B/lora_b.""" |
| 136 | src = tmp_path / "peft" |
| 137 | dst = tmp_path / "mlx" |
| 138 | _write_synthetic_peft_adapter(src, rank=8, in_features=64, out_features=128, num_layers=1) |
| 139 | # Sanity-check the synthetic input first. |
| 140 | peft_w = load_file(str(src / "adapter_model.safetensors")) |
| 141 | a_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" |
| 142 | b_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight" |
| 143 | assert peft_w[a_key].shape == (8, 64) |
| 144 | assert peft_w[b_key].shape == (128, 8) |
| 145 | |
| 146 | convert_peft_to_mlx(src, dst) |
| 147 | mlx_w = load_file(str(dst / "adapters.safetensors")) |
| 148 | assert "model.layers.0.self_attn.q_proj.lora_a" in mlx_w |
| 149 | assert "model.layers.0.self_attn.q_proj.lora_b" in mlx_w |
| 150 | assert mlx_w["model.layers.0.self_attn.q_proj.lora_a"].shape == (64, 8) |
| 151 | assert mlx_w["model.layers.0.self_attn.q_proj.lora_b"].shape == (8, 128) |
| 152 | |
| 153 | def test_values_preserved_through_transpose(self, tmp_path: Path) -> None: |
| 154 | """Round-trip the underlying numbers — transpose must be the |
| 155 | only operation, not a reshape with data loss.""" |
| 156 | src = tmp_path / "peft" |
| 157 | dst = tmp_path / "mlx" |
| 158 | _write_synthetic_peft_adapter(src, num_layers=1) |
| 159 | |
| 160 | peft_w = load_file(str(src / "adapter_model.safetensors")) |
| 161 | convert_peft_to_mlx(src, dst) |
| 162 | mlx_w = load_file(str(dst / "adapters.safetensors")) |
| 163 | |
| 164 | a_in = peft_w["base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"] |
| 165 | a_out = mlx_w["model.layers.0.self_attn.q_proj.lora_a"] |
| 166 | np.testing.assert_array_equal(a_in.T, a_out) |
| 167 | |
| 168 | |
| 169 | class TestConvertPeftToMlxErrors: |
| 170 | """Structural errors must surface as ``MlxConvertError`` with |
| 171 | actionable messages, not as cryptic IO / KeyError tracebacks.""" |
| 172 | |
| 173 | def test_missing_safetensors_raises(self, tmp_path: Path) -> None: |
| 174 | src = tmp_path / "empty" |
| 175 | src.mkdir() |
| 176 | (src / "adapter_config.json").write_text('{"peft_type": "LORA", "r": 8}') |
| 177 | with pytest.raises(MlxConvertError, match="missing adapter_model.safetensors"): |
| 178 | convert_peft_to_mlx(src, tmp_path / "mlx") |
| 179 | |
| 180 | def test_missing_config_raises(self, tmp_path: Path) -> None: |
| 181 | src = tmp_path / "empty" |
| 182 | src.mkdir() |
| 183 | save_file({}, str(src / "adapter_model.safetensors")) |
| 184 | with pytest.raises(MlxConvertError, match="missing adapter_config.json"): |
| 185 | convert_peft_to_mlx(src, tmp_path / "mlx") |
| 186 | |
| 187 | def test_non_lora_peft_type_raises(self, tmp_path: Path) -> None: |
| 188 | src = tmp_path / "ia3" |
| 189 | _write_synthetic_peft_adapter(src) |
| 190 | cfg = json.loads((src / "adapter_config.json").read_text()) |
| 191 | cfg["peft_type"] = "IA3" |
| 192 | (src / "adapter_config.json").write_text(json.dumps(cfg)) |
| 193 | with pytest.raises(MlxConvertError, match="unsupported PEFT type"): |
| 194 | convert_peft_to_mlx(src, tmp_path / "mlx") |
| 195 | |
| 196 | def test_invalid_rank_raises(self, tmp_path: Path) -> None: |
| 197 | src = tmp_path / "bad" |
| 198 | _write_synthetic_peft_adapter(src) |
| 199 | cfg = json.loads((src / "adapter_config.json").read_text()) |
| 200 | cfg["r"] = 0 |
| 201 | (src / "adapter_config.json").write_text(json.dumps(cfg)) |
| 202 | with pytest.raises(MlxConvertError, match="invalid LoRA rank"): |
| 203 | convert_peft_to_mlx(src, tmp_path / "mlx") |
| 204 | |
| 205 | def test_dst_not_empty_refuses_without_overwrite(self, tmp_path: Path) -> None: |
| 206 | src = tmp_path / "peft" |
| 207 | dst = tmp_path / "mlx" |
| 208 | _write_synthetic_peft_adapter(src) |
| 209 | # Pre-create the output to simulate a stale conversion. |
| 210 | dst.mkdir() |
| 211 | (dst / "adapters.safetensors").write_bytes(b"old") |
| 212 | with pytest.raises(MlxConvertError, match="overwrite=True"): |
| 213 | convert_peft_to_mlx(src, dst) |
| 214 | |
| 215 | def test_dst_overwrite_replaces_existing(self, tmp_path: Path) -> None: |
| 216 | src = tmp_path / "peft" |
| 217 | dst = tmp_path / "mlx" |
| 218 | _write_synthetic_peft_adapter(src) |
| 219 | dst.mkdir() |
| 220 | (dst / "adapters.safetensors").write_bytes(b"old") |
| 221 | convert_peft_to_mlx(src, dst, overwrite=True) |
| 222 | # Should not be the placeholder bytes any more. |
| 223 | assert (dst / "adapters.safetensors").read_bytes()[:4] != b"old\x00" |
| 224 | |
| 225 | def test_unexpected_key_prefix_raises(self, tmp_path: Path) -> None: |
| 226 | """A safetensors file whose keys don't have the |
| 227 | ``base_model.model.`` PEFT-wrapper prefix shouldn't be silently |
| 228 | emitted with a wrong MLX path.""" |
| 229 | src = tmp_path / "weird" |
| 230 | src.mkdir() |
| 231 | save_file( |
| 232 | {"some.other.lora_A.weight": np.zeros((8, 64), dtype=np.float32)}, |
| 233 | str(src / "adapter_model.safetensors"), |
| 234 | ) |
| 235 | cfg = { |
| 236 | "peft_type": "LORA", |
| 237 | "r": 8, |
| 238 | "lora_alpha": 16, |
| 239 | "target_modules": ["x"], |
| 240 | } |
| 241 | (src / "adapter_config.json").write_text(json.dumps(cfg)) |
| 242 | with pytest.raises(MlxConvertError, match="missing 'base_model.model.' prefix"): |
| 243 | convert_peft_to_mlx(src, tmp_path / "mlx") |
| 244 | |
| 245 | |
| 246 | class TestEnsureMlxAdapterAutoConvert: |
| 247 | """``MLXDifferentialBackend.__init__`` calls ``_ensure_mlx_adapter`` |
| 248 | to upgrade PEFT-shaped adapter dirs to MLX format on the fly. The |
| 249 | function lives in ``backends/mlx.py`` so it doesn't pull mlx-lm |
| 250 | when the path is already MLX-shaped.""" |
| 251 | |
| 252 | def test_passes_through_when_dir_is_already_mlx_shape(self, tmp_path: Path) -> None: |
| 253 | """Existing ``adapters.safetensors`` → no conversion, return |
| 254 | the same path unchanged. (Manual conversions / pre-built MLX |
| 255 | adapters from other tools must not be re-converted.)""" |
| 256 | from dlm_sway.backends.mlx import _ensure_mlx_adapter |
| 257 | |
| 258 | mlx_dir = tmp_path / "mlx" |
| 259 | mlx_dir.mkdir() |
| 260 | save_file({}, str(mlx_dir / "adapters.safetensors")) |
| 261 | (mlx_dir / "adapter_config.json").write_text('{"fine_tune_type":"lora"}') |
| 262 | out = _ensure_mlx_adapter(mlx_dir) |
| 263 | assert out == mlx_dir |
| 264 | |
| 265 | def test_auto_converts_peft_dir_into_cache( |
| 266 | self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 267 | ) -> None: |
| 268 | """A PEFT-shaped dir gets converted into XDG_CACHE_HOME on |
| 269 | first call; the returned path is the cache dir, not the source.""" |
| 270 | from dlm_sway.backends.mlx import _ensure_mlx_adapter |
| 271 | |
| 272 | cache_root = tmp_path / "cache" |
| 273 | monkeypatch.setenv("XDG_CACHE_HOME", str(cache_root)) |
| 274 | |
| 275 | peft_dir = tmp_path / "peft" |
| 276 | _write_synthetic_peft_adapter(peft_dir) |
| 277 | out = _ensure_mlx_adapter(peft_dir) |
| 278 | |
| 279 | assert out != peft_dir |
| 280 | assert (out / "adapters.safetensors").exists() |
| 281 | assert (out / "adapter_config.json").exists() |
| 282 | assert str(out).startswith(str(cache_root)) |
| 283 | |
| 284 | def test_repeated_calls_short_circuit_on_cache_hit( |
| 285 | self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 286 | ) -> None: |
| 287 | """Same PEFT bytes → same cache hash → second call returns the |
| 288 | cached dir without re-converting (touch mtime to detect).""" |
| 289 | from dlm_sway.backends.mlx import _ensure_mlx_adapter |
| 290 | |
| 291 | monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "cache")) |
| 292 | peft_dir = tmp_path / "peft" |
| 293 | _write_synthetic_peft_adapter(peft_dir) |
| 294 | |
| 295 | first = _ensure_mlx_adapter(peft_dir) |
| 296 | first_mtime = (first / "adapters.safetensors").stat().st_mtime_ns |
| 297 | |
| 298 | # Second call — should NOT rewrite the file. |
| 299 | second = _ensure_mlx_adapter(peft_dir) |
| 300 | assert second == first |
| 301 | assert (second / "adapters.safetensors").stat().st_mtime_ns == first_mtime |
| 302 | |
| 303 | def test_passes_through_unrecognized_dir(self, tmp_path: Path) -> None: |
| 304 | """A directory with neither shape — let mlx_lm.load surface |
| 305 | its own error rather than this helper second-guessing.""" |
| 306 | from dlm_sway.backends.mlx import _ensure_mlx_adapter |
| 307 | |
| 308 | empty = tmp_path / "empty" |
| 309 | empty.mkdir() |
| 310 | out = _ensure_mlx_adapter(empty) |
| 311 | assert out == empty |
| 312 | |
| 313 | |
| 314 | class TestModulesToSave: |
| 315 | """``modules_to_save`` (e.g. embed_tokens, lm_head) must be skipped |
| 316 | cleanly with a report entry, not crash the converter.""" |
| 317 | |
| 318 | def test_modules_to_save_skipped_and_reported(self, tmp_path: Path) -> None: |
| 319 | src = tmp_path / "peft" |
| 320 | dst = tmp_path / "mlx" |
| 321 | _write_synthetic_peft_adapter(src, num_layers=1, modules_to_save=["embed_tokens"]) |
| 322 | # Inject a non-LoRA full-weight tensor that simulates |
| 323 | # PEFT's modules_to_save serialization. |
| 324 | existing = load_file(str(src / "adapter_model.safetensors")) |
| 325 | existing["base_model.model.model.embed_tokens.modules_to_save.default.weight"] = np.zeros( |
| 326 | (100, 64), dtype=np.float32 |
| 327 | ) |
| 328 | save_file(existing, str(src / "adapter_model.safetensors")) |
| 329 | |
| 330 | report = convert_peft_to_mlx(src, dst) |
| 331 | assert len(report["modules_to_save_skipped"]) == 1 |
| 332 | assert "embed_tokens" in report["modules_to_save_skipped"][0] |
| 333 | # Real LoRA factors still extracted. |
| 334 | assert report["num_keys"] == 4 # 1 layer × 2 modules × 2 factors |