Python · 3357 bytes Raw Blame History
1 """`--gpus` flag parsing + resolution."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from dlm.train.distributed.gpus import GpuSpec, UnsupportedGpuSpecError, parse_gpus
8
9
10 class TestParseGpus:
11 def test_none_raises_empty(self) -> None:
12 with pytest.raises(UnsupportedGpuSpecError, match="empty"):
13 parse_gpus(None) # type: ignore[arg-type]
14
15 def test_all_case_insensitive(self) -> None:
16 for value in ("all", "ALL", "All"):
17 spec = parse_gpus(value)
18 assert spec == GpuSpec(kind="all", value=None)
19
20 def test_integer_count(self) -> None:
21 assert parse_gpus("2") == GpuSpec(kind="count", value=2)
22 assert parse_gpus(" 4 ") == GpuSpec(kind="count", value=4)
23
24 def test_comma_list(self) -> None:
25 assert parse_gpus("0,1") == GpuSpec(kind="list", value=(0, 1))
26 assert parse_gpus("0, 1, 3") == GpuSpec(kind="list", value=(0, 1, 3))
27
28 def test_empty_raises(self) -> None:
29 with pytest.raises(UnsupportedGpuSpecError, match="empty"):
30 parse_gpus("")
31 with pytest.raises(UnsupportedGpuSpecError, match="empty"):
32 parse_gpus(" ")
33
34 def test_negative_list_rejected(self) -> None:
35 with pytest.raises(UnsupportedGpuSpecError, match="negative"):
36 parse_gpus("0,-1")
37
38 def test_non_integer_list_rejected(self) -> None:
39 with pytest.raises(UnsupportedGpuSpecError, match="non-integer"):
40 parse_gpus("0,foo,1")
41
42 def test_empty_comma_list_rejected(self) -> None:
43 with pytest.raises(UnsupportedGpuSpecError, match="is empty"):
44 parse_gpus(", ,")
45
46 def test_malformed_scalar_rejected(self) -> None:
47 with pytest.raises(UnsupportedGpuSpecError, match="not `all`"):
48 parse_gpus("xyz")
49
50
51 class TestResolveGpuSpec:
52 def test_list_returns_requested_ids(self) -> None:
53 spec = GpuSpec(kind="list", value=(0, 2))
54 assert spec.resolve(device_count=4) == (0, 2)
55
56 def test_all_returns_full_range(self) -> None:
57 spec = GpuSpec(kind="all", value=None)
58 assert spec.resolve(device_count=3) == (0, 1, 2)
59
60 def test_count_returns_prefix(self) -> None:
61 spec = GpuSpec(kind="count", value=2)
62 assert spec.resolve(device_count=4) == (0, 1)
63
64 def test_count_exceeding_visible_raises(self) -> None:
65 spec = GpuSpec(kind="count", value=4)
66 with pytest.raises(UnsupportedGpuSpecError, match="exceeds"):
67 spec.resolve(device_count=2)
68
69 def test_count_zero_rejected(self) -> None:
70 spec = GpuSpec(kind="count", value=0)
71 with pytest.raises(UnsupportedGpuSpecError, match=">= 1"):
72 spec.resolve(device_count=4)
73
74 def test_list_out_of_range_raises(self) -> None:
75 spec = GpuSpec(kind="list", value=(0, 5))
76 with pytest.raises(UnsupportedGpuSpecError, match="out-of-range"):
77 spec.resolve(device_count=2)
78
79 def test_list_duplicate_raises(self) -> None:
80 spec = GpuSpec(kind="list", value=(0, 1, 1))
81 with pytest.raises(UnsupportedGpuSpecError, match="duplicate"):
82 spec.resolve(device_count=4)
83
84 def test_no_visible_devices_raises(self) -> None:
85 spec = GpuSpec(kind="all", value=None)
86 with pytest.raises(UnsupportedGpuSpecError, match="at least 1"):
87 spec.resolve(device_count=0)