| 1 | """C5: MLX backend smoke test (darwin-arm64-only). |
| 2 | |
| 3 | Exercises ``MLXDifferentialBackend.next_token_dist`` and ``generate`` |
| 4 | on a real mlx_lm-loaded model with a small LoRA adapter attached. |
| 5 | The adapter is built once via ``tests/fixtures/build_mlx_adapter.py`` |
| 6 | (see that script's docstring); this test skips when the fixture |
| 7 | directory is absent rather than building it on the fly, since fixture |
| 8 | builds take ~30s and re-running the test should be cheap. |
| 9 | |
| 10 | Skips on: |
| 11 | - non-darwin platforms (mlx is Apple Silicon only) |
| 12 | - non-arm64 architectures |
| 13 | - missing ``mlx_lm`` import |
| 14 | - missing fixture directory |
| 15 | """ |
| 16 | |
| 17 | from __future__ import annotations |
| 18 | |
| 19 | import math |
| 20 | import platform |
| 21 | import sys |
| 22 | from pathlib import Path |
| 23 | |
| 24 | import numpy as np |
| 25 | import pytest |
| 26 | |
| 27 | pytestmark = [pytest.mark.slow, pytest.mark.online] |
| 28 | |
| 29 | _FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "mlx_adapter_smollm2_135m" |
| 30 | _MODEL_ID = "mlx-community/SmolLM2-135M-Instruct-4bit" |
| 31 | |
| 32 | |
| 33 | def _platform_supports_mlx() -> bool: |
| 34 | return sys.platform == "darwin" and platform.machine() == "arm64" |
| 35 | |
| 36 | |
| 37 | @pytest.fixture(scope="module") |
| 38 | def mlx_backend(): |
| 39 | if not _platform_supports_mlx(): |
| 40 | pytest.skip("MLX requires darwin-arm64") |
| 41 | pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests") |
| 42 | |
| 43 | if not _FIXTURE_DIR.exists(): |
| 44 | pytest.skip( |
| 45 | f"missing MLX adapter fixture at {_FIXTURE_DIR} — generate it via " |
| 46 | f"`python tests/fixtures/build_mlx_adapter.py`" |
| 47 | ) |
| 48 | |
| 49 | from dlm_sway.backends.mlx import MLXDifferentialBackend |
| 50 | from dlm_sway.core.model import ModelSpec |
| 51 | |
| 52 | backend = MLXDifferentialBackend( |
| 53 | base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"), |
| 54 | adapter_path=_FIXTURE_DIR, |
| 55 | ) |
| 56 | yield backend |
| 57 | backend.close() |
| 58 | |
| 59 | |
| 60 | def test_next_token_dist_returns_finite_topk(mlx_backend) -> None: |
| 61 | with mlx_backend.as_base() as b: |
| 62 | d = b.next_token_dist("The capital of France is", top_k=32) |
| 63 | assert d.token_ids.shape == (32,) |
| 64 | assert d.logprobs.shape == (32,) |
| 65 | assert np.all(np.isfinite(d.logprobs)) |
| 66 | # Top-k must be sorted in descending probability order. |
| 67 | assert np.all(np.diff(d.logprobs) <= 1e-7) |
| 68 | |
| 69 | |
| 70 | def test_adapter_changes_distribution(mlx_backend) -> None: |
| 71 | """The whole point of the differential backend: base ≠ ft.""" |
| 72 | prompt = "The adapter does" |
| 73 | with mlx_backend.as_base() as b: |
| 74 | base_dist = b.next_token_dist(prompt, top_k=32) |
| 75 | with mlx_backend.as_finetuned() as f: |
| 76 | ft_dist = f.next_token_dist(prompt, top_k=32) |
| 77 | same_ids = np.array_equal(base_dist.token_ids, ft_dist.token_ids) |
| 78 | if same_ids: |
| 79 | assert not np.allclose(base_dist.logprobs, ft_dist.logprobs, atol=1e-5) |
| 80 | |
| 81 | |
| 82 | def test_logprob_of_finite(mlx_backend) -> None: |
| 83 | with mlx_backend.as_base() as b: |
| 84 | lp = b.logprob_of("The capital of France is", " Paris") |
| 85 | assert math.isfinite(lp) |
| 86 | assert lp < 0.0 |
| 87 | |
| 88 | |
| 89 | def test_generate_returns_nonempty_string(mlx_backend) -> None: |
| 90 | with mlx_backend.as_base() as b: |
| 91 | out = b.generate("Hello", max_new_tokens=8, seed=0) |
| 92 | assert isinstance(out, str) |
| 93 | assert len(out) > 0 |