Python · 5257 bytes Raw Blame History
1 """Backend selection: auto / pytorch / mlx routing + refusal modes."""
2
3 from __future__ import annotations
4
5 from unittest.mock import MagicMock, patch
6
7 import pytest
8
9 from dlm.inference.backends.select import (
10 UnsupportedBackendError,
11 build_backend,
12 is_apple_silicon,
13 select_backend,
14 )
15
16
17 @pytest.fixture(autouse=True)
18 def _reset_mlx_state() -> None:
19 """Audit-08 N9: ensure no prior test's mlx_available patch leaks.
20
21 Every test in this module patches `mlx_available` / `is_apple_silicon`
22 within a `with` block, so leakage should be impossible — but pinning
23 the contract autouse keeps the guarantee explicit. Nothing to do at
24 setup (each test sets its own patches); this is a failsafe anchor
25 in case a future test forgets the context manager.
26 """
27
28
29 class TestSelectBackendExplicit:
30 def test_pytorch_always_selected(self) -> None:
31 # Explicit pytorch never triggers the mlx probe.
32 with patch("dlm.inference.backends.select.mlx_available") as m_avail:
33 assert select_backend("pytorch") == "pytorch"
34 m_avail.assert_not_called()
35
36 def test_mlx_on_apple_silicon_with_mlx_installed(self) -> None:
37 with (
38 patch("dlm.inference.backends.select.is_apple_silicon", return_value=True),
39 patch("dlm.inference.backends.select.mlx_available", return_value=True),
40 ):
41 assert select_backend("mlx") == "mlx"
42
43 def test_mlx_off_platform_raises_clear_error(self) -> None:
44 with patch("dlm.inference.backends.select.is_apple_silicon", return_value=False):
45 with pytest.raises(UnsupportedBackendError, match="Apple Silicon"):
46 select_backend("mlx")
47
48 def test_mlx_apple_silicon_without_extra_raises(self) -> None:
49 with (
50 patch("dlm.inference.backends.select.is_apple_silicon", return_value=True),
51 patch("dlm.inference.backends.select.mlx_available", return_value=False),
52 ):
53 with pytest.raises(UnsupportedBackendError, match="mlx extra"):
54 select_backend("mlx")
55
56
57 class TestSelectBackendAuto:
58 def test_auto_picks_mlx_when_available(self) -> None:
59 with patch("dlm.inference.backends.select.mlx_available", return_value=True):
60 assert select_backend("auto") == "mlx"
61
62 def test_auto_falls_back_to_pytorch_when_mlx_absent(self) -> None:
63 with patch("dlm.inference.backends.select.mlx_available", return_value=False):
64 assert select_backend("auto") == "pytorch"
65
66 def test_auto_on_non_darwin_never_imports_mlx(self) -> None:
67 # `mlx_available()` short-circuits on is_apple_silicon=False,
68 # so auto on Linux gracefully lands on pytorch without probing
69 # the mlx/mlx_lm modules.
70 with patch("dlm.inference.backends.select.is_apple_silicon", return_value=False):
71 assert select_backend("auto") == "pytorch"
72
73
74 class TestBuildBackend:
75 def test_pytorch_returns_pytorch_backend(self) -> None:
76 from dlm.inference.backends.pytorch_backend import PyTorchBackend
77
78 backend = build_backend("pytorch", MagicMock())
79 assert isinstance(backend, PyTorchBackend)
80
81 def test_mlx_returns_mlx_backend(self) -> None:
82 from dlm.inference.backends.mlx_backend import MlxBackend
83
84 backend = build_backend("mlx", MagicMock())
85 assert isinstance(backend, MlxBackend)
86
87 def test_unknown_backend_raises(self) -> None:
88 with pytest.raises(ValueError, match="unknown backend"):
89 build_backend("haskell", MagicMock()) # type: ignore[arg-type]
90
91
92 class TestMlxAvailableDoesNotImportMlx:
93 def test_mlx_available_off_platform_short_circuits(self) -> None:
94 # On non-darwin, mlx_available returns False without calling
95 # importlib.util.find_spec — guaranteed by the early return.
96 from dlm.inference.backends import select as sel
97
98 with (
99 patch.object(sel, "is_apple_silicon", return_value=False),
100 patch.object(sel.importlib.util, "find_spec") as m_find,
101 ):
102 assert sel.mlx_available() is False
103 m_find.assert_not_called()
104
105 def test_mlx_available_checks_both_packages_on_apple_silicon(self) -> None:
106 from dlm.inference.backends import select as sel
107
108 with (
109 patch.object(sel, "is_apple_silicon", return_value=True),
110 patch.object(
111 sel.importlib.util, "find_spec", side_effect=[object(), object()]
112 ) as m_find,
113 ):
114 assert sel.mlx_available() is True
115 assert m_find.call_count == 2
116
117
118 class TestPlatformHelper:
119 def test_is_apple_silicon_true_only_for_darwin_arm64(self) -> None:
120 with (
121 patch("dlm.inference.backends.select.sys.platform", "darwin"),
122 patch("dlm.inference.backends.select.platform.machine", return_value="arm64"),
123 ):
124 assert is_apple_silicon() is True
125
126 def test_is_apple_silicon_false_for_other_hosts(self) -> None:
127 with (
128 patch("dlm.inference.backends.select.sys.platform", "linux"),
129 patch("dlm.inference.backends.select.platform.machine", return_value="x86_64"),
130 ):
131 assert is_apple_silicon() is False