@@ -0,0 +1,184 @@ |
| 1 | +"""S24 — end-to-end PEFT → MLX adapter conversion (darwin-arm64-only). |
| 2 | + |
| 3 | +Closes the F01 audit gap: ``dlm train`` writes a PEFT-shaped adapter, |
| 4 | +``MLXDifferentialBackend`` is pointed at it, the backend auto-converts |
| 5 | +into the user's cache, ``mlx_lm.load`` consumes the result, scoring |
| 6 | +returns finite logprobs. |
| 7 | + |
| 8 | +This is the **prove-the-value** test the sprint file calls out — every |
| 9 | +other layer of testing (synthetic-input unit tests, CLI smoke tests) |
| 10 | +is upstream of this. If this passes locally on darwin-arm64, the |
| 11 | +headline ``.dlm → MLX`` flow works. |
| 12 | + |
| 13 | +Skips cleanly on: |
| 14 | +- non-darwin (mlx is Apple Silicon only) |
| 15 | +- non-arm64 |
| 16 | +- ``mlx_lm`` not installed (the ``[mlx]`` extra is optional) |
| 17 | +- ``peft`` / ``transformers`` not installed (the ``[hf]`` extra needed |
| 18 | + to *build* the source PEFT adapter) |
| 19 | +""" |
| 20 | + |
| 21 | +from __future__ import annotations |
| 22 | + |
| 23 | +import math |
| 24 | +import platform |
| 25 | +import sys |
| 26 | +from pathlib import Path |
| 27 | + |
| 28 | +import numpy as np |
| 29 | +import pytest |
| 30 | + |
| 31 | +pytestmark = [pytest.mark.slow, pytest.mark.online] |
| 32 | + |
| 33 | + |
| 34 | +# Default to the unquantized MLX repo because the 4-bit variant has |
| 35 | +# slipped into a gated/auth state on HF Hub. Either repo's adapter |
| 36 | +# slot works for the converter — the test only cares that mlx-lm |
| 37 | +# loads our converted ``adapters.safetensors``. |
| 38 | +_MODEL_ID = "mlx-community/SmolLM2-135M-Instruct" |
| 39 | + |
| 40 | + |
| 41 | +def _platform_supports_mlx() -> bool: |
| 42 | + return sys.platform == "darwin" and platform.machine() == "arm64" |
| 43 | + |
| 44 | + |
| 45 | +def _build_random_peft_lora(base_dir: Path, out_dir: Path) -> None: |
| 46 | + """Same deterministic LoRA the HF integration tests use, shipped |
| 47 | + here because we don't want to import from another test file.""" |
| 48 | + import torch |
| 49 | + from peft import LoraConfig, get_peft_model |
| 50 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 51 | + |
| 52 | + torch.manual_seed(0) |
| 53 | + tokenizer = AutoTokenizer.from_pretrained(str(base_dir)) |
| 54 | + if tokenizer.pad_token_id is None: |
| 55 | + tokenizer.pad_token = tokenizer.eos_token |
| 56 | + base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32) |
| 57 | + cfg = LoraConfig( |
| 58 | + r=8, |
| 59 | + lora_alpha=16, |
| 60 | + target_modules=["q_proj", "v_proj"], |
| 61 | + lora_dropout=0.0, |
| 62 | + bias="none", |
| 63 | + task_type="CAUSAL_LM", |
| 64 | + ) |
| 65 | + peft_model = get_peft_model(base, cfg) |
| 66 | + with torch.no_grad(): |
| 67 | + for name, param in peft_model.named_parameters(): |
| 68 | + if "lora_B" in name: |
| 69 | + param.copy_(torch.randn_like(param) * 0.05) |
| 70 | + peft_model.save_pretrained(str(out_dir)) |
| 71 | + tokenizer.save_pretrained(str(out_dir)) |
| 72 | + |
| 73 | + |
| 74 | +@pytest.fixture(scope="module") |
| 75 | +def peft_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path: |
| 76 | + if not _platform_supports_mlx(): |
| 77 | + pytest.skip("MLX requires darwin-arm64") |
| 78 | + pytest.importorskip("peft", reason="needs the [hf] extra to build a PEFT adapter") |
| 79 | + out = tmp_path_factory.mktemp("peft-for-mlx-convert") |
| 80 | + _build_random_peft_lora(tiny_model_dir, out) |
| 81 | + return out |
| 82 | + |
| 83 | + |
| 84 | +@pytest.fixture(scope="module") |
| 85 | +def mlx_backend(peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory): |
| 86 | + """Point the MLX backend at a PEFT-shaped adapter dir; the backend |
| 87 | + auto-converts into a tmp cache (XDG_CACHE_HOME redirected so we |
| 88 | + don't pollute the user's real cache).""" |
| 89 | + pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests") |
| 90 | + |
| 91 | + # Redirect the cache so this test doesn't write to the user's |
| 92 | + # ~/.cache/dlm-sway/. Each fixture invocation gets a fresh dir. |
| 93 | + import os |
| 94 | + |
| 95 | + cache_root = tmp_path_factory.mktemp("mlx-convert-cache") |
| 96 | + prev = os.environ.get("XDG_CACHE_HOME") |
| 97 | + os.environ["XDG_CACHE_HOME"] = str(cache_root) |
| 98 | + try: |
| 99 | + from dlm_sway.backends.mlx import MLXDifferentialBackend |
| 100 | + from dlm_sway.core.model import ModelSpec |
| 101 | + |
| 102 | + backend = MLXDifferentialBackend( |
| 103 | + base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"), |
| 104 | + adapter_path=peft_adapter, |
| 105 | + ) |
| 106 | + yield backend, cache_root |
| 107 | + backend.close() |
| 108 | + finally: |
| 109 | + if prev is None: |
| 110 | + os.environ.pop("XDG_CACHE_HOME", None) |
| 111 | + else: |
| 112 | + os.environ["XDG_CACHE_HOME"] = prev |
| 113 | + |
| 114 | + |
| 115 | +def test_auto_conversion_writes_to_xdg_cache(mlx_backend) -> None: |
| 116 | + """The backend's __init__ must have populated the cache dir with |
| 117 | + an MLX-format adapter — proves the auto-convert path fired.""" |
| 118 | + _backend, cache_root = mlx_backend |
| 119 | + converted = list((cache_root / "dlm-sway" / "mlx-converted").glob("*")) |
| 120 | + assert len(converted) == 1, f"expected exactly one cached MLX adapter dir, got {converted}" |
| 121 | + cache_dir = converted[0] |
| 122 | + assert (cache_dir / "adapters.safetensors").exists() |
| 123 | + assert (cache_dir / "adapter_config.json").exists() |
| 124 | + |
| 125 | + |
| 126 | +def test_next_token_dist_returns_finite_topk_via_converted_adapter(mlx_backend) -> None: |
| 127 | + """The converted adapter, loaded via mlx_lm + scored via the MLX |
| 128 | + backend, must produce finite, well-ordered top-k logprobs.""" |
| 129 | + backend, _ = mlx_backend |
| 130 | + with backend.as_finetuned() as ft: |
| 131 | + d = ft.next_token_dist("The capital of France is", top_k=32) |
| 132 | + assert d.token_ids.shape == (32,) |
| 133 | + assert d.logprobs.shape == (32,) |
| 134 | + assert np.all(np.isfinite(d.logprobs)) |
| 135 | + assert np.all(np.diff(d.logprobs) <= 1e-7) # descending |
| 136 | + |
| 137 | + |
| 138 | +def test_logprob_of_finite_via_converted_adapter(mlx_backend) -> None: |
| 139 | + backend, _ = mlx_backend |
| 140 | + with backend.as_finetuned() as ft: |
| 141 | + lp = ft.logprob_of("The capital of France is", " Paris") |
| 142 | + assert math.isfinite(lp) |
| 143 | + assert lp < 0.0 |
| 144 | + |
| 145 | + |
| 146 | +def test_repeat_load_skips_reconvert( |
| 147 | + peft_adapter: Path, tmp_path_factory: pytest.TempPathFactory |
| 148 | +) -> None: |
| 149 | + """Second backend instance against the same PEFT adapter must |
| 150 | + short-circuit on the cache and NOT rewrite the converted file.""" |
| 151 | + pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests") |
| 152 | + |
| 153 | + import os |
| 154 | + |
| 155 | + cache_root = tmp_path_factory.mktemp("mlx-convert-cache-2") |
| 156 | + prev = os.environ.get("XDG_CACHE_HOME") |
| 157 | + os.environ["XDG_CACHE_HOME"] = str(cache_root) |
| 158 | + try: |
| 159 | + from dlm_sway.backends.mlx import MLXDifferentialBackend |
| 160 | + from dlm_sway.core.model import ModelSpec |
| 161 | + |
| 162 | + b1 = MLXDifferentialBackend( |
| 163 | + base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"), |
| 164 | + adapter_path=peft_adapter, |
| 165 | + ) |
| 166 | + cache_dir = next((cache_root / "dlm-sway" / "mlx-converted").glob("*")) |
| 167 | + first_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns |
| 168 | + b1.close() |
| 169 | + |
| 170 | + b2 = MLXDifferentialBackend( |
| 171 | + base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"), |
| 172 | + adapter_path=peft_adapter, |
| 173 | + ) |
| 174 | + second_mtime = (cache_dir / "adapters.safetensors").stat().st_mtime_ns |
| 175 | + b2.close() |
| 176 | + |
| 177 | + assert second_mtime == first_mtime, ( |
| 178 | + "second backend init re-wrote the cached MLX adapter — cache short-circuit is broken" |
| 179 | + ) |
| 180 | + finally: |
| 181 | + if prev is None: |
| 182 | + os.environ.pop("XDG_CACHE_HOME", None) |
| 183 | + else: |
| 184 | + os.environ["XDG_CACHE_HOME"] = prev |