| 1 |
"""Backend selection: auto / pytorch / mlx routing + refusal modes.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from unittest.mock import MagicMock, patch |
| 6 |
|
| 7 |
import pytest |
| 8 |
|
| 9 |
from dlm.inference.backends.select import ( |
| 10 |
UnsupportedBackendError, |
| 11 |
build_backend, |
| 12 |
is_apple_silicon, |
| 13 |
select_backend, |
| 14 |
) |
| 15 |
|
| 16 |
|
| 17 |
@pytest.fixture(autouse=True) |
| 18 |
def _reset_mlx_state() -> None: |
| 19 |
"""Audit-08 N9: ensure no prior test's mlx_available patch leaks. |
| 20 |
|
| 21 |
Every test in this module patches `mlx_available` / `is_apple_silicon` |
| 22 |
within a `with` block, so leakage should be impossible — but pinning |
| 23 |
the contract autouse keeps the guarantee explicit. Nothing to do at |
| 24 |
setup (each test sets its own patches); this is a failsafe anchor |
| 25 |
in case a future test forgets the context manager. |
| 26 |
""" |
| 27 |
|
| 28 |
|
| 29 |
class TestSelectBackendExplicit: |
| 30 |
def test_pytorch_always_selected(self) -> None: |
| 31 |
# Explicit pytorch never triggers the mlx probe. |
| 32 |
with patch("dlm.inference.backends.select.mlx_available") as m_avail: |
| 33 |
assert select_backend("pytorch") == "pytorch" |
| 34 |
m_avail.assert_not_called() |
| 35 |
|
| 36 |
def test_mlx_on_apple_silicon_with_mlx_installed(self) -> None: |
| 37 |
with ( |
| 38 |
patch("dlm.inference.backends.select.is_apple_silicon", return_value=True), |
| 39 |
patch("dlm.inference.backends.select.mlx_available", return_value=True), |
| 40 |
): |
| 41 |
assert select_backend("mlx") == "mlx" |
| 42 |
|
| 43 |
def test_mlx_off_platform_raises_clear_error(self) -> None: |
| 44 |
with patch("dlm.inference.backends.select.is_apple_silicon", return_value=False): |
| 45 |
with pytest.raises(UnsupportedBackendError, match="Apple Silicon"): |
| 46 |
select_backend("mlx") |
| 47 |
|
| 48 |
def test_mlx_apple_silicon_without_extra_raises(self) -> None: |
| 49 |
with ( |
| 50 |
patch("dlm.inference.backends.select.is_apple_silicon", return_value=True), |
| 51 |
patch("dlm.inference.backends.select.mlx_available", return_value=False), |
| 52 |
): |
| 53 |
with pytest.raises(UnsupportedBackendError, match="mlx extra"): |
| 54 |
select_backend("mlx") |
| 55 |
|
| 56 |
|
| 57 |
class TestSelectBackendAuto: |
| 58 |
def test_auto_picks_mlx_when_available(self) -> None: |
| 59 |
with patch("dlm.inference.backends.select.mlx_available", return_value=True): |
| 60 |
assert select_backend("auto") == "mlx" |
| 61 |
|
| 62 |
def test_auto_falls_back_to_pytorch_when_mlx_absent(self) -> None: |
| 63 |
with patch("dlm.inference.backends.select.mlx_available", return_value=False): |
| 64 |
assert select_backend("auto") == "pytorch" |
| 65 |
|
| 66 |
def test_auto_on_non_darwin_never_imports_mlx(self) -> None: |
| 67 |
# `mlx_available()` short-circuits on is_apple_silicon=False, |
| 68 |
# so auto on Linux gracefully lands on pytorch without probing |
| 69 |
# the mlx/mlx_lm modules. |
| 70 |
with patch("dlm.inference.backends.select.is_apple_silicon", return_value=False): |
| 71 |
assert select_backend("auto") == "pytorch" |
| 72 |
|
| 73 |
|
| 74 |
class TestBuildBackend: |
| 75 |
def test_pytorch_returns_pytorch_backend(self) -> None: |
| 76 |
from dlm.inference.backends.pytorch_backend import PyTorchBackend |
| 77 |
|
| 78 |
backend = build_backend("pytorch", MagicMock()) |
| 79 |
assert isinstance(backend, PyTorchBackend) |
| 80 |
|
| 81 |
def test_mlx_returns_mlx_backend(self) -> None: |
| 82 |
from dlm.inference.backends.mlx_backend import MlxBackend |
| 83 |
|
| 84 |
backend = build_backend("mlx", MagicMock()) |
| 85 |
assert isinstance(backend, MlxBackend) |
| 86 |
|
| 87 |
def test_unknown_backend_raises(self) -> None: |
| 88 |
with pytest.raises(ValueError, match="unknown backend"): |
| 89 |
build_backend("haskell", MagicMock()) # type: ignore[arg-type] |
| 90 |
|
| 91 |
|
| 92 |
class TestMlxAvailableDoesNotImportMlx: |
| 93 |
def test_mlx_available_off_platform_short_circuits(self) -> None: |
| 94 |
# On non-darwin, mlx_available returns False without calling |
| 95 |
# importlib.util.find_spec — guaranteed by the early return. |
| 96 |
from dlm.inference.backends import select as sel |
| 97 |
|
| 98 |
with ( |
| 99 |
patch.object(sel, "is_apple_silicon", return_value=False), |
| 100 |
patch.object(sel.importlib.util, "find_spec") as m_find, |
| 101 |
): |
| 102 |
assert sel.mlx_available() is False |
| 103 |
m_find.assert_not_called() |
| 104 |
|
| 105 |
def test_mlx_available_checks_both_packages_on_apple_silicon(self) -> None: |
| 106 |
from dlm.inference.backends import select as sel |
| 107 |
|
| 108 |
with ( |
| 109 |
patch.object(sel, "is_apple_silicon", return_value=True), |
| 110 |
patch.object( |
| 111 |
sel.importlib.util, "find_spec", side_effect=[object(), object()] |
| 112 |
) as m_find, |
| 113 |
): |
| 114 |
assert sel.mlx_available() is True |
| 115 |
assert m_find.call_count == 2 |
| 116 |
|
| 117 |
|
| 118 |
class TestPlatformHelper: |
| 119 |
def test_is_apple_silicon_true_only_for_darwin_arm64(self) -> None: |
| 120 |
with ( |
| 121 |
patch("dlm.inference.backends.select.sys.platform", "darwin"), |
| 122 |
patch("dlm.inference.backends.select.platform.machine", return_value="arm64"), |
| 123 |
): |
| 124 |
assert is_apple_silicon() is True |
| 125 |
|
| 126 |
def test_is_apple_silicon_false_for_other_hosts(self) -> None: |
| 127 |
with ( |
| 128 |
patch("dlm.inference.backends.select.sys.platform", "linux"), |
| 129 |
patch("dlm.inference.backends.select.platform.machine", return_value="x86_64"), |
| 130 |
): |
| 131 |
assert is_apple_silicon() is False |