feat(hardware): report MLX inference availability in doctor + pyproject mlx extra
- SHA
eb05e739a7e8c59dd7da7e01fcaff0a98cce52f1- Parents
-
fc20ed1 - Tree
45ade83
eb05e73
eb05e739a7e8c59dd7da7e01fcaff0a98cce52f1fc20ed1
45ade83| Status | File | + | - |
|---|---|---|---|
| M |
docs/cli/reference.md
|
1 | 0 |
| M |
pyproject.toml
|
7 | 0 |
| M |
src/dlm/hardware/capabilities.py
|
15 | 0 |
| M |
src/dlm/hardware/render.py
|
1 | 0 |
| M |
tests/unit/hardware/test_capabilities.py
|
49 | 0 |
docs/cli/reference.mdmodified@@ -80,6 +80,7 @@ dlm prompt <path> [query] [--max-tokens N] [--temp F] [--top-p F] | ||
| 80 | 80 | | `--temp F` | 0.7 | Temperature. `0.0` = greedy decoding (deterministic). | |
| 81 | 81 | | `--top-p F` | None | Top-p sampling. | |
| 82 | 82 | | `--adapter NAME` | None | Select a named adapter from `training.adapters`. Required on multi-adapter documents; rejected on single-adapter ones. | |
| 83 | +| `--backend {auto,pytorch,mlx}` | `auto` | Inference backend. `auto` picks MLX on Apple Silicon (when `uv sync --extra mlx` is installed), else PyTorch. | | |
| 83 | 84 | | `--verbose` | false | Print resolved `InferencePlan` on stderr. | |
| 84 | 85 | |
| 85 | 86 | Query is the CLI positional argument. Omit to read from stdin. |
pyproject.tomlmodified@@ -48,6 +48,13 @@ dependencies = [ | ||
| 48 | 48 | cuda = [ |
| 49 | 49 | "bitsandbytes>=0.43", |
| 50 | 50 | ] |
| 51 | +# Apple Silicon only (Sprint 21). `mlx` + `mlx-lm` wheels are darwin-arm64 | |
| 52 | +# exclusives; env markers keep `uv sync --extra mlx` a no-op on non-Apple | |
| 53 | +# hosts so wheel resolution doesn't fail for Linux/CUDA contributors. | |
| 54 | +mlx = [ | |
| 55 | + "mlx>=0.18; sys_platform == 'darwin' and platform_machine == 'arm64'", | |
| 56 | + "mlx-lm>=0.19; sys_platform == 'darwin' and platform_machine == 'arm64'", | |
| 57 | +] | |
| 51 | 58 | |
| 52 | 59 | [project.scripts] |
| 53 | 60 | dlm = "dlm.cli.app:main" |
src/dlm/hardware/capabilities.pymodified@@ -54,6 +54,7 @@ class Capabilities: | ||
| 54 | 54 | has_xformers: bool |
| 55 | 55 | has_bitsandbytes: bool |
| 56 | 56 | has_triton: bool |
| 57 | + has_mlx: bool | |
| 57 | 58 | torch_version: str |
| 58 | 59 | cuda_version: str | None |
| 59 | 60 | rocm_version: str | None |
@@ -92,6 +93,7 @@ def probe() -> Capabilities: | ||
| 92 | 93 | has_xformers=_module_available("xformers"), |
| 93 | 94 | has_bitsandbytes=_module_available("bitsandbytes") and backend == Backend.CUDA, |
| 94 | 95 | has_triton=_module_available("triton"), |
| 96 | + has_mlx=_has_mlx_inference(backend), | |
| 95 | 97 | torch_version=str(torch.__version__), |
| 96 | 98 | cuda_version=_cuda_version(backend, torch), |
| 97 | 99 | rocm_version=_rocm_version(torch), |
@@ -182,6 +184,19 @@ def _module_available(name: str) -> bool: | ||
| 182 | 184 | return importlib.util.find_spec(name) is not None |
| 183 | 185 | |
| 184 | 186 | |
| 187 | +def _has_mlx_inference(backend: Backend) -> bool: | |
| 188 | + """True iff MLX inference is runnable on this host. | |
| 189 | + | |
| 190 | + Sprint 21: MLX is darwin-arm64 only. Off-platform installs of `mlx` | |
| 191 | + via pip would be a packaging mistake, but we still gate on backend | |
| 192 | + to avoid reporting True for a misconfigured CUDA box that happens | |
| 193 | + to have an mlx dist lying around. | |
| 194 | + """ | |
| 195 | + if backend != Backend.MPS: | |
| 196 | + return False | |
| 197 | + return _module_available("mlx") and _module_available("mlx_lm") | |
| 198 | + | |
| 199 | + | |
| 185 | 200 | def _cuda_version(backend: Backend, torch: object) -> str | None: |
| 186 | 201 | if backend != Backend.CUDA: |
| 187 | 202 | return None |
src/dlm/hardware/render.pymodified@@ -25,6 +25,7 @@ def render_text(result: DoctorResult) -> str: | ||
| 25 | 25 | lines.append(f"FlashAttention: {_bool(caps.has_flash_attention)}") |
| 26 | 26 | lines.append(f"xFormers: {_bool(caps.has_xformers)}") |
| 27 | 27 | lines.append(f"Triton: {_bool(caps.has_triton)}") |
| 28 | + lines.append(f"MLX inference: {_bool(caps.has_mlx)}") | |
| 28 | 29 | lines.append(f"CPU cores: {caps.cpu_cores}") |
| 29 | 30 | lines.append(f"RAM: {caps.ram_gb:.1f} GB") |
| 30 | 31 | lines.append(f"Determinism: {caps.determinism_class}") |
tests/unit/hardware/test_capabilities.pymodified@@ -55,6 +55,55 @@ class TestProbeMps: | ||
| 55 | 55 | assert caps.has_flash_attention is False |
| 56 | 56 | |
| 57 | 57 | |
| 58 | +class TestMlxAvailability: | |
| 59 | + def test_non_mps_never_reports_mlx(self) -> None: | |
| 60 | + # Off-Apple hosts: has_mlx is False regardless of dist metadata | |
| 61 | + # (Sprint 21). The probe won't consult importlib on CUDA/CPU. | |
| 62 | + with force_cuda(sm=(8, 0)): | |
| 63 | + caps = probe() | |
| 64 | + assert caps.has_mlx is False | |
| 65 | + with force_cpu(): | |
| 66 | + caps = probe() | |
| 67 | + assert caps.has_mlx is False | |
| 68 | + | |
| 69 | + def test_mps_reports_mlx_when_both_modules_installed( | |
| 70 | + self, monkeypatch: pytest.MonkeyPatch | |
| 71 | + ) -> None: | |
| 72 | + # Simulate both mlx + mlx_lm available. | |
| 73 | + from dlm.hardware import capabilities as caps_mod | |
| 74 | + | |
| 75 | + real_avail = caps_mod._module_available | |
| 76 | + | |
| 77 | + def fake_available(name: str) -> bool: | |
| 78 | + if name in ("mlx", "mlx_lm"): | |
| 79 | + return True | |
| 80 | + return real_avail(name) | |
| 81 | + | |
| 82 | + monkeypatch.setattr(caps_mod, "_module_available", fake_available) | |
| 83 | + with force_mps(): | |
| 84 | + caps = probe() | |
| 85 | + assert caps.has_mlx is True | |
| 86 | + | |
| 87 | + def test_mps_reports_no_mlx_when_mlx_lm_missing( | |
| 88 | + self, monkeypatch: pytest.MonkeyPatch | |
| 89 | + ) -> None: | |
| 90 | + from dlm.hardware import capabilities as caps_mod | |
| 91 | + | |
| 92 | + real_avail = caps_mod._module_available | |
| 93 | + | |
| 94 | + def fake_available(name: str) -> bool: | |
| 95 | + if name == "mlx": | |
| 96 | + return True | |
| 97 | + if name == "mlx_lm": | |
| 98 | + return False | |
| 99 | + return real_avail(name) | |
| 100 | + | |
| 101 | + monkeypatch.setattr(caps_mod, "_module_available", fake_available) | |
| 102 | + with force_mps(): | |
| 103 | + caps = probe() | |
| 104 | + assert caps.has_mlx is False | |
| 105 | + | |
| 106 | + | |
| 58 | 107 | class TestProbeCpu: |
| 59 | 108 | def test_cpu_advisory_determinism(self) -> None: |
| 60 | 109 | with force_cpu(): |