| 1 |
"""Verify hardware_mocks flip torch attributes as advertised.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import torch |
| 6 |
|
| 7 |
from tests.fixtures.hardware_mocks import force_cpu, force_cuda, force_mps, force_rocm |
| 8 |
|
| 9 |
|
| 10 |
class TestForceCuda: |
| 11 |
def test_reports_cuda_available(self) -> None: |
| 12 |
with force_cuda(sm=(8, 6), vram_gb=16.0): |
| 13 |
assert torch.cuda.is_available() is True |
| 14 |
assert torch.cuda.get_device_capability() == (8, 6) |
| 15 |
assert torch.backends.mps.is_available() is False |
| 16 |
assert torch.version.hip is None |
| 17 |
|
| 18 |
def test_mem_get_info_reports_requested_vram(self) -> None: |
| 19 |
with force_cuda(vram_gb=12.0): |
| 20 |
free, total = torch.cuda.mem_get_info() |
| 21 |
assert free == int(12.0 * (1024**3)) |
| 22 |
assert total == free |
| 23 |
|
| 24 |
def test_restores_state_after_context(self) -> None: |
| 25 |
pre = torch.cuda.is_available() |
| 26 |
with force_cuda(): |
| 27 |
assert torch.cuda.is_available() is True |
| 28 |
assert torch.cuda.is_available() is pre |
| 29 |
|
| 30 |
|
| 31 |
class TestForceMps: |
| 32 |
def test_reports_mps_available(self) -> None: |
| 33 |
with force_mps(): |
| 34 |
assert torch.backends.mps.is_available() is True |
| 35 |
assert torch.backends.mps.is_built() is True |
| 36 |
assert torch.cuda.is_available() is False |
| 37 |
assert torch.version.hip is None |
| 38 |
|
| 39 |
|
| 40 |
class TestForceCpu: |
| 41 |
def test_reports_nothing_available(self) -> None: |
| 42 |
with force_cpu(): |
| 43 |
assert torch.cuda.is_available() is False |
| 44 |
assert torch.backends.mps.is_available() is False |
| 45 |
assert torch.version.hip is None |
| 46 |
|
| 47 |
|
| 48 |
class TestForceRocm: |
| 49 |
def test_reports_hip_version_and_no_mps(self) -> None: |
| 50 |
with force_rocm(hip_version="6.0"): |
| 51 |
# On ROCm, torch.cuda.is_available() is True (hip impersonates cuda) |
| 52 |
assert torch.cuda.is_available() is True |
| 53 |
assert torch.version.hip == "6.0" |
| 54 |
assert torch.backends.mps.is_available() is False |
| 55 |
|
| 56 |
|
| 57 |
class TestNesting: |
| 58 |
def test_inner_context_overrides_outer(self) -> None: |
| 59 |
with force_cpu(): |
| 60 |
assert torch.cuda.is_available() is False |
| 61 |
with force_cuda(): |
| 62 |
assert torch.cuda.is_available() is True |
| 63 |
# After inner exits, outer context restored. |
| 64 |
assert torch.cuda.is_available() is False |