Python · 3090 bytes Raw Blame History
1 """C5: MLX backend smoke test (darwin-arm64-only).
2
3 Exercises ``MLXDifferentialBackend.next_token_dist`` and ``generate``
4 on a real mlx_lm-loaded model with a small LoRA adapter attached.
5 The adapter is built once via ``tests/fixtures/build_mlx_adapter.py``
6 (see that script's docstring); this test skips when the fixture
7 directory is absent rather than building it on the fly, since fixture
8 builds take ~30s and re-running the test should be cheap.
9
10 Skips on:
11 - non-darwin platforms (mlx is Apple Silicon only)
12 - non-arm64 architectures
13 - missing ``mlx_lm`` import
14 - missing fixture directory
15 """
16
17 from __future__ import annotations
18
19 import math
20 import platform
21 import sys
22 from pathlib import Path
23
24 import numpy as np
25 import pytest
26
27 pytestmark = [pytest.mark.slow, pytest.mark.online]
28
29 _FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "mlx_adapter_smollm2_135m"
30 _MODEL_ID = "mlx-community/SmolLM2-135M-Instruct-4bit"
31
32
33 def _platform_supports_mlx() -> bool:
34 return sys.platform == "darwin" and platform.machine() == "arm64"
35
36
37 @pytest.fixture(scope="module")
38 def mlx_backend():
39 if not _platform_supports_mlx():
40 pytest.skip("MLX requires darwin-arm64")
41 pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
42
43 if not _FIXTURE_DIR.exists():
44 pytest.skip(
45 f"missing MLX adapter fixture at {_FIXTURE_DIR} — generate it via "
46 f"`python tests/fixtures/build_mlx_adapter.py`"
47 )
48
49 from dlm_sway.backends.mlx import MLXDifferentialBackend
50 from dlm_sway.core.model import ModelSpec
51
52 backend = MLXDifferentialBackend(
53 base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
54 adapter_path=_FIXTURE_DIR,
55 )
56 yield backend
57 backend.close()
58
59
60 def test_next_token_dist_returns_finite_topk(mlx_backend) -> None:
61 with mlx_backend.as_base() as b:
62 d = b.next_token_dist("The capital of France is", top_k=32)
63 assert d.token_ids.shape == (32,)
64 assert d.logprobs.shape == (32,)
65 assert np.all(np.isfinite(d.logprobs))
66 # Top-k must be sorted in descending probability order.
67 assert np.all(np.diff(d.logprobs) <= 1e-7)
68
69
70 def test_adapter_changes_distribution(mlx_backend) -> None:
71 """The whole point of the differential backend: base ≠ ft."""
72 prompt = "The adapter does"
73 with mlx_backend.as_base() as b:
74 base_dist = b.next_token_dist(prompt, top_k=32)
75 with mlx_backend.as_finetuned() as f:
76 ft_dist = f.next_token_dist(prompt, top_k=32)
77 same_ids = np.array_equal(base_dist.token_ids, ft_dist.token_ids)
78 if same_ids:
79 assert not np.allclose(base_dist.logprobs, ft_dist.logprobs, atol=1e-5)
80
81
82 def test_logprob_of_finite(mlx_backend) -> None:
83 with mlx_backend.as_base() as b:
84 lp = b.logprob_of("The capital of France is", " Paris")
85 assert math.isfinite(lp)
86 assert lp < 0.0
87
88
89 def test_generate_returns_nonempty_string(mlx_backend) -> None:
90 with mlx_backend.as_base() as b:
91 out = b.generate("Hello", max_new_tokens=8, seed=0)
92 assert isinstance(out, str)
93 assert len(out) > 0