Python · 2319 bytes Raw Blame History
1 """Verify hardware_mocks flip torch attributes as advertised."""
2
3 from __future__ import annotations
4
5 import torch
6
7 from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps, force_rocm
8
9
10 class TestForceCuda:
11 def test_reports_cuda_available(self) -> None:
12 with force_cuda(sm=(8, 6), vram_gb=16.0):
13 assert torch.cuda.is_available() is True
14 assert torch.cuda.get_device_capability() == (8, 6)
15 assert torch.backends.mps.is_available() is False
16 assert torch.version.hip is None
17
18 def test_mem_get_info_reports_requested_vram(self) -> None:
19 with force_cuda(vram_gb=12.0):
20 free, total = torch.cuda.mem_get_info()
21 assert free == int(12.0 * (1024**3))
22 assert total == free
23
24 def test_restores_state_after_context(self) -> None:
25 pre = torch.cuda.is_available()
26 with force_cuda():
27 assert torch.cuda.is_available() is True
28 assert torch.cuda.is_available() is pre
29
30
31 class TestForceMps:
32 def test_reports_mps_available(self) -> None:
33 with force_mps():
34 assert torch.backends.mps.is_available() is True
35 assert torch.backends.mps.is_built() is True
36 assert torch.cuda.is_available() is False
37 assert torch.version.hip is None
38
39
40 class TestForceCpu:
41 def test_reports_nothing_available(self) -> None:
42 with force_cpu():
43 assert torch.cuda.is_available() is False
44 assert torch.backends.mps.is_available() is False
45 assert torch.version.hip is None
46
47
48 class TestForceRocm:
49 def test_reports_hip_version_and_no_mps(self) -> None:
50 with force_rocm(hip_version="6.0"):
51 # On ROCm, torch.cuda.is_available() is True (hip impersonates cuda)
52 assert torch.cuda.is_available() is True
53 assert torch.version.hip == "6.0"
54 assert torch.backends.mps.is_available() is False
55
56
57 class TestNesting:
58 def test_inner_context_overrides_outer(self) -> None:
59 with force_cpu():
60 assert torch.cuda.is_available() is False
61 with force_cuda():
62 assert torch.cuda.is_available() is True
63 # After inner exits, outer context restored.
64 assert torch.cuda.is_available() is False