tenseleyflow/documentlanguagemodel / eb05e73

Browse files

feat(hardware): report MLX inference availability in doctor + pyproject mlx extra

Authored by espadonne
SHA
eb05e739a7e8c59dd7da7e01fcaff0a98cce52f1
Parents
fc20ed1
Tree
45ade83

5 changed files

StatusFile+-
M docs/cli/reference.md 1 0
M pyproject.toml 7 0
M src/dlm/hardware/capabilities.py 15 0
M src/dlm/hardware/render.py 1 0
M tests/unit/hardware/test_capabilities.py 49 0
docs/cli/reference.mdmodified
@@ -80,6 +80,7 @@ dlm prompt <path> [query] [--max-tokens N] [--temp F] [--top-p F]
8080
 | `--temp F` | 0.7 | Temperature. `0.0` = greedy decoding (deterministic). |
8181
 | `--top-p F` | None | Top-p sampling. |
8282
 | `--adapter NAME` | None | Select a named adapter from `training.adapters`. Required on multi-adapter documents; rejected on single-adapter ones. |
83
+| `--backend {auto,pytorch,mlx}` | `auto` | Inference backend. `auto` picks MLX on Apple Silicon (when `uv sync --extra mlx` is installed), else PyTorch. |
8384
 | `--verbose` | false | Print resolved `InferencePlan` on stderr. |
8485
 
8586
 Query is the CLI positional argument. Omit to read from stdin.
pyproject.tomlmodified
@@ -48,6 +48,13 @@ dependencies = [
4848
 cuda = [
4949
     "bitsandbytes>=0.43",
5050
 ]
51
+# Apple Silicon only (Sprint 21). `mlx` + `mlx-lm` wheels are darwin-arm64
52
+# exclusives; env markers keep `uv sync --extra mlx` a no-op on non-Apple
53
+# hosts so wheel resolution doesn't fail for Linux/CUDA contributors.
54
+mlx = [
55
+    "mlx>=0.18; sys_platform == 'darwin' and platform_machine == 'arm64'",
56
+    "mlx-lm>=0.19; sys_platform == 'darwin' and platform_machine == 'arm64'",
57
+]
5158
 
5259
 [project.scripts]
5360
 dlm = "dlm.cli.app:main"
src/dlm/hardware/capabilities.pymodified
@@ -54,6 +54,7 @@ class Capabilities:
5454
     has_xformers: bool
5555
     has_bitsandbytes: bool
5656
     has_triton: bool
57
+    has_mlx: bool
5758
     torch_version: str
5859
     cuda_version: str | None
5960
     rocm_version: str | None
@@ -92,6 +93,7 @@ def probe() -> Capabilities:
9293
         has_xformers=_module_available("xformers"),
9394
         has_bitsandbytes=_module_available("bitsandbytes") and backend == Backend.CUDA,
9495
         has_triton=_module_available("triton"),
96
+        has_mlx=_has_mlx_inference(backend),
9597
         torch_version=str(torch.__version__),
9698
         cuda_version=_cuda_version(backend, torch),
9799
         rocm_version=_rocm_version(torch),
@@ -182,6 +184,19 @@ def _module_available(name: str) -> bool:
182184
     return importlib.util.find_spec(name) is not None
183185
 
184186
 
187
+def _has_mlx_inference(backend: Backend) -> bool:
188
+    """True iff MLX inference is runnable on this host.
189
+
190
+    Sprint 21: MLX is darwin-arm64 only. Off-platform installs of `mlx`
191
+    via pip would be a packaging mistake, but we still gate on backend
192
+    to avoid reporting True for a misconfigured CUDA box that happens
193
+    to have an mlx dist lying around.
194
+    """
195
+    if backend != Backend.MPS:
196
+        return False
197
+    return _module_available("mlx") and _module_available("mlx_lm")
198
+
199
+
185200
 def _cuda_version(backend: Backend, torch: object) -> str | None:
186201
     if backend != Backend.CUDA:
187202
         return None
src/dlm/hardware/render.pymodified
@@ -25,6 +25,7 @@ def render_text(result: DoctorResult) -> str:
2525
     lines.append(f"FlashAttention: {_bool(caps.has_flash_attention)}")
2626
     lines.append(f"xFormers:       {_bool(caps.has_xformers)}")
2727
     lines.append(f"Triton:         {_bool(caps.has_triton)}")
28
+    lines.append(f"MLX inference:  {_bool(caps.has_mlx)}")
2829
     lines.append(f"CPU cores:      {caps.cpu_cores}")
2930
     lines.append(f"RAM:            {caps.ram_gb:.1f} GB")
3031
     lines.append(f"Determinism:    {caps.determinism_class}")
tests/unit/hardware/test_capabilities.pymodified
@@ -55,6 +55,55 @@ class TestProbeMps:
5555
         assert caps.has_flash_attention is False
5656
 
5757
 
58
+class TestMlxAvailability:
59
+    def test_non_mps_never_reports_mlx(self) -> None:
60
+        # Off-Apple hosts: has_mlx is False regardless of dist metadata
61
+        # (Sprint 21). The probe won't consult importlib on CUDA/CPU.
62
+        with force_cuda(sm=(8, 0)):
63
+            caps = probe()
64
+        assert caps.has_mlx is False
65
+        with force_cpu():
66
+            caps = probe()
67
+        assert caps.has_mlx is False
68
+
69
+    def test_mps_reports_mlx_when_both_modules_installed(
70
+        self, monkeypatch: pytest.MonkeyPatch
71
+    ) -> None:
72
+        # Simulate both mlx + mlx_lm available.
73
+        from dlm.hardware import capabilities as caps_mod
74
+
75
+        real_avail = caps_mod._module_available
76
+
77
+        def fake_available(name: str) -> bool:
78
+            if name in ("mlx", "mlx_lm"):
79
+                return True
80
+            return real_avail(name)
81
+
82
+        monkeypatch.setattr(caps_mod, "_module_available", fake_available)
83
+        with force_mps():
84
+            caps = probe()
85
+        assert caps.has_mlx is True
86
+
87
+    def test_mps_reports_no_mlx_when_mlx_lm_missing(
88
+        self, monkeypatch: pytest.MonkeyPatch
89
+    ) -> None:
90
+        from dlm.hardware import capabilities as caps_mod
91
+
92
+        real_avail = caps_mod._module_available
93
+
94
+        def fake_available(name: str) -> bool:
95
+            if name == "mlx":
96
+                return True
97
+            if name == "mlx_lm":
98
+                return False
99
+            return real_avail(name)
100
+
101
+        monkeypatch.setattr(caps_mod, "_module_available", fake_available)
102
+        with force_mps():
103
+            caps = probe()
104
+        assert caps.has_mlx is False
105
+
106
+
58107
 class TestProbeCpu:
59108
     def test_cpu_advisory_determinism(self) -> None:
60109
         with force_cpu():