@@ -0,0 +1,63 @@ |
| 1 | +"""ROCm training smoke (Sprint 22). |
| 2 | + |
| 3 | +Verifies the doctor→plan→trainer pipeline actually runs on a ROCm |
| 4 | +host without the refusal matrix blocking LoRA. Uses the tiny-model |
| 5 | +session fixture and runs for a single step; the smoke is that |
| 6 | +`run_training` returns a result rather than raising. |
| 7 | + |
| 8 | +Skipped unless: |
| 9 | +- `torch.version.hip` is truthy at runtime (real ROCm torch build) |
| 10 | +- `DLM_ENABLE_ROCM_SMOKE=1` in the environment (opt-in even on a |
| 11 | + ROCm host so local `pytest -m slow` runs stay CPU/CUDA-only) |
| 12 | + |
| 13 | +CI: no default runner exists; expected to be run on a self-hosted |
| 14 | +ROCm box via a scheduled workflow. Documented in |
| 15 | +`docs/hardware/rocm.md`. |
| 16 | +""" |
| 17 | + |
| 18 | +from __future__ import annotations |
| 19 | + |
| 20 | +import os |
| 21 | +from typing import TYPE_CHECKING |
| 22 | + |
| 23 | +import pytest |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from tests.fixtures.trained_store import TrainedStoreHandle |
| 27 | + |
| 28 | + |
| 29 | +def _rocm_host() -> bool: |
| 30 | + try: |
| 31 | + import torch |
| 32 | + except ImportError: # pragma: no cover |
| 33 | + return False |
| 34 | + return bool(getattr(torch.version, "hip", None)) |
| 35 | + |
| 36 | + |
| 37 | +pytestmark = [ |
| 38 | + pytest.mark.slow, |
| 39 | + pytest.mark.skipif(not _rocm_host(), reason="requires a ROCm PyTorch build"), |
| 40 | + pytest.mark.skipif( |
| 41 | + os.environ.get("DLM_ENABLE_ROCM_SMOKE") != "1", |
| 42 | + reason="set DLM_ENABLE_ROCM_SMOKE=1 to opt in to the ROCm smoke on a real host", |
| 43 | + ), |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +def test_rocm_lora_smoke_runs( # pragma: no cover - gpu+rocm path |
| 48 | + trained_store: TrainedStoreHandle, |
| 49 | +) -> None: |
| 50 | + """One-step LoRA train on ROCm — no refusal, produces an adapter version bump.""" |
| 51 | + # The `trained_store` session fixture trained once during setup; |
| 52 | + # reaching this point on a ROCm host without a refusal is the |
| 53 | + # smoke signal. Assert the store has at least one committed |
| 54 | + # adapter version. |
| 55 | + store = trained_store.store |
| 56 | + adapter_dir = store.resolve_current_adapter() |
| 57 | + assert adapter_dir is not None, ( |
| 58 | + "trained_store fixture produced no adapter on ROCm — " |
| 59 | + "LoRA path likely blocked by a refusal that shouldn't fire" |
| 60 | + ) |
| 61 | + assert (adapter_dir / "adapter_model.safetensors").exists(), ( |
| 62 | + "ROCm LoRA wrote the pointer but not the adapter weights" |
| 63 | + ) |