tenseleyflow/sway / 8917f1f

Browse files

tests/mlx_smoke: darwin-arm64 MLX backend smoke + reproducible adapter fixture builder (C5)

Authored by espadonne
SHA
8917f1fcdef4ed155e01fe204c8adb512c979c26
Parents
c6f6de7
Tree
ee47ced

2 changed files

StatusFile+-
A tests/fixtures/build_mlx_adapter.py 84 0
A tests/integration/test_mlx_smoke.py 93 0
tests/fixtures/build_mlx_adapter.pyadded
@@ -0,0 +1,84 @@
1
+"""Reproducible MLX adapter fixture for the smoke test.
2
+
3
+Run this once on darwin-arm64 with the ``[mlx]`` extra installed to
4
+generate a tiny LoRA adapter under
5
+``tests/fixtures/mlx_adapter_smollm2_135m/``. The integration test
6
+at ``tests/integration/test_mlx_smoke.py`` uses the directory if it
7
+exists, else skips with a pointer here.
8
+
9
+This script is deliberately *not* vendored as a binary in the repo —
10
+regenerating it from scratch stays reproducible, keeps the repo small,
11
+and lets us bump the base model without re-checking binaries in.
12
+
13
+Usage:
14
+
15
+    # Prerequisites: darwin-arm64 + ``uv pip install -e ".[mlx]"``
16
+    python tests/fixtures/build_mlx_adapter.py
17
+"""
18
+
19
+from __future__ import annotations
20
+
21
+import sys
22
+from pathlib import Path
23
+
24
+
25
+def main() -> int:
26
+    if sys.platform != "darwin":
27
+        print("build_mlx_adapter.py requires darwin-arm64", file=sys.stderr)
28
+        return 1
29
+
30
+    try:
31
+        import mlx.core as mx
32
+        from mlx_lm import load
33
+        from mlx_lm.tuner.utils import linear_to_lora_layers
34
+    except ImportError as exc:
35
+        print(
36
+            f"mlx / mlx_lm not importable: {exc}\n"
37
+            "Install the [mlx] extra: uv pip install -e '.[mlx]'",
38
+            file=sys.stderr,
39
+        )
40
+        return 1
41
+
42
+    model_id = "mlx-community/SmolLM2-135M-Instruct-4bit"
43
+    out_dir = Path(__file__).parent / "mlx_adapter_smollm2_135m"
44
+    out_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+    print(f"Loading {model_id} (will download on first run)…")
47
+    model, _tokenizer = load(model_id)
48
+
49
+    # Apply a tiny LoRA shim over the attention projections and
50
+    # randomize the A/B weights deterministically.
51
+    lora_config = {
52
+        "rank": 4,
53
+        "alpha": 8,
54
+        "dropout": 0.0,
55
+        "scale": 10.0,
56
+        "keys": ["q_proj", "v_proj"],
57
+    }
58
+    linear_to_lora_layers(model, num_layers=2, config=lora_config)
59
+
60
+    # Seed-scale lora_B so the adapter actually changes outputs.
61
+    mx.random.seed(0)
62
+    params = dict(model.trainable_parameters())
63
+    for name, arr in params.items():
64
+        if "lora_b" in name.lower():
65
+            params[name] = mx.random.normal(arr.shape) * 0.05
66
+
67
+    # Save adapters to the mlx_lm convention: <dir>/adapters.safetensors
68
+    # plus an adapter_config.json stub mlx_lm.load can find.
69
+    import json
70
+
71
+    from mlx.utils import tree_flatten
72
+    from safetensors.numpy import save_file as save_safetensors
73
+
74
+    save_safetensors(
75
+        str(out_dir / "adapters.safetensors"),
76
+        {k: v.astype(mx.float16) for k, v in tree_flatten(params)},
77
+    )
78
+    (out_dir / "adapter_config.json").write_text(json.dumps(lora_config, indent=2))
79
+    print(f"Wrote adapter fixture to {out_dir}")
80
+    return 0
81
+
82
+
83
+if __name__ == "__main__":
84
+    raise SystemExit(main())
tests/integration/test_mlx_smoke.pyadded
@@ -0,0 +1,93 @@
1
+"""C5: MLX backend smoke test (darwin-arm64-only).
2
+
3
+Exercises ``MLXDifferentialBackend.next_token_dist`` and ``generate``
4
+on a real mlx_lm-loaded model with a small LoRA adapter attached.
5
+The adapter is built once via ``tests/fixtures/build_mlx_adapter.py``
6
+(see that script's docstring); this test skips when the fixture
7
+directory is absent rather than building it on the fly, since fixture
8
+builds take ~30s and re-running the test should be cheap.
9
+
10
+Skips on:
11
+- non-darwin platforms (mlx is Apple Silicon only)
12
+- non-arm64 architectures
13
+- missing ``mlx_lm`` import
14
+- missing fixture directory
15
+"""
16
+
17
+from __future__ import annotations
18
+
19
+import math
20
+import platform
21
+import sys
22
+from pathlib import Path
23
+
24
+import numpy as np
25
+import pytest
26
+
27
+pytestmark = [pytest.mark.slow, pytest.mark.online]
28
+
29
+_FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "mlx_adapter_smollm2_135m"
30
+_MODEL_ID = "mlx-community/SmolLM2-135M-Instruct-4bit"
31
+
32
+
33
+def _platform_supports_mlx() -> bool:
34
+    return sys.platform == "darwin" and platform.machine() == "arm64"
35
+
36
+
37
+@pytest.fixture(scope="module")
38
+def mlx_backend():
39
+    if not _platform_supports_mlx():
40
+        pytest.skip("MLX requires darwin-arm64")
41
+    pytest.importorskip("mlx_lm", reason="install the [mlx] extra to run MLX tests")
42
+
43
+    if not _FIXTURE_DIR.exists():
44
+        pytest.skip(
45
+            f"missing MLX adapter fixture at {_FIXTURE_DIR} — generate it via "
46
+            f"`python tests/fixtures/build_mlx_adapter.py`"
47
+        )
48
+
49
+    from dlm_sway.backends.mlx import MLXDifferentialBackend
50
+    from dlm_sway.core.model import ModelSpec
51
+
52
+    backend = MLXDifferentialBackend(
53
+        base_spec=ModelSpec(base=_MODEL_ID, kind="mlx"),
54
+        adapter_path=_FIXTURE_DIR,
55
+    )
56
+    yield backend
57
+    backend.close()
58
+
59
+
60
+def test_next_token_dist_returns_finite_topk(mlx_backend) -> None:
61
+    with mlx_backend.as_base() as b:
62
+        d = b.next_token_dist("The capital of France is", top_k=32)
63
+    assert d.token_ids.shape == (32,)
64
+    assert d.logprobs.shape == (32,)
65
+    assert np.all(np.isfinite(d.logprobs))
66
+    # Top-k must be sorted in descending probability order.
67
+    assert np.all(np.diff(d.logprobs) <= 1e-7)
68
+
69
+
70
+def test_adapter_changes_distribution(mlx_backend) -> None:
71
+    """The whole point of the differential backend: base ≠ ft."""
72
+    prompt = "The adapter does"
73
+    with mlx_backend.as_base() as b:
74
+        base_dist = b.next_token_dist(prompt, top_k=32)
75
+    with mlx_backend.as_finetuned() as f:
76
+        ft_dist = f.next_token_dist(prompt, top_k=32)
77
+    same_ids = np.array_equal(base_dist.token_ids, ft_dist.token_ids)
78
+    if same_ids:
79
+        assert not np.allclose(base_dist.logprobs, ft_dist.logprobs, atol=1e-5)
80
+
81
+
82
+def test_logprob_of_finite(mlx_backend) -> None:
83
+    with mlx_backend.as_base() as b:
84
+        lp = b.logprob_of("The capital of France is", " Paris")
85
+    assert math.isfinite(lp)
86
+    assert lp < 0.0
87
+
88
+
89
+def test_generate_returns_nonempty_string(mlx_backend) -> None:
90
+    with mlx_backend.as_base() as b:
91
+        out = b.generate("Hello", max_new_tokens=8, seed=0)
92
+    assert isinstance(out, str)
93
+    assert len(out) > 0