Python · 5051 bytes Raw Blame History
1 """ROCm training-plan resolution — Sprint 22.
2
3 Covers:
4
5 - Arch-aware bf16 probe (`gfx90a` / `gfx942` / `gfx1100` → bf16,
6 `gfx1030` RDNA2 / `gfx906` Vega20 → fp16).
7 - FlashAttention gating: bf16-capable arch + `flash_attn` importable
8 → FA2 enabled; bf16-incapable arch always SDPA regardless of
9 package presence.
10 - QLoRA refusal: permanent, no "Sprint 22" pointer.
11 - `TrainingPlan` picks correct precision + attention on each arch.
12 """
13
14 from __future__ import annotations
15
16 from unittest.mock import patch
17
18 import pytest
19
20 from dlm.doc.schema import TrainingConfig
21 from dlm.hardware.backend import Backend
22 from dlm.hardware.capabilities import probe
23 from dlm.hardware.plan import resolve
24 from dlm.hardware.refusals import ResolutionError, check_refusals
25 from tests.fixtures.hardware_mocks import force_rocm
26
27
28 def _cfg(**overrides: object) -> TrainingConfig:
29 base = {"adapter": "lora", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05}
30 base.update(overrides)
31 return TrainingConfig(**base) # type: ignore[arg-type]
32
33
34 class TestRocmBf16Matrix:
35 @pytest.mark.parametrize(
36 ("arch", "expected"),
37 [
38 ("gfx90a", True), # MI200
39 ("gfx942", True), # MI300
40 ("gfx1100", True), # RDNA3 7900 XTX
41 ("gfx1101", True), # RDNA3 7800 XT
42 ("gfx1102", True), # RDNA3 7700 XT
43 ("gfx1030", False), # RDNA2 6900 XT — no bf16
44 ("gfx906", False), # Vega20 — no bf16
45 ("gfx908", False), # CDNA1 MI100 — no bf16
46 ],
47 )
48 def test_bf16_by_arch(self, arch: str, expected: bool) -> None:
49 with force_rocm(gcn_arch_name=arch):
50 caps = probe()
51 assert caps.backend == Backend.ROCM
52 assert caps.rocm_arch == arch
53 assert caps.supports_bf16 is expected
54
55 def test_bf16_plan_picks_bf16_on_rdna3(self) -> None:
56 with force_rocm(gcn_arch_name="gfx1100"):
57 caps = probe()
58 plan = resolve(_cfg(), caps, base_params=135_000_000, seq_len=512)
59 assert plan.precision == "bf16"
60
61 def test_bf16_plan_falls_back_to_fp16_on_rdna2(self) -> None:
62 with force_rocm(gcn_arch_name="gfx1030"):
63 caps = probe()
64 plan = resolve(_cfg(), caps, base_params=135_000_000, seq_len=512)
65 assert plan.precision == "fp16"
66
67 def test_gcn_arch_strips_xnack_suffix(self) -> None:
68 """ROCm sometimes appends `:sramecc+:xnack-` — we match the bare arch."""
69 with force_rocm(gcn_arch_name="gfx90a:sramecc+:xnack-"):
70 caps = probe()
71 assert caps.rocm_arch == "gfx90a"
72 assert caps.supports_bf16 is True
73
74
75 class TestRocmFlashAttention:
76 def test_fa2_enabled_when_package_present_and_arch_ok(self) -> None:
77 with (
78 patch(
79 "dlm.hardware.capabilities._module_available",
80 lambda name: name == "flash_attn",
81 ),
82 force_rocm(gcn_arch_name="gfx90a"),
83 ):
84 caps = probe()
85 assert caps.has_flash_attention is True
86
87 def test_fa2_disabled_when_package_absent(self) -> None:
88 with (
89 patch("dlm.hardware.capabilities._module_available", lambda name: False),
90 force_rocm(gcn_arch_name="gfx1100"),
91 ):
92 caps = probe()
93 assert caps.has_flash_attention is False
94
95 def test_fa2_disabled_on_rdna2_even_with_package(self) -> None:
96 """bf16-incapable arch never gets FA2, regardless of `flash_attn` install."""
97 with (
98 patch(
99 "dlm.hardware.capabilities._module_available",
100 lambda name: name == "flash_attn",
101 ),
102 force_rocm(gcn_arch_name="gfx1030"),
103 ):
104 caps = probe()
105 assert caps.has_flash_attention is False
106
107
108 class TestRocmQLoRARefusal:
109 def test_qlora_on_rocm_refuses_with_bitsandbytes_explanation(self) -> None:
110 with force_rocm():
111 caps = probe()
112 with pytest.raises(ResolutionError) as exc_info:
113 check_refusals(_cfg(adapter="qlora"), caps, base_params=1_500_000_000)
114 msg = str(exc_info.value)
115 assert "bitsandbytes" in msg
116 assert "ROCm" in msg
117 # Sprint 22 landed; the message must not promise future work.
118 assert "Sprint 22" not in msg
119 assert "wait for" not in msg.lower()
120
121 def test_lora_on_rocm_passes(self) -> None:
122 with force_rocm():
123 caps = probe()
124 # Should not raise.
125 check_refusals(_cfg(adapter="lora"), caps, base_params=1_500_000_000)
126
127
128 class TestRocmPlanIntegration:
129 def test_tier2_happy_path_rdna3(self) -> None:
130 with force_rocm(gcn_arch_name="gfx1100", vram_gb=24.0):
131 caps = probe()
132 plan = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048)
133 assert plan.precision == "bf16"
134 # Without flash_attn installed locally, SDPA is the expected fallback.
135 assert plan.attn_implementation in ("sdpa", "flash_attention_2")
136 assert plan.use_qlora is False