Python · 6967 bytes Raw Blame History
1 """S24 — end-to-end PEFT → MLX adapter conversion (darwin-arm64-only).
2
3 Closes the F01 audit gap: ``dlm train`` writes a PEFT-shaped adapter,
4 ``MLXDifferentialBackend`` is pointed at it, the backend auto-converts
5 into the user's cache, ``mlx_lm.load`` consumes the result, scoring
6 returns finite logprobs.
7
8 This is the **prove-the-value** test the sprint file calls out — every
9 other layer of testing (synthetic-input unit tests, CLI smoke tests)
10 is upstream of this. If this passes locally on darwin-arm64, the
11 headline ``.dlm → MLX`` flow works.
12
13 Skips cleanly on:
14 - non-darwin (mlx is Apple Silicon only)
15 - non-arm64
16 - ``mlx_lm`` not installed (the ``[mlx]`` extra is optional)
17 - ``peft`` / ``transformers`` not installed (the ``[hf]`` extra needed
18 to *build* the source PEFT adapter)
19 """
20
21 from __future__ import annotations
22
23 import math
24 import platform
25 import sys
26 from pathlib import Path
27
28 import numpy as np
29 import pytest
30
31 pytestmark = [pytest.mark.slow, pytest.mark.online]
32
33
34 # Default to the unquantized MLX repo because the 4-bit variant has
35 # slipped into a gated/auth state on HF Hub. Either repo's adapter
36 # slot works for the converter — the test only cares that mlx-lm
37 # loads our converted ``adapters.safetensors``.
38 _MODEL_ID = "mlx-community/SmolLM2-135M-Instruct"
39
40
41 def _platform_supports_mlx() -> bool:
42 return sys.platform == "darwin" and platform.machine() == "arm64"
43
44
45 def _build_random_peft_lora(base_dir: Path, out_dir: Path) -> None:
46 """Same deterministic LoRA the HF integration tests use, shipped
47 here because we don't want to import from another test file."""
48 import torch
49 from peft import LoraConfig, get_peft_model
50 from transformers import AutoModelForCausalLM, AutoTokenizer
51
52 torch.manual_seed(0)
53 tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
54 if tokenizer.pad_token_id is None:
55 tokenizer.pad_token = tokenizer.eos_token
56 base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
57 cfg = LoraConfig(
58 r=8,
59 lora_alpha=16,
60 target_modules=["q_proj", "v_proj"],
61 lora_dropout=0.0,
62 bias="none",
63 task_type="CAUSAL_LM",
64 )
65 peft_model = get_peft_model(base, cfg)
66 with torch.no_grad():
67 for name, param in peft_model.named_parameters():
68 if "lora_B" in name:
69 param.copy_(torch.randn_like(param) * 0.05)
70 peft_model.save_pretrained(str(out_dir))
71 tokenizer.save_pretrained(str(out_dir))
72
73
74 @pytest.fixture(scope="module")
75 def peft_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
76 if not _platform_supports_mlx():
77 pytest.skip("MLX requires darwin-arm64")
78 pytest.importorskip("peft", reason="needs the [hf] extra to build a PEFT adapter")
79 out = tmp_path_factory.mktemp("peft-for-mlx-convert")
80 _build_random_peft_lora(tiny_model_dir, out)
81 return out
82
83
84 @pytest.fixture(scope="module")
85 def mlx_backend(peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory):
86 """Point the MLX backend at a PEFT-shaped adapter dir; the backend
87 auto-converts into a tmp cache (XDG_CACHE_HOME redirected so we
88 don't pollute the user's real cache)."""
89 pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
90
91 # Redirect the cache so this test doesn't write to the user's
92 # ~/.cache/dlm-sway/. Each fixture invocation gets a fresh dir.
93 import os
94
95 cache_root = tmp_path_factory.mktemp("mlx-convert-cache")
96 prev = os.environ.get("XDG_CACHE_HOME")
97 os.environ["XDG_CACHE_HOME"] = str(cache_root)
98 try:
99 from dlm_sway.backends.mlx import MLXDifferentialBackend
100 from dlm_sway.core.model import ModelSpec
101
102 backend = MLXDifferentialBackend(
103 base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
104 adapter_path=peft_adapter,
105 )
106 yield backend, cache_root
107 backend.close()
108 finally:
109 if prev is None:
110 os.environ.pop("XDG_CACHE_HOME", None)
111 else:
112 os.environ["XDG_CACHE_HOME"] = prev
113
114
115 def test_auto_conversion_writes_to_xdg_cache(mlx_backend) -> None:
116 """The backend's __init__ must have populated the cache dir with
117 an MLX-format adapter — proves the auto-convert path fired."""
118 _backend, cache_root = mlx_backend
119 converted = list((cache_root / "dlm-sway" / "mlx-converted").glob("*"))
120 assert len(converted) == 1, f"expected exactly one cached MLX adapter dir, got {converted}"
121 cache_dir = converted[0]
122 assert (cache_dir / "adapters.safetensors").exists()
123 assert (cache_dir / "adapter_config.json").exists()
124
125
126 def test_next_token_dist_returns_finite_topk_via_converted_adapter(mlx_backend) -> None:
127 """The converted adapter, loaded via mlx_lm + scored via the MLX
128 backend, must produce finite, well-ordered top-k logprobs."""
129 backend, _ = mlx_backend
130 with backend.as_finetuned() as ft:
131 d = ft.next_token_dist("The capital of France is", top_k=32)
132 assert d.token_ids.shape == (32,)
133 assert d.logprobs.shape == (32,)
134 assert np.all(np.isfinite(d.logprobs))
135 assert np.all(np.diff(d.logprobs) <= 1e-7) # descending
136
137
138 def test_logprob_of_finite_via_converted_adapter(mlx_backend) -> None:
139 backend, _ = mlx_backend
140 with backend.as_finetuned() as ft:
141 lp = ft.logprob_of("The capital of France is", " Paris")
142 assert math.isfinite(lp)
143 assert lp < 0.0
144
145
146 def test_repeat_load_skips_reconvert(
147 peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory
148 ) -> None:
149 """Second backend instance against the same PEFT adapter must
150 short-circuit on the cache and NOT rewrite the converted file."""
151 pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
152
153 import os
154
155 cache_root = tmp_path_factory.mktemp("mlx-convert-cache-2")
156 prev = os.environ.get("XDG_CACHE_HOME")
157 os.environ["XDG_CACHE_HOME"] = str(cache_root)
158 try:
159 from dlm_sway.backends.mlx import MLXDifferentialBackend
160 from dlm_sway.core.model import ModelSpec
161
162 b1 = MLXDifferentialBackend(
163 base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
164 adapter_path=peft_adapter,
165 )
166 cache_dir = next((cache_root / "dlm-sway" / "mlx-converted").glob("*"))
167 first_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns
168 b1.close()
169
170 b2 = MLXDifferentialBackend(
171 base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
172 adapter_path=peft_adapter,
173 )
174 second_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns
175 b2.close()
176
177 assert second_mtime == first_mtime, (
178 "second backend init re-wrote the cached MLX adapter — cache short-circuit is broken"
179 )
180 finally:
181 if prev is None:
182 os.environ.pop("XDG_CACHE_HOME", None)
183 else:
184 os.environ["XDG_CACHE_HOME"] = prev