tenseleyflow/documentlanguagemodel / fbf82a6

Browse files

test(inference): MLX stage adapter tests cover translated config schema (audit-11 B1)

Authored by espadonne
SHA
fbf82a6c3c22523d502f4d95efdc4da8a1408466
Parents
66db773
Tree
9d2eaea

1 changed file

StatusFile+-
M tests/unit/inference/test_mlx_stage_adapter_dir.py 105 23
tests/unit/inference/test_mlx_stage_adapter_dir.pymodified
@@ -1,30 +1,44 @@
1
-"""`stage_mlx_adapter_dir` — audit-08 B3 fix.
1
+"""`stage_mlx_adapter_dir` — covers the pure staging path.
22
 
3
-Covers the file-layout piece of the MLX backend load path without
4
-importing mlx / mlx_lm. The heavy `MlxBackend.load` path stays
5
-`# pragma: no cover` — the correctness claim is that given a
6
-well-shaped PEFT adapter dir, staging produces the pair mlx_lm
7
-requires (`adapters.npz` + `adapter_config.json`).
3
+After the audit-11 B1 fix, staging no longer copies PEFT's config
4
+verbatim — it translates it into mlx-lm's schema (`num_layers` +
5
+`lora_parameters`) and writes `adapters.safetensors` rather than
6
+`.npz`. The heavy `MlxBackend.load` path stays `# pragma: no cover`;
7
+the correctness claim here is that given a well-shaped PEFT adapter
8
+dir + a base HF id whose config is cached locally, staging produces
9
+the pair mlx_lm requires.
810
 """
911
 
1012
 from __future__ import annotations
1113
 
1214
 import json
1315
 from pathlib import Path
16
+from typing import Any
1417
 
1518
 import pytest
1619
 import torch
17
-from safetensors.torch import save_file
20
+from safetensors.torch import load_file, save_file
1821
 
22
+from dlm.inference.backends import mlx_backend
1923
 from dlm.inference.backends.mlx_backend import stage_mlx_adapter_dir
2024
 from dlm.inference.mlx_adapter import MlxConversionError
2125
 
2226
 
23
-def _write_peft_adapter(dst: Path) -> None:
27
+def _write_peft_adapter(dst: Path, *, target_modules: list[str] | None = None) -> None:
2428
     """Write a minimal PEFT-shaped adapter dir."""
2529
     dst.mkdir(parents=True, exist_ok=True)
30
+    resolved_targets = ["q_proj", "v_proj"] if target_modules is None else target_modules
2631
     (dst / "adapter_config.json").write_text(
27
-        json.dumps({"peft_type": "LORA", "r": 8, "lora_alpha": 16})
32
+        json.dumps(
33
+            {
34
+                "peft_type": "LORA",
35
+                "r": 8,
36
+                "lora_alpha": 16,
37
+                "lora_dropout": 0.05,
38
+                "target_modules": resolved_targets,
39
+                "use_dora": False,
40
+            }
41
+        )
2842
     )
2943
     tensors = {
3044
         "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.zeros(8, 16),
@@ -33,46 +47,114 @@ def _write_peft_adapter(dst: Path) -> None:
3347
     save_file(tensors, str(dst / "adapter_model.safetensors"))
3448
 
3549
 
50
+@pytest.fixture
51
+def stub_num_layers(monkeypatch: pytest.MonkeyPatch) -> None:
52
+    """Bypass the HF cache lookup; pretend every base has 30 layers."""
53
+
54
+    def _stub(_hf_id: str) -> int:
55
+        return 30
56
+
57
+    monkeypatch.setattr(mlx_backend, "_resolve_base_num_hidden_layers", _stub)
58
+
59
+
3660
 class TestStageSuccess:
37
-    def test_writes_npz_and_config(self, tmp_path: Path) -> None:
61
+    def test_writes_safetensors_and_translated_config(
62
+        self, tmp_path: Path, stub_num_layers: None
63
+    ) -> None:
3864
         src = tmp_path / "peft"
3965
         dst = tmp_path / "mlx"
4066
         _write_peft_adapter(src)
41
-        staged = stage_mlx_adapter_dir(src, dst)
67
+        staged = stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake")
4268
 
4369
         assert staged == dst
44
-        assert (dst / "adapters.npz").exists()
70
+        assert (dst / "adapters.safetensors").exists()
4571
         assert (dst / "adapter_config.json").exists()
46
-        # Config is copied verbatim.
47
-        original = (src / "adapter_config.json").read_text()
48
-        assert (dst / "adapter_config.json").read_text() == original
72
+
73
+        mlx_cfg: dict[str, Any] = json.loads(
74
+            (dst / "adapter_config.json").read_text(encoding="utf-8")
75
+        )
76
+        assert mlx_cfg["fine_tune_type"] == "lora"
77
+        assert mlx_cfg["num_layers"] == 30
78
+        lp = mlx_cfg["lora_parameters"]
79
+        assert lp["rank"] == 8
80
+        assert lp["scale"] == pytest.approx(16 / 8)
81
+        assert lp["dropout"] == pytest.approx(0.05)
82
+        assert lp["keys"] == ["q_proj", "v_proj"]
83
+
84
+    def test_tensor_keys_match_mlx_layout(
85
+        self, tmp_path: Path, stub_num_layers: None
86
+    ) -> None:
87
+        src = tmp_path / "peft"
88
+        dst = tmp_path / "mlx"
89
+        _write_peft_adapter(src)
90
+        stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake")
91
+
92
+        # Loaded keys should follow mlx-lm's flattened shape: no
93
+        # `base_model.` prefix, lowercase `lora_a`/`lora_b`, no
94
+        # trailing `.weight`.
95
+        loaded = load_file(str(dst / "adapters.safetensors"))
96
+        for key in loaded:
97
+            assert not key.startswith("base_model.")
98
+            assert ".lora_a" in key or ".lora_b" in key
99
+            assert not key.endswith(".weight")
100
+
101
+    def test_dora_emits_dora_fine_tune_type(
102
+        self, tmp_path: Path, stub_num_layers: None
103
+    ) -> None:
104
+        src = tmp_path / "peft"
105
+        dst = tmp_path / "mlx"
106
+        _write_peft_adapter(src)
107
+        cfg = json.loads((src / "adapter_config.json").read_text())
108
+        cfg["use_dora"] = True
109
+        (src / "adapter_config.json").write_text(json.dumps(cfg))
110
+        stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake")
111
+
112
+        mlx_cfg = json.loads((dst / "adapter_config.json").read_text(encoding="utf-8"))
113
+        assert mlx_cfg["fine_tune_type"] == "dora"
49114
 
50115
 
51116
 class TestStagePreflight:
52117
     def test_missing_adapter_config_refused(self, tmp_path: Path) -> None:
53
-        """N10: reject a dir that doesn't look like a PEFT adapter."""
54118
         src = tmp_path / "broken"
55119
         src.mkdir()
56
-        # tensor only, no config
57120
         save_file(
58121
             {"base_model.model.a.lora_A.weight": torch.zeros(2, 2)},
59122
             str(src / "adapter_model.safetensors"),
60123
         )
61124
         with pytest.raises(MlxConversionError, match="not a PEFT adapter dir"):
62
-            stage_mlx_adapter_dir(src, tmp_path / "out")
125
+            stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake")
63126
 
64127
     def test_missing_safetensors_refused(self, tmp_path: Path) -> None:
65
-        """Merged-model dir (has config, no adapter_model.safetensors)."""
66128
         src = tmp_path / "merged"
67129
         src.mkdir()
68130
         (src / "adapter_config.json").write_text("{}")
69131
         with pytest.raises(MlxConversionError, match="no adapter_model.safetensors"):
70
-            stage_mlx_adapter_dir(src, tmp_path / "out")
132
+            stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake")
71133
 
72
-    def test_dst_dir_created_if_missing(self, tmp_path: Path) -> None:
134
+    def test_dst_dir_created_if_missing(
135
+        self, tmp_path: Path, stub_num_layers: None
136
+    ) -> None:
73137
         src = tmp_path / "peft"
74138
         _write_peft_adapter(src)
75
-        # nested non-existent dst
76139
         dst = tmp_path / "nested" / "dst"
77
-        stage_mlx_adapter_dir(src, dst)
140
+        stage_mlx_adapter_dir(src, dst, base_hf_id="org/fake")
78141
         assert dst.is_dir()
142
+
143
+    def test_peft_config_missing_r_refused(
144
+        self, tmp_path: Path, stub_num_layers: None
145
+    ) -> None:
146
+        src = tmp_path / "peft"
147
+        _write_peft_adapter(src)
148
+        cfg = json.loads((src / "adapter_config.json").read_text())
149
+        cfg.pop("r")
150
+        (src / "adapter_config.json").write_text(json.dumps(cfg))
151
+        with pytest.raises(MlxConversionError, match="'r'"):
152
+            stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake")
153
+
154
+    def test_peft_config_empty_target_modules_refused(
155
+        self, tmp_path: Path, stub_num_layers: None
156
+    ) -> None:
157
+        src = tmp_path / "peft"
158
+        _write_peft_adapter(src, target_modules=[])
159
+        with pytest.raises(MlxConversionError, match="target_modules"):
160
+            stage_mlx_adapter_dir(src, tmp_path / "out", base_hf_id="org/fake")