| 1 |
"""Capabilities probe — per-backend fields, determinism class, telemetry posture.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from types import SimpleNamespace |
| 6 |
from unittest.mock import patch |
| 7 |
|
| 8 |
import pytest |
| 9 |
|
| 10 |
from dlm.hardware.backend import Backend |
| 11 |
from dlm.hardware.capabilities import _accelerate_version, _rocm_arch_supports_bf16, probe |
| 12 |
from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps, force_rocm |
| 13 |
|
| 14 |
|
| 15 |
class TestProbeCuda: |
| 16 |
def test_cuda_caps_populate_device_fields(self) -> None: |
| 17 |
with force_cuda(sm=(8, 9), vram_gb=24.0, device_name="NVIDIA RTX 4090"): |
| 18 |
caps = probe() |
| 19 |
assert caps.backend == Backend.CUDA |
| 20 |
assert caps.device_name == "NVIDIA RTX 4090" |
| 21 |
assert caps.sm == (8, 9) |
| 22 |
assert caps.vram_gb is not None |
| 23 |
assert 23.5 <= caps.vram_gb <= 24.5 |
| 24 |
assert caps.unified_memory_gb is None |
| 25 |
assert caps.supports_bf16 is True |
| 26 |
assert caps.supports_fp16 is True |
| 27 |
assert caps.determinism_class == "strong" |
| 28 |
|
| 29 |
def test_cuda_pre_ampere_no_bf16(self) -> None: |
| 30 |
with force_cuda(sm=(7, 5)): |
| 31 |
caps = probe() |
| 32 |
assert caps.supports_bf16 is False |
| 33 |
# flash_attn gated on SM>=8.0 regardless of package availability |
| 34 |
assert caps.has_flash_attention is False |
| 35 |
|
| 36 |
def test_cuda_sm_probe_failure_yields_unknown_sm(self) -> None: |
| 37 |
with force_cuda(): |
| 38 |
with patch("torch.cuda.get_device_capability", side_effect=RuntimeError("boom")): |
| 39 |
caps = probe() |
| 40 |
assert caps.sm is None |
| 41 |
|
| 42 |
def test_cuda_vram_probe_failure_yields_unknown_vram(self) -> None: |
| 43 |
with force_cuda(): |
| 44 |
with patch("torch.cuda.mem_get_info", side_effect=RuntimeError("boom")): |
| 45 |
caps = probe() |
| 46 |
assert caps.vram_gb is None |
| 47 |
|
| 48 |
def test_cuda_flash_attention_true_when_package_and_sm_supported(self) -> None: |
| 49 |
with ( |
| 50 |
patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"), |
| 51 |
force_cuda(sm=(8, 0)), |
| 52 |
): |
| 53 |
caps = probe() |
| 54 |
assert caps.has_flash_attention is True |
| 55 |
|
| 56 |
|
| 57 |
class TestProbeRocm: |
| 58 |
def test_rocm_reports_hip_version(self) -> None: |
| 59 |
with force_rocm(hip_version="6.0"): |
| 60 |
caps = probe() |
| 61 |
assert caps.backend == Backend.ROCM |
| 62 |
assert caps.rocm_version == "6.0" |
| 63 |
assert caps.cuda_version is None |
| 64 |
assert caps.determinism_class == "best-effort" |
| 65 |
assert caps.has_flash_attention is False |
| 66 |
|
| 67 |
def test_rocm_arch_probe_failure_yields_unknown_arch(self) -> None: |
| 68 |
with force_rocm(): |
| 69 |
with patch("torch.cuda.get_device_properties", side_effect=RuntimeError("boom")): |
| 70 |
caps = probe() |
| 71 |
assert caps.rocm_arch is None |
| 72 |
|
| 73 |
def test_rocm_arch_probe_missing_name_yields_unknown_arch(self) -> None: |
| 74 |
with force_rocm(): |
| 75 |
with patch( |
| 76 |
"torch.cuda.get_device_properties", return_value=SimpleNamespace(name="AMD") |
| 77 |
): |
| 78 |
caps = probe() |
| 79 |
assert caps.rocm_arch is None |
| 80 |
|
| 81 |
|
| 82 |
class TestProbeMps: |
| 83 |
def test_mps_caps(self) -> None: |
| 84 |
with force_mps(): |
| 85 |
caps = probe() |
| 86 |
assert caps.backend == Backend.MPS |
| 87 |
assert caps.supports_bf16 is False # conservative default |
| 88 |
assert caps.supports_fp16 is True |
| 89 |
assert caps.unified_memory_gb is not None |
| 90 |
assert caps.vram_gb is None |
| 91 |
assert caps.determinism_class == "best-effort" |
| 92 |
assert caps.has_flash_attention is False |
| 93 |
|
| 94 |
def test_mps_never_reports_flash_attention(self) -> None: |
| 95 |
with ( |
| 96 |
patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"), |
| 97 |
force_mps(), |
| 98 |
): |
| 99 |
caps = probe() |
| 100 |
assert caps.has_flash_attention is False |
| 101 |
|
| 102 |
|
| 103 |
class TestMlxAvailability: |
| 104 |
def test_non_mps_never_reports_mlx(self) -> None: |
| 105 |
# Off-Apple hosts: has_mlx is False regardless of dist metadata |
| 106 |
# (Sprint 21). The probe won't consult importlib on CUDA/CPU. |
| 107 |
with force_cuda(sm=(8, 0)): |
| 108 |
caps = probe() |
| 109 |
assert caps.has_mlx is False |
| 110 |
with force_cpu(): |
| 111 |
caps = probe() |
| 112 |
assert caps.has_mlx is False |
| 113 |
|
| 114 |
def test_mps_reports_mlx_when_both_modules_installed( |
| 115 |
self, monkeypatch: pytest.MonkeyPatch |
| 116 |
) -> None: |
| 117 |
# Simulate both mlx + mlx_lm available. |
| 118 |
from dlm.hardware import capabilities as caps_mod |
| 119 |
|
| 120 |
real_avail = caps_mod._module_available |
| 121 |
|
| 122 |
def fake_available(name: str) -> bool: |
| 123 |
if name in ("mlx", "mlx_lm"): |
| 124 |
return True |
| 125 |
return real_avail(name) |
| 126 |
|
| 127 |
monkeypatch.setattr(caps_mod, "_module_available", fake_available) |
| 128 |
with force_mps(): |
| 129 |
caps = probe() |
| 130 |
assert caps.has_mlx is True |
| 131 |
|
| 132 |
def test_mps_reports_no_mlx_when_mlx_lm_missing(self, monkeypatch: pytest.MonkeyPatch) -> None: |
| 133 |
from dlm.hardware import capabilities as caps_mod |
| 134 |
|
| 135 |
real_avail = caps_mod._module_available |
| 136 |
|
| 137 |
def fake_available(name: str) -> bool: |
| 138 |
if name == "mlx": |
| 139 |
return True |
| 140 |
if name == "mlx_lm": |
| 141 |
return False |
| 142 |
return real_avail(name) |
| 143 |
|
| 144 |
monkeypatch.setattr(caps_mod, "_module_available", fake_available) |
| 145 |
with force_mps(): |
| 146 |
caps = probe() |
| 147 |
assert caps.has_mlx is False |
| 148 |
|
| 149 |
|
| 150 |
class TestProbeCpu: |
| 151 |
def test_cpu_advisory_determinism(self) -> None: |
| 152 |
with force_cpu(): |
| 153 |
caps = probe() |
| 154 |
assert caps.backend == Backend.CPU |
| 155 |
assert caps.determinism_class == "advisory" |
| 156 |
assert caps.vram_gb is None |
| 157 |
assert caps.unified_memory_gb is None |
| 158 |
assert caps.supports_bf16 is False |
| 159 |
assert caps.has_flash_attention is False |
| 160 |
|
| 161 |
|
| 162 |
class TestTelemetryPosture: |
| 163 |
def test_reports_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None: |
| 164 |
monkeypatch.setenv("HF_HUB_DISABLE_TELEMETRY", "1") |
| 165 |
monkeypatch.setenv("DO_NOT_TRACK", "1") |
| 166 |
with force_cpu(): |
| 167 |
caps = probe() |
| 168 |
assert caps.telemetry_posture["HF_HUB_DISABLE_TELEMETRY"] == "1" |
| 169 |
assert caps.telemetry_posture["DO_NOT_TRACK"] == "1" |
| 170 |
assert "wandb_installed" in caps.telemetry_posture |
| 171 |
assert "python" in caps.telemetry_posture |
| 172 |
|
| 173 |
def test_unset_vars_show_placeholder(self, monkeypatch: pytest.MonkeyPatch) -> None: |
| 174 |
monkeypatch.delenv("HF_HUB_DISABLE_TELEMETRY", raising=False) |
| 175 |
monkeypatch.delenv("DO_NOT_TRACK", raising=False) |
| 176 |
monkeypatch.delenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", raising=False) |
| 177 |
with force_cpu(): |
| 178 |
caps = probe() |
| 179 |
assert caps.telemetry_posture["HF_HUB_DISABLE_TELEMETRY"] == "<unset>" |
| 180 |
assert caps.telemetry_posture["DO_NOT_TRACK"] == "<unset>" |
| 181 |
|
| 182 |
|
| 183 |
class TestCoverageEdges: |
| 184 |
def test_rocm_arch_none_is_not_bf16_capable(self) -> None: |
| 185 |
assert _rocm_arch_supports_bf16(None) is False |
| 186 |
|
| 187 |
def test_accelerate_version_missing_returns_none(self) -> None: |
| 188 |
from importlib.metadata import PackageNotFoundError |
| 189 |
|
| 190 |
with patch("importlib.metadata.version", side_effect=PackageNotFoundError): |
| 191 |
assert _accelerate_version() is None |