Python · 7210 bytes Raw Blame History
1 """Capabilities probe — per-backend fields, determinism class, telemetry posture."""
2
3 from __future__ import annotations
4
5 from types import SimpleNamespace
6 from unittest.mock import patch
7
8 import pytest
9
10 from dlm.hardware.backend import Backend
11 from dlm.hardware.capabilities import _accelerate_version, _rocm_arch_supports_bf16, probe
12 from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps, force_rocm
13
14
15 class TestProbeCuda:
16 def test_cuda_caps_populate_device_fields(self) -> None:
17 with force_cuda(sm=(8, 9), vram_gb=24.0, device_name="NVIDIA RTX 4090"):
18 caps = probe()
19 assert caps.backend == Backend.CUDA
20 assert caps.device_name == "NVIDIA RTX 4090"
21 assert caps.sm == (8, 9)
22 assert caps.vram_gb is not None
23 assert 23.5 <= caps.vram_gb <= 24.5
24 assert caps.unified_memory_gb is None
25 assert caps.supports_bf16 is True
26 assert caps.supports_fp16 is True
27 assert caps.determinism_class == "strong"
28
29 def test_cuda_pre_ampere_no_bf16(self) -> None:
30 with force_cuda(sm=(7, 5)):
31 caps = probe()
32 assert caps.supports_bf16 is False
33 # flash_attn gated on SM>=8.0 regardless of package availability
34 assert caps.has_flash_attention is False
35
36 def test_cuda_sm_probe_failure_yields_unknown_sm(self) -> None:
37 with force_cuda():
38 with patch("torch.cuda.get_device_capability", side_effect=RuntimeError("boom")):
39 caps = probe()
40 assert caps.sm is None
41
42 def test_cuda_vram_probe_failure_yields_unknown_vram(self) -> None:
43 with force_cuda():
44 with patch("torch.cuda.mem_get_info", side_effect=RuntimeError("boom")):
45 caps = probe()
46 assert caps.vram_gb is None
47
48 def test_cuda_flash_attention_true_when_package_and_sm_supported(self) -> None:
49 with (
50 patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"),
51 force_cuda(sm=(8, 0)),
52 ):
53 caps = probe()
54 assert caps.has_flash_attention is True
55
56
57 class TestProbeRocm:
58 def test_rocm_reports_hip_version(self) -> None:
59 with force_rocm(hip_version="6.0"):
60 caps = probe()
61 assert caps.backend == Backend.ROCM
62 assert caps.rocm_version == "6.0"
63 assert caps.cuda_version is None
64 assert caps.determinism_class == "best-effort"
65 assert caps.has_flash_attention is False
66
67 def test_rocm_arch_probe_failure_yields_unknown_arch(self) -> None:
68 with force_rocm():
69 with patch("torch.cuda.get_device_properties", side_effect=RuntimeError("boom")):
70 caps = probe()
71 assert caps.rocm_arch is None
72
73 def test_rocm_arch_probe_missing_name_yields_unknown_arch(self) -> None:
74 with force_rocm():
75 with patch(
76 "torch.cuda.get_device_properties", return_value=SimpleNamespace(name="AMD")
77 ):
78 caps = probe()
79 assert caps.rocm_arch is None
80
81
82 class TestProbeMps:
83 def test_mps_caps(self) -> None:
84 with force_mps():
85 caps = probe()
86 assert caps.backend == Backend.MPS
87 assert caps.supports_bf16 is False # conservative default
88 assert caps.supports_fp16 is True
89 assert caps.unified_memory_gb is not None
90 assert caps.vram_gb is None
91 assert caps.determinism_class == "best-effort"
92 assert caps.has_flash_attention is False
93
94 def test_mps_never_reports_flash_attention(self) -> None:
95 with (
96 patch("dlm.hardware.capabilities._module_available", lambda name: name == "flash_attn"),
97 force_mps(),
98 ):
99 caps = probe()
100 assert caps.has_flash_attention is False
101
102
103 class TestMlxAvailability:
104 def test_non_mps_never_reports_mlx(self) -> None:
105 # Off-Apple hosts: has_mlx is False regardless of dist metadata
106 # (Sprint 21). The probe won't consult importlib on CUDA/CPU.
107 with force_cuda(sm=(8, 0)):
108 caps = probe()
109 assert caps.has_mlx is False
110 with force_cpu():
111 caps = probe()
112 assert caps.has_mlx is False
113
114 def test_mps_reports_mlx_when_both_modules_installed(
115 self, monkeypatch: pytest.MonkeyPatch
116 ) -> None:
117 # Simulate both mlx + mlx_lm available.
118 from dlm.hardware import capabilities as caps_mod
119
120 real_avail = caps_mod._module_available
121
122 def fake_available(name: str) -> bool:
123 if name in ("mlx", "mlx_lm"):
124 return True
125 return real_avail(name)
126
127 monkeypatch.setattr(caps_mod, "_module_available", fake_available)
128 with force_mps():
129 caps = probe()
130 assert caps.has_mlx is True
131
132 def test_mps_reports_no_mlx_when_mlx_lm_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
133 from dlm.hardware import capabilities as caps_mod
134
135 real_avail = caps_mod._module_available
136
137 def fake_available(name: str) -> bool:
138 if name == "mlx":
139 return True
140 if name == "mlx_lm":
141 return False
142 return real_avail(name)
143
144 monkeypatch.setattr(caps_mod, "_module_available", fake_available)
145 with force_mps():
146 caps = probe()
147 assert caps.has_mlx is False
148
149
150 class TestProbeCpu:
151 def test_cpu_advisory_determinism(self) -> None:
152 with force_cpu():
153 caps = probe()
154 assert caps.backend == Backend.CPU
155 assert caps.determinism_class == "advisory"
156 assert caps.vram_gb is None
157 assert caps.unified_memory_gb is None
158 assert caps.supports_bf16 is False
159 assert caps.has_flash_attention is False
160
161
162 class TestTelemetryPosture:
163 def test_reports_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
164 monkeypatch.setenv("HF_HUB_DISABLE_TELEMETRY", "1")
165 monkeypatch.setenv("DO_NOT_TRACK", "1")
166 with force_cpu():
167 caps = probe()
168 assert caps.telemetry_posture["HF_HUB_DISABLE_TELEMETRY"] == "1"
169 assert caps.telemetry_posture["DO_NOT_TRACK"] == "1"
170 assert "wandb_installed" in caps.telemetry_posture
171 assert "python" in caps.telemetry_posture
172
173 def test_unset_vars_show_placeholder(self, monkeypatch: pytest.MonkeyPatch) -> None:
174 monkeypatch.delenv("HF_HUB_DISABLE_TELEMETRY", raising=False)
175 monkeypatch.delenv("DO_NOT_TRACK", raising=False)
176 monkeypatch.delenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", raising=False)
177 with force_cpu():
178 caps = probe()
179 assert caps.telemetry_posture["HF_HUB_DISABLE_TELEMETRY"] == "<unset>"
180 assert caps.telemetry_posture["DO_NOT_TRACK"] == "<unset>"
181
182
183 class TestCoverageEdges:
184 def test_rocm_arch_none_is_not_bf16_capable(self) -> None:
185 assert _rocm_arch_supports_bf16(None) is False
186
187 def test_accelerate_version_missing_returns_none(self) -> None:
188 from importlib.metadata import PackageNotFoundError
189
190 with patch("importlib.metadata.version", side_effect=PackageNotFoundError):
191 assert _accelerate_version() is None