Python · 4633 bytes Raw Blame History
1 """Hardware-capability mocks for doctor and planner tests.
2
3 Each context manager flips a consistent set of `torch.*` attributes so
4 code paths gated on `torch.cuda.is_available()`, `get_device_capability()`,
5 `mem_get_info()`, MPS availability, or `torch.version.hip` behave as if
6 the requested backend is present — without real hardware.
7
8 `torch` is imported inside each function so merely collecting the module
9 never touches torch state.
10 """
11
12 from __future__ import annotations
13
14 import contextlib
15 from collections.abc import Iterator
16 from unittest.mock import patch
17
18
19 @contextlib.contextmanager
20 def force_cuda(
21 sm: tuple[int, int] = (8, 0),
22 vram_gb: float = 24.0,
23 device_name: str = "NVIDIA GeForce RTX 4090",
24 ) -> Iterator[None]:
25 """Pretend a CUDA GPU with `sm` compute capability and `vram_gb` free.
26
27 `torch.cuda.mem_get_info()` returns (free, total) in bytes; we report
28 the same value for both to make arithmetic simple.
29 """
30 import torch
31
32 free_bytes = int(vram_gb * (1024**3))
33 total_bytes = free_bytes
34 patches = [
35 patch.object(torch.cuda, "is_available", return_value=True),
36 patch.object(torch.cuda, "device_count", return_value=1),
37 patch.object(torch.cuda, "get_device_name", return_value=device_name),
38 patch.object(torch.cuda, "get_device_capability", return_value=sm),
39 patch.object(torch.cuda, "mem_get_info", return_value=(free_bytes, total_bytes)),
40 # ROCm build attribute must be absent on a "real" CUDA box.
41 patch.object(torch.version, "hip", None),
42 patch.object(torch.backends.mps, "is_available", return_value=False),
43 patch.object(torch.backends.mps, "is_built", return_value=False),
44 ]
45 with contextlib.ExitStack() as stack:
46 for p in patches:
47 stack.enter_context(p)
48 yield
49
50
51 @contextlib.contextmanager
52 def force_rocm(
53 vram_gb: float = 16.0,
54 device_name: str = "AMD Radeon RX 7900 XTX",
55 hip_version: str = "6.0",
56 sm: tuple[int, int] = (11, 0), # HIP compute capability (RDNA3 ≈ 11.0.3)
57 gcn_arch_name: str = "gfx1100",
58 ) -> Iterator[None]:
59 """Pretend a ROCm GPU. `torch.version.hip` is the distinguishing mark.
60
61 `gcn_arch_name` (Sprint 22) is the AMD arch string — `gfx90a`
62 (MI200), `gfx942` (MI300), `gfx1100` (RDNA3), `gfx1030` (RDNA2),
63 etc. The bf16 + FlashAttention probes allowlist against this
64 string rather than the CUDA-style SM tuple.
65 """
66 from types import SimpleNamespace
67
68 import torch
69
70 free_bytes = int(vram_gb * (1024**3))
71 total_bytes = free_bytes
72 device_props = SimpleNamespace(
73 name=device_name,
74 gcnArchName=gcn_arch_name,
75 total_memory=total_bytes,
76 )
77 patches = [
78 patch.object(torch.cuda, "is_available", return_value=True),
79 patch.object(torch.cuda, "device_count", return_value=1),
80 patch.object(torch.cuda, "get_device_name", return_value=device_name),
81 patch.object(torch.cuda, "get_device_capability", return_value=sm),
82 patch.object(torch.cuda, "get_device_properties", return_value=device_props),
83 patch.object(torch.cuda, "mem_get_info", return_value=(free_bytes, total_bytes)),
84 patch.object(torch.version, "hip", hip_version),
85 patch.object(torch.backends.mps, "is_available", return_value=False),
86 patch.object(torch.backends.mps, "is_built", return_value=False),
87 ]
88 with contextlib.ExitStack() as stack:
89 for p in patches:
90 stack.enter_context(p)
91 yield
92
93
94 @contextlib.contextmanager
95 def force_mps() -> Iterator[None]:
96 """Pretend Apple Silicon (MPS backend available, no CUDA)."""
97 import torch
98
99 patches = [
100 patch.object(torch.cuda, "is_available", return_value=False),
101 patch.object(torch.version, "hip", None),
102 patch.object(torch.backends.mps, "is_available", return_value=True),
103 patch.object(torch.backends.mps, "is_built", return_value=True),
104 ]
105 with contextlib.ExitStack() as stack:
106 for p in patches:
107 stack.enter_context(p)
108 yield
109
110
111 @contextlib.contextmanager
112 def force_cpu() -> Iterator[None]:
113 """Pretend CPU-only (no CUDA, no MPS)."""
114 import torch
115
116 patches = [
117 patch.object(torch.cuda, "is_available", return_value=False),
118 patch.object(torch.version, "hip", None),
119 patch.object(torch.backends.mps, "is_available", return_value=False),
120 patch.object(torch.backends.mps, "is_built", return_value=False),
121 ]
122 with contextlib.ExitStack() as stack:
123 for p in patches:
124 stack.enter_context(p)
125 yield