Python · 2729 bytes Raw Blame History
1 """Reproducible MLX adapter fixture for the smoke test.
2
3 Run this once on darwin-arm64 with the ``[mlx]`` extra installed to
4 generate a tiny LoRA adapter under
5 ``tests/fixtures/mlx_adapter_smollm2_135m/``. The integration test
6 at ``tests/integration/test_mlx_smoke.py`` uses the directory if it
7 exists, else skips with a pointer here.
8
9 This script is deliberately *not* vendored as a binary in the repo —
10 regenerating it from scratch stays reproducible, keeps the repo small,
11 and lets us bump the base model without re-checking binaries in.
12
13 Usage:
14
15 # Prerequisites: darwin-arm64 + ``uv pip install -e ".[mlx]"``
16 python tests/fixtures/build_mlx_adapter.py
17 """
18
19 from __future__ import annotations
20
21 import sys
22 from pathlib import Path
23
24
25 def main() -> int:
26 if sys.platform != "darwin":
27 print("build_mlx_adapter.py requires darwin-arm64", file=sys.stderr)
28 return 1
29
30 try:
31 import mlx.core as mx
32 from mlx_lm import load
33 from mlx_lm.tuner.utils import linear_to_lora_layers
34 except ImportError as exc:
35 print(
36 f"mlx / mlx_lm not importable: {exc}\n"
37 "Install the [mlx] extra: uv pip install -e '.[mlx]'",
38 file=sys.stderr,
39 )
40 return 1
41
42 model_id = "mlx-community/SmolLM2-135M-Instruct-4bit"
43 out_dir = Path(__file__).parent / "mlx_adapter_smollm2_135m"
44 out_dir.mkdir(parents=True, exist_ok=True)
45
46 print(f"Loading {model_id} (will download on first run)…")
47 model, _tokenizer = load(model_id)
48
49 # Apply a tiny LoRA shim over the attention projections and
50 # randomize the A/B weights deterministically.
51 lora_config = {
52 "rank": 4,
53 "alpha": 8,
54 "dropout": 0.0,
55 "scale": 10.0,
56 "keys": ["q_proj", "v_proj"],
57 }
58 linear_to_lora_layers(model, num_layers=2, config=lora_config)
59
60 # Seed-scale lora_B so the adapter actually changes outputs.
61 mx.random.seed(0)
62 params = dict(model.trainable_parameters())
63 for name, arr in params.items():
64 if "lora_b" in name.lower():
65 params[name] = mx.random.normal(arr.shape) * 0.05
66
67 # Save adapters to the mlx_lm convention: <dir>/adapters.safetensors
68 # plus an adapter_config.json stub mlx_lm.load can find.
69 import json
70
71 from mlx.utils import tree_flatten
72 from safetensors.numpy import save_file as save_safetensors
73
74 save_safetensors(
75 str(out_dir / "adapters.safetensors"),
76 {k: v.astype(mx.float16) for k, v in tree_flatten(params)},
77 )
78 (out_dir / "adapter_config.json").write_text(json.dumps(lora_config, indent=2))
79 print(f"Wrote adapter fixture to {out_dir}")
80 return 0
81
82
83 if __name__ == "__main__":
84 raise SystemExit(main())