@@ -1,30 +1,44 @@ |
| 1 | | -"""`stage_mlx_adapter_dir` — audit-08 B3 fix. |
| 1 | +"""`stage_mlx_adapter_dir` — covers the pure staging path. |
| 2 | 2 | |
| 3 | | -Covers the file-layout piece of the MLX backend load path without |
| 4 | | -importing mlx / mlx_lm. The heavy `MlxBackend.load` path stays |
| 5 | | -`# pragma: no cover` — the correctness claim is that given a |
| 6 | | -well-shaped PEFT adapter dir, staging produces the pair mlx_lm |
| 7 | | -requires (`adapters.npz` + `adapter_config.json`). |
| 3 | +After the audit-11 B1 fix, staging no longer copies PEFT's config |
| 4 | +verbatim — it translates it into mlx-lm's schema (`num_layers` + |
| 5 | +`lora_parameters`) and writes `adapters.safetensors` rather than |
| 6 | +`.npz`. The heavy `MlxBackend.load` path stays `# pragma: no cover`; |
| 7 | +the correctness claim here is that given a well-shaped PEFT adapter |
| 8 | +dir + a base HF id whose config is cached locally, staging produces |
| 9 | +the pair mlx_lm requires. |
| 8 | 10 | """ |
| 9 | 11 | |
| 10 | 12 | from __future__ import annotations |
| 11 | 13 | |
| 12 | 14 | import json |
| 13 | 15 | from pathlib import Path |
| 16 | +from typing import Any |
| 14 | 17 | |
| 15 | 18 | import pytest |
| 16 | 19 | import torch |
| 17 | | -from safetensors.torch import save_file |
| 20 | +from safetensors.torch import load_file, save_file |
| 18 | 21 | |
| 22 | +from dlm.inference.backends import mlx_backend |
| 19 | 23 | from dlm.inference.backends.mlx_backend import stage_mlx_adapter_dir |
| 20 | 24 | from dlm.inference.mlx_adapter import MlxConversionError |
| 21 | 25 | |
| 22 | 26 | |
| 23 | | -def _write_peft_adapter(dst: Path) -> None: |
| 27 | +def _write_peft_adapter(dst: Path, *, target_modules: list[str] | None = None) -> None: |
| 24 | 28 | """Write a minimal PEFT-shaped adapter dir.""" |
| 25 | 29 | dst.mkdir(parents=True, exist_ok=True) |
| 30 | + resolved_targets = ["q_proj", "v_proj"] if target_modules is None else target_modules |
| 26 | 31 | (dst / "adapter_config.json").write_text( |
| 27 | | - json.dumps({"peft_type": "LORA", "r": 8, "lora_alpha": 16}) |
| 32 | + json.dumps( |
| 33 | + { |
| 34 | + "peft_type": "LORA", |
| 35 | + "r": 8, |
| 36 | + "lora_alpha": 16, |
| 37 | + "lora_dropout": 0.05, |
| 38 | + "target_modules": resolved_targets, |
| 39 | + "use_dora": False, |
| 40 | + } |
| 41 | + ) |
| 28 | 42 | ) |
| 29 | 43 | tensors = { |
| 30 | 44 | "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.zeros(8, 16), |
@@ -33,46 +47,114 @@ def _write_peft_adapter(dst: Path) -> None: |
| 33 | 47 | save_file(tensors, str(dst / "adapter_model.safetensors")) |
| 34 | 48 | |
| 35 | 49 | |
| 50 | +@pytest.fixture |
| 51 | +def stub_num_layers(monkeypatch: pytest.MonkeyPatch) -> None: |
| 52 | + """Bypass the HF cache lookup; pretend every base has 30 layers.""" |
| 53 | + |
| 54 | + def _stub(_hf_id: str) -> int: |
| 55 | + return 30 |
| 56 | + |
| 57 | + monkeypatch.setattr(mlx_backend, "_resolve_base_num_hidden_layers", _stub) |
| 58 | + |
| 59 | + |
| 36 | 60 | class TestStageSuccess: |
| 37 | | - def test_writes_npz_and_config(self, tmp_path: Path) -> None: |
| 61 | + def test_writes_safetensors_and_translated_config( |
| 62 | + self, tmp_path: Path, stub_num_layers: None |
| 63 | + ) -> None: |
| 38 | 64 | src = tmp_path / "peft" |
| 39 | 65 | dst = tmp_path / "mlx" |
| 40 | 66 | _write_peft_adapter(src) |
| 41 | | - staged = stage_mlx_adapter_dir(src, dst) |
| 67 | + staged = stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake") |
| 42 | 68 | |
| 43 | 69 | assert staged == dst |
| 44 | | - assert (dst / "adapters.npz").exists() |
| 70 | + assert (dst / "adapters.safetensors").exists() |
| 45 | 71 | assert (dst / "adapter_config.json").exists() |
| 46 | | - # Config is copied verbatim. |
| 47 | | - original = (src / "adapter_config.json").read_text() |
| 48 | | - assert (dst / "adapter_config.json").read_text() == original |
| 72 | + |
| 73 | + mlx_cfg: dict[str, Any] = json.loads( |
| 74 | + (dst / "adapter_config.json").read_text(encoding="utf-8") |
| 75 | + ) |
| 76 | + assert mlx_cfg["fine_tune_type"] == "lora" |
| 77 | + assert mlx_cfg["num_layers"] == 30 |
| 78 | + lp = mlx_cfg["lora_parameters"] |
| 79 | + assert lp["rank"] == 8 |
| 80 | + assert lp["scale"] == pytest.approx(16 / 8) |
| 81 | + assert lp["dropout"] == pytest.approx(0.05) |
| 82 | + assert lp["keys"] == ["q_proj", "v_proj"] |
| 83 | + |
| 84 | + def test_tensor_keys_match_mlx_layout( |
| 85 | + self, tmp_path: Path, stub_num_layers: None |
| 86 | + ) -> None: |
| 87 | + src = tmp_path / "peft" |
| 88 | + dst = tmp_path / "mlx" |
| 89 | + _write_peft_adapter(src) |
| 90 | + stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake") |
| 91 | + |
| 92 | + # Loaded keys should follow mlx-lm's flattened shape: no |
| 93 | + # `base_model.` prefix, lowercase `lora_a`/`lora_b`, no |
| 94 | + # trailing `.weight`. |
| 95 | + loaded = load_file(str(dst / "adapters.safetensors")) |
| 96 | + for key in loaded: |
| 97 | + assert not key.startswith("base_model.") |
| 98 | + assert ".lora_a" in key or ".lora_b" in key |
| 99 | + assert not key.endswith(".weight") |
| 100 | + |
| 101 | + def test_dora_emits_dora_fine_tune_type( |
| 102 | + self, tmp_path: Path, stub_num_layers: None |
| 103 | + ) -> None: |
| 104 | + src = tmp_path / "peft" |
| 105 | + dst = tmp_path / "mlx" |
| 106 | + _write_peft_adapter(src) |
| 107 | + cfg = json.loads((src / "adapter_config.json").read_text()) |
| 108 | + cfg["use_dora"] = True |
| 109 | + (src / "adapter_config.json").write_text(json.dumps(cfg)) |
| 110 | + stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake") |
| 111 | + |
| 112 | + mlx_cfg = json.loads((dst / "adapter_config.json").read_text(encoding="utf-8")) |
| 113 | + assert mlx_cfg["fine_tune_type"] == "dora" |
| 49 | 114 | |
| 50 | 115 | |
| 51 | 116 | class TestStagePreflight: |
| 52 | 117 | def test_missing_adapter_config_refused(self, tmp_path: Path) -> None: |
| 53 | | - """N10: reject a dir that doesn't look like a PEFT adapter.""" |
| 54 | 118 | src = tmp_path / "broken" |
| 55 | 119 | src.mkdir() |
| 56 | | - # tensor only, no config |
| 57 | 120 | save_file( |
| 58 | 121 | {"base_model.model.a.lora_A.weight": torch.zeros(2, 2)}, |
| 59 | 122 | str(src / "adapter_model.safetensors"), |
| 60 | 123 | ) |
| 61 | 124 | with pytest.raises(MlxConversionError, match="not a PEFT adapter dir"): |
| 62 | | - stage_mlx_adapter_dir(src, tmp_path / "out") |
| 125 | + stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake") |
| 63 | 126 | |
| 64 | 127 | def test_missing_safetensors_refused(self, tmp_path: Path) -> None: |
| 65 | | - """Merged-model dir (has config, no adapter_model.safetensors).""" |
| 66 | 128 | src = tmp_path / "merged" |
| 67 | 129 | src.mkdir() |
| 68 | 130 | (src / "adapter_config.json").write_text("{}") |
| 69 | 131 | with pytest.raises(MlxConversionError, match="no adapter_model.safetensors"): |
| 70 | | - stage_mlx_adapter_dir(src, tmp_path / "out") |
| 132 | + stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake") |
| 71 | 133 | |
| 72 | | - def test_dst_dir_created_if_missing(self, tmp_path: Path) -> None: |
| 134 | + def test_dst_dir_created_if_missing( |
| 135 | + self, tmp_path: Path, stub_num_layers: None |
| 136 | + ) -> None: |
| 73 | 137 | src = tmp_path / "peft" |
| 74 | 138 | _write_peft_adapter(src) |
| 75 | | - # nested non-existent dst |
| 76 | 139 | dst = tmp_path / "nested" / "dst" |
| 77 | | - stage_mlx_adapter_dir(src, dst) |
| 140 | + stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake") |
| 78 | 141 | assert dst.is_dir() |
| 142 | + |
| 143 | + def test_peft_config_missing_r_refused( |
| 144 | + self, tmp_path: Path, stub_num_layers: None |
| 145 | + ) -> None: |
| 146 | + src = tmp_path / "peft" |
| 147 | + _write_peft_adapter(src) |
| 148 | + cfg = json.loads((src / "adapter_config.json").read_text()) |
| 149 | + cfg.pop("r") |
| 150 | + (src / "adapter_config.json").write_text(json.dumps(cfg)) |
| 151 | + with pytest.raises(MlxConversionError, match="'r'"): |
| 152 | + stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake") |
| 153 | + |
| 154 | + def test_peft_config_empty_target_modules_refused( |
| 155 | + self, tmp_path: Path, stub_num_layers: None |
| 156 | + ) -> None: |
| 157 | + src = tmp_path / "peft" |
| 158 | + _write_peft_adapter(src, target_modules=[]) |
| 159 | + with pytest.raises(MlxConversionError, match="target_modules"): |
| 160 | + stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake") |