tests/mlx_smoke: darwin-arm64 MLX backend smoke + reproducible adapter fixture builder (C5)
- SHA
8917f1fcdef4ed155e01fe204c8adb512c979c26- Parents
-
c6f6de7 - Tree
ee47ced
8917f1f
8917f1fcdef4ed155e01fe204c8adb512c979c26c6f6de7
ee47ced| Status | File | + | - |
|---|---|---|---|
| A |
tests/fixtures/build_mlx_adapter.py
|
84 | 0 |
| A |
tests/integration/test_mlx_smoke.py
|
93 | 0 |
tests/fixtures/build_mlx_adapter.pyadded@@ -0,0 +1,84 @@ | ||
| 1 | +"""Reproducible MLX adapter fixture for the smoke test. | |
| 2 | + | |
| 3 | +Run this once on darwin-arm64 with the ``[mlx]`` extra installed to | |
| 4 | +generate a tiny LoRA adapter under | |
| 5 | +``tests/fixtures/mlx_adapter_smollm2_135m/``. The integration test | |
| 6 | +at ``tests/integration/test_mlx_smoke.py`` uses the directory if it | |
| 7 | +exists, else skips with a pointer here. | |
| 8 | + | |
| 9 | +This script is deliberately *not* vendored as a binary in the repo — | |
| 10 | +regenerating it from scratch stays reproducible, keeps the repo small, | |
| 11 | +and lets us bump the base model without re-checking binaries in. | |
| 12 | + | |
| 13 | +Usage: | |
| 14 | + | |
| 15 | + # Prerequisites: darwin-arm64 + ``uv pip install -e ".[mlx]"`` | |
| 16 | + python tests/fixtures/build_mlx_adapter.py | |
| 17 | +""" | |
| 18 | + | |
| 19 | +from __future__ import annotations | |
| 20 | + | |
| 21 | +import sys | |
| 22 | +from pathlib import Path | |
| 23 | + | |
| 24 | + | |
| 25 | +def main() -> int: | |
| 26 | + if sys.platform != "darwin": | |
| 27 | + print("build_mlx_adapter.py requires darwin-arm64", file=sys.stderr) | |
| 28 | + return 1 | |
| 29 | + | |
| 30 | + try: | |
| 31 | + import mlx.core as mx | |
| 32 | + from mlx_lm import load | |
| 33 | + from mlx_lm.tuner.utils import linear_to_lora_layers | |
| 34 | + except ImportError as exc: | |
| 35 | + print( | |
| 36 | + f"mlx / mlx_lm not importable: {exc}\n" | |
| 37 | + "Install the [mlx] extra: uv pip install -e '.[mlx]'", | |
| 38 | + file=sys.stderr, | |
| 39 | + ) | |
| 40 | + return 1 | |
| 41 | + | |
| 42 | + model_id = "mlx-community/SmolLM2-135M-Instruct-4bit" | |
| 43 | + out_dir = Path(__file__).parent / "mlx_adapter_smollm2_135m" | |
| 44 | + out_dir.mkdir(parents=True, exist_ok=True) | |
| 45 | + | |
| 46 | + print(f"Loading {model_id} (will download on first run)…") | |
| 47 | + model, _tokenizer = load(model_id) | |
| 48 | + | |
| 49 | + # Apply a tiny LoRA shim over the attention projections and | |
| 50 | + # randomize the A/B weights deterministically. | |
| 51 | + lora_config = { | |
| 52 | + "rank": 4, | |
| 53 | + "alpha": 8, | |
| 54 | + "dropout": 0.0, | |
| 55 | + "scale": 10.0, | |
| 56 | + "keys": ["q_proj", "v_proj"], | |
| 57 | + } | |
| 58 | + linear_to_lora_layers(model, num_layers=2, config=lora_config) | |
| 59 | + | |
| 60 | + # Seed-scale lora_B so the adapter actually changes outputs. | |
| 61 | + mx.random.seed(0) | |
| 62 | + params = dict(model.trainable_parameters()) | |
| 63 | + for name, arr in params.items(): | |
| 64 | + if "lora_b" in name.lower(): | |
| 65 | + params[name] = mx.random.normal(arr.shape) * 0.05 | |
| 66 | + | |
| 67 | + # Save adapters to the mlx_lm convention: <dir>/adapters.safetensors | |
| 68 | + # plus an adapter_config.json stub mlx_lm.load can find. | |
| 69 | + import json | |
| 70 | + | |
| 71 | + from mlx.utils import tree_flatten | |
| 72 | + from safetensors.numpy import save_file as save_safetensors | |
| 73 | + | |
| 74 | + save_safetensors( | |
| 75 | + str(out_dir / "adapters.safetensors"), | |
| 76 | + {k: v.astype(mx.float16) for k, v in tree_flatten(params)}, | |
| 77 | + ) | |
| 78 | + (out_dir / "adapter_config.json").write_text(json.dumps(lora_config, indent=2)) | |
| 79 | + print(f"Wrote adapter fixture to {out_dir}") | |
| 80 | + return 0 | |
| 81 | + | |
| 82 | + | |
| 83 | +if __name__ == "__main__": | |
| 84 | + raise SystemExit(main()) | |
tests/integration/test_mlx_smoke.pyadded@@ -0,0 +1,93 @@ | ||
| 1 | +"""C5: MLX backend smoke test (darwin-arm64-only). | |
| 2 | + | |
| 3 | +Exercises ``MLXDifferentialBackend.next_token_dist`` and ``generate`` | |
| 4 | +on a real mlx_lm-loaded model with a small LoRA adapter attached. | |
| 5 | +The adapter is built once via ``tests/fixtures/build_mlx_adapter.py`` | |
| 6 | +(see that script's docstring); this test skips when the fixture | |
| 7 | +directory is absent rather than building it on the fly, since fixture | |
| 8 | +builds take ~30s and re-running the test should be cheap. | |
| 9 | + | |
| 10 | +Skips on: | |
| 11 | +- non-darwin platforms (mlx is Apple Silicon only) | |
| 12 | +- non-arm64 architectures | |
| 13 | +- missing ``mlx_lm`` import | |
| 14 | +- missing fixture directory | |
| 15 | +""" | |
| 16 | + | |
| 17 | +from __future__ import annotations | |
| 18 | + | |
| 19 | +import math | |
| 20 | +import platform | |
| 21 | +import sys | |
| 22 | +from pathlib import Path | |
| 23 | + | |
| 24 | +import numpy as np | |
| 25 | +import pytest | |
| 26 | + | |
| 27 | +pytestmark = [pytest.mark.slow, pytest.mark.online] | |
| 28 | + | |
| 29 | +_FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "mlx_adapter_smollm2_135m" | |
| 30 | +_MODEL_ID = "mlx-community/SmolLM2-135M-Instruct-4bit" | |
| 31 | + | |
| 32 | + | |
| 33 | +def _platform_supports_mlx() -> bool: | |
| 34 | + return sys.platform == "darwin" and platform.machine() == "arm64" | |
| 35 | + | |
| 36 | + | |
| 37 | +@pytest.fixture(scope="module") | |
| 38 | +def mlx_backend(): | |
| 39 | + if not _platform_supports_mlx(): | |
| 40 | + pytest.skip("MLX requires darwin-arm64") | |
| 41 | + pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests") | |
| 42 | + | |
| 43 | + if not _FIXTURE_DIR.exists(): | |
| 44 | + pytest.skip( | |
| 45 | + f"missing MLX adapter fixture at {_FIXTURE_DIR} — generate it via " | |
| 46 | + f"`python tests/fixtures/build_mlx_adapter.py`" | |
| 47 | + ) | |
| 48 | + | |
| 49 | + from dlm_sway.backends.mlx import MLXDifferentialBackend | |
| 50 | + from dlm_sway.core.model import ModelSpec | |
| 51 | + | |
| 52 | + backend = MLXDifferentialBackend( | |
| 53 | + base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"), | |
| 54 | + adapter_path=_FIXTURE_DIR, | |
| 55 | + ) | |
| 56 | + yield backend | |
| 57 | + backend.close() | |
| 58 | + | |
| 59 | + | |
| 60 | +def test_next_token_dist_returns_finite_topk(mlx_backend) -> None: | |
| 61 | + with mlx_backend.as_base() as b: | |
| 62 | + d = b.next_token_dist("The capital of France is", top_k=32) | |
| 63 | + assert d.token_ids.shape == (32,) | |
| 64 | + assert d.logprobs.shape == (32,) | |
| 65 | + assert np.all(np.isfinite(d.logprobs)) | |
| 66 | + # Top-k must be sorted in descending probability order. | |
| 67 | + assert np.all(np.diff(d.logprobs) <= 1e-7) | |
| 68 | + | |
| 69 | + | |
| 70 | +def test_adapter_changes_distribution(mlx_backend) -> None: | |
| 71 | + """The whole point of the differential backend: base ≠ ft.""" | |
| 72 | + prompt = "The adapter does" | |
| 73 | + with mlx_backend.as_base() as b: | |
| 74 | + base_dist = b.next_token_dist(prompt, top_k=32) | |
| 75 | + with mlx_backend.as_finetuned() as f: | |
| 76 | + ft_dist = f.next_token_dist(prompt, top_k=32) | |
| 77 | + same_ids = np.array_equal(base_dist.token_ids, ft_dist.token_ids) | |
| 78 | + if same_ids: | |
| 79 | + assert not np.allclose(base_dist.logprobs, ft_dist.logprobs, atol=1e-5) | |
| 80 | + | |
| 81 | + | |
| 82 | +def test_logprob_of_finite(mlx_backend) -> None: | |
| 83 | + with mlx_backend.as_base() as b: | |
| 84 | + lp = b.logprob_of("The capital of France is", " Paris") | |
| 85 | + assert math.isfinite(lp) | |
| 86 | + assert lp < 0.0 | |
| 87 | + | |
| 88 | + | |
| 89 | +def test_generate_returns_nonempty_string(mlx_backend) -> None: | |
| 90 | + with mlx_backend.as_base() as b: | |
| 91 | + out = b.generate("Hello", max_new_tokens=8, seed=0) | |
| 92 | + assert isinstance(out, str) | |
| 93 | + assert len(out) > 0 | |