| 1 |
"""ROCm training-plan resolution — Sprint 22. |
| 2 |
|
| 3 |
Covers: |
| 4 |
|
| 5 |
- Arch-aware bf16 probe (`gfx90a` / `gfx942` / `gfx1100` → bf16, |
| 6 |
`gfx1030` RDNA2 / `gfx906` Vega20 → fp16). |
| 7 |
- FlashAttention gating: bf16-capable arch + `flash_attn` importable |
| 8 |
→ FA2 enabled; bf16-incapable arch always SDPA regardless of |
| 9 |
package presence. |
| 10 |
- QLoRA refusal: permanent, no "Sprint 22" pointer. |
| 11 |
- `TrainingPlan` picks correct precision + attention on each arch. |
| 12 |
""" |
| 13 |
|
| 14 |
from __future__ import annotations |
| 15 |
|
| 16 |
from unittest.mock import patch |
| 17 |
|
| 18 |
import pytest |
| 19 |
|
| 20 |
from dlm.doc.schema import TrainingConfig |
| 21 |
from dlm.hardware.backend import Backend |
| 22 |
from dlm.hardware.capabilities import probe |
| 23 |
from dlm.hardware.plan import resolve |
| 24 |
from dlm.hardware.refusals import ResolutionError, check_refusals |
| 25 |
from tests.fixtures.hardware_mocks import force_rocm |
| 26 |
|
| 27 |
|
| 28 |
def _cfg(**overrides: object) -> TrainingConfig: |
| 29 |
base = {"adapter": "lora", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05} |
| 30 |
base.update(overrides) |
| 31 |
return TrainingConfig(**base) # type: ignore[arg-type] |
| 32 |
|
| 33 |
|
| 34 |
class TestRocmBf16Matrix: |
| 35 |
@pytest.mark.parametrize( |
| 36 |
("arch", "expected"), |
| 37 |
[ |
| 38 |
("gfx90a", True), # MI200 |
| 39 |
("gfx942", True), # MI300 |
| 40 |
("gfx1100", True), # RDNA3 7900 XTX |
| 41 |
("gfx1101", True), # RDNA3 7800 XT |
| 42 |
("gfx1102", True), # RDNA3 7700 XT |
| 43 |
("gfx1030", False), # RDNA2 6900 XT — no bf16 |
| 44 |
("gfx906", False), # Vega20 — no bf16 |
| 45 |
("gfx908", False), # CDNA1 MI100 — no bf16 |
| 46 |
], |
| 47 |
) |
| 48 |
def test_bf16_by_arch(self, arch: str, expected: bool) -> None: |
| 49 |
with force_rocm(gcn_arch_name=arch): |
| 50 |
caps = probe() |
| 51 |
assert caps.backend == Backend.ROCM |
| 52 |
assert caps.rocm_arch == arch |
| 53 |
assert caps.supports_bf16 is expected |
| 54 |
|
| 55 |
def test_bf16_plan_picks_bf16_on_rdna3(self) -> None: |
| 56 |
with force_rocm(gcn_arch_name="gfx1100"): |
| 57 |
caps = probe() |
| 58 |
plan = resolve(_cfg(), caps, base_params=135_000_000, seq_len=512) |
| 59 |
assert plan.precision == "bf16" |
| 60 |
|
| 61 |
def test_bf16_plan_falls_back_to_fp16_on_rdna2(self) -> None: |
| 62 |
with force_rocm(gcn_arch_name="gfx1030"): |
| 63 |
caps = probe() |
| 64 |
plan = resolve(_cfg(), caps, base_params=135_000_000, seq_len=512) |
| 65 |
assert plan.precision == "fp16" |
| 66 |
|
| 67 |
def test_gcn_arch_strips_xnack_suffix(self) -> None: |
| 68 |
"""ROCm sometimes appends `:sramecc+:xnack-` — we match the bare arch.""" |
| 69 |
with force_rocm(gcn_arch_name="gfx90a:sramecc+:xnack-"): |
| 70 |
caps = probe() |
| 71 |
assert caps.rocm_arch == "gfx90a" |
| 72 |
assert caps.supports_bf16 is True |
| 73 |
|
| 74 |
|
| 75 |
class TestRocmFlashAttention: |
| 76 |
def test_fa2_enabled_when_package_present_and_arch_ok(self) -> None: |
| 77 |
with ( |
| 78 |
patch( |
| 79 |
"dlm.hardware.capabilities._module_available", |
| 80 |
lambda name: name == "flash_attn", |
| 81 |
), |
| 82 |
force_rocm(gcn_arch_name="gfx90a"), |
| 83 |
): |
| 84 |
caps = probe() |
| 85 |
assert caps.has_flash_attention is True |
| 86 |
|
| 87 |
def test_fa2_disabled_when_package_absent(self) -> None: |
| 88 |
with ( |
| 89 |
patch("dlm.hardware.capabilities._module_available", lambda name: False), |
| 90 |
force_rocm(gcn_arch_name="gfx1100"), |
| 91 |
): |
| 92 |
caps = probe() |
| 93 |
assert caps.has_flash_attention is False |
| 94 |
|
| 95 |
def test_fa2_disabled_on_rdna2_even_with_package(self) -> None: |
| 96 |
"""bf16-incapable arch never gets FA2, regardless of `flash_attn` install.""" |
| 97 |
with ( |
| 98 |
patch( |
| 99 |
"dlm.hardware.capabilities._module_available", |
| 100 |
lambda name: name == "flash_attn", |
| 101 |
), |
| 102 |
force_rocm(gcn_arch_name="gfx1030"), |
| 103 |
): |
| 104 |
caps = probe() |
| 105 |
assert caps.has_flash_attention is False |
| 106 |
|
| 107 |
|
| 108 |
class TestRocmQLoRARefusal: |
| 109 |
def test_qlora_on_rocm_refuses_with_bitsandbytes_explanation(self) -> None: |
| 110 |
with force_rocm(): |
| 111 |
caps = probe() |
| 112 |
with pytest.raises(ResolutionError) as exc_info: |
| 113 |
check_refusals(_cfg(adapter="qlora"), caps, base_params=1_500_000_000) |
| 114 |
msg = str(exc_info.value) |
| 115 |
assert "bitsandbytes" in msg |
| 116 |
assert "ROCm" in msg |
| 117 |
# Sprint 22 landed; the message must not promise future work. |
| 118 |
assert "Sprint 22" not in msg |
| 119 |
assert "wait for" not in msg.lower() |
| 120 |
|
| 121 |
def test_lora_on_rocm_passes(self) -> None: |
| 122 |
with force_rocm(): |
| 123 |
caps = probe() |
| 124 |
# Should not raise. |
| 125 |
check_refusals(_cfg(adapter="lora"), caps, base_params=1_500_000_000) |
| 126 |
|
| 127 |
|
| 128 |
class TestRocmPlanIntegration: |
| 129 |
def test_tier2_happy_path_rdna3(self) -> None: |
| 130 |
with force_rocm(gcn_arch_name="gfx1100", vram_gb=24.0): |
| 131 |
caps = probe() |
| 132 |
plan = resolve(_cfg(), caps, base_params=1_500_000_000, seq_len=2048) |
| 133 |
assert plan.precision == "bf16" |
| 134 |
# Without flash_attn installed locally, SDPA is the expected fallback. |
| 135 |
assert plan.attn_implementation in ("sdpa", "flash_attention_2") |
| 136 |
assert plan.use_qlora is False |