| 1 |
"""`--gpus` flag parsing + resolution.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import pytest |
| 6 |
|
| 7 |
from dlm.train.distributed.gpus import GpuSpec, UnsupportedGpuSpecError, parse_gpus |
| 8 |
|
| 9 |
|
| 10 |
class TestParseGpus: |
| 11 |
def test_none_raises_empty(self) -> None: |
| 12 |
with pytest.raises(UnsupportedGpuSpecError, match="empty"): |
| 13 |
parse_gpus(None) # type: ignore[arg-type] |
| 14 |
|
| 15 |
def test_all_case_insensitive(self) -> None: |
| 16 |
for value in ("all", "ALL", "All"): |
| 17 |
spec = parse_gpus(value) |
| 18 |
assert spec == GpuSpec(kind="all", value=None) |
| 19 |
|
| 20 |
def test_integer_count(self) -> None: |
| 21 |
assert parse_gpus("2") == GpuSpec(kind="count", value=2) |
| 22 |
assert parse_gpus(" 4 ") == GpuSpec(kind="count", value=4) |
| 23 |
|
| 24 |
def test_comma_list(self) -> None: |
| 25 |
assert parse_gpus("0,1") == GpuSpec(kind="list", value=(0, 1)) |
| 26 |
assert parse_gpus("0, 1, 3") == GpuSpec(kind="list", value=(0, 1, 3)) |
| 27 |
|
| 28 |
def test_empty_raises(self) -> None: |
| 29 |
with pytest.raises(UnsupportedGpuSpecError, match="empty"): |
| 30 |
parse_gpus("") |
| 31 |
with pytest.raises(UnsupportedGpuSpecError, match="empty"): |
| 32 |
parse_gpus(" ") |
| 33 |
|
| 34 |
def test_negative_list_rejected(self) -> None: |
| 35 |
with pytest.raises(UnsupportedGpuSpecError, match="negative"): |
| 36 |
parse_gpus("0,-1") |
| 37 |
|
| 38 |
def test_non_integer_list_rejected(self) -> None: |
| 39 |
with pytest.raises(UnsupportedGpuSpecError, match="non-integer"): |
| 40 |
parse_gpus("0,foo,1") |
| 41 |
|
| 42 |
def test_empty_comma_list_rejected(self) -> None: |
| 43 |
with pytest.raises(UnsupportedGpuSpecError, match="is empty"): |
| 44 |
parse_gpus(", ,") |
| 45 |
|
| 46 |
def test_malformed_scalar_rejected(self) -> None: |
| 47 |
with pytest.raises(UnsupportedGpuSpecError, match="not `all`"): |
| 48 |
parse_gpus("xyz") |
| 49 |
|
| 50 |
|
| 51 |
class TestResolveGpuSpec: |
| 52 |
def test_list_returns_requested_ids(self) -> None: |
| 53 |
spec = GpuSpec(kind="list", value=(0, 2)) |
| 54 |
assert spec.resolve(device_count=4) == (0, 2) |
| 55 |
|
| 56 |
def test_all_returns_full_range(self) -> None: |
| 57 |
spec = GpuSpec(kind="all", value=None) |
| 58 |
assert spec.resolve(device_count=3) == (0, 1, 2) |
| 59 |
|
| 60 |
def test_count_returns_prefix(self) -> None: |
| 61 |
spec = GpuSpec(kind="count", value=2) |
| 62 |
assert spec.resolve(device_count=4) == (0, 1) |
| 63 |
|
| 64 |
def test_count_exceeding_visible_raises(self) -> None: |
| 65 |
spec = GpuSpec(kind="count", value=4) |
| 66 |
with pytest.raises(UnsupportedGpuSpecError, match="exceeds"): |
| 67 |
spec.resolve(device_count=2) |
| 68 |
|
| 69 |
def test_count_zero_rejected(self) -> None: |
| 70 |
spec = GpuSpec(kind="count", value=0) |
| 71 |
with pytest.raises(UnsupportedGpuSpecError, match=">= 1"): |
| 72 |
spec.resolve(device_count=4) |
| 73 |
|
| 74 |
def test_list_out_of_range_raises(self) -> None: |
| 75 |
spec = GpuSpec(kind="list", value=(0, 5)) |
| 76 |
with pytest.raises(UnsupportedGpuSpecError, match="out-of-range"): |
| 77 |
spec.resolve(device_count=2) |
| 78 |
|
| 79 |
def test_list_duplicate_raises(self) -> None: |
| 80 |
spec = GpuSpec(kind="list", value=(0, 1, 1)) |
| 81 |
with pytest.raises(UnsupportedGpuSpecError, match="duplicate"): |
| 82 |
spec.resolve(device_count=4) |
| 83 |
|
| 84 |
def test_no_visible_devices_raises(self) -> None: |
| 85 |
spec = GpuSpec(kind="all", value=None) |
| 86 |
with pytest.raises(UnsupportedGpuSpecError, match="at least 1"): |
| 87 |
spec.resolve(device_count=0) |