tenseleyflow/documentlanguagemodel / 9614cd5

Browse files

feat(hardware): world_size-aware resolve + multi-GPU refusal matrix

Authored by espadonne
SHA
9614cd5875cd28b6cbdc922dde192e54febf9d34
Parents
181e542
Tree
c8f6b8f

2 changed files

StatusFile+-
M src/dlm/hardware/plan.py 21 2
M src/dlm/hardware/refusals.py 54 0
src/dlm/hardware/plan.pymodified
@@ -24,7 +24,7 @@ from dlm.doc.schema import TrainingConfig
2424
 from dlm.hardware.backend import Backend
2525
 from dlm.hardware.capabilities import Capabilities
2626
 from dlm.hardware.memory import estimate_peak_vram_gb, estimate_step_seconds
27
-from dlm.hardware.refusals import check_refusals
27
+from dlm.hardware.refusals import check_multi_gpu_refusals, check_refusals
2828
 
2929
 AttnImpl = Literal["flash_attention_2", "sdpa", "eager"]
3030
 Precision = Literal["bf16", "fp16"]
@@ -39,6 +39,10 @@ class TrainingPlan:
3939
     """Resolved training plan for the current host.
4040
 
4141
     Fields mirror the knobs the trainer (Sprint 09) actually consumes.
42
+    `world_size` (Sprint 23) is the number of data-parallel ranks; 1
43
+    on single-GPU / single-process paths. `effective_batch_size`
44
+    already folds `world_size` in, so users reading the plan don't
45
+    have to multiply themselves.
4246
     """
4347
 
4448
     precision: Precision
@@ -49,6 +53,7 @@ class TrainingPlan:
4953
     grad_accum: int
5054
     effective_batch_size: int
5155
     gradient_checkpointing: bool
56
+    world_size: int
5257
     est_peak_vram_gb: float
5358
     est_step_seconds: float
5459
     reason: str
@@ -70,6 +75,7 @@ def resolve(
7075
     force: bool = False,
7176
     phase: Phase = "sft",
7277
     num_adapters: int = 1,
78
+    world_size: int = 1,
7379
 ) -> TrainingPlan:
7480
     """Produce a concrete plan from a frontmatter config + host caps.
7581
 
@@ -82,7 +88,19 @@ def resolve(
8288
     `num_adapters` lets multi-adapter callers surface the count so
8389
     F28 (multi-adapter QLoRA VRAM refusal) can fire before training
8490
     starts. Single-adapter docs keep the default.
91
+
92
+    `world_size` (Sprint 23) is the number of data-parallel ranks.
93
+    Multiplies the reported `effective_batch_size` (each rank
94
+    processes a micro-batch independently) and scales the per-rank
95
+    step-time estimate down — more GPUs, less wall-clock time per
96
+    global step up to comm overhead. `world_size > 1` triggers the
97
+    multi-GPU refusal matrix (MPS/CPU refusal, heterogeneous CUDA
98
+    refusal).
8599
     """
100
+    if world_size < 1:
101
+        raise ValueError(f"world_size must be >= 1, got {world_size}")
102
+    if world_size > 1:
103
+        check_multi_gpu_refusals(caps, world_size)
86104
     check_refusals(
87105
         training, caps, base_params, force=force, num_adapters=num_adapters
88106
     )
@@ -132,8 +150,9 @@ def resolve(
132150
         quant_compute_dtype=quant_dtype,
133151
         micro_batch_size=micro_batch,
134152
         grad_accum=grad_accum,
135
-        effective_batch_size=micro_batch * grad_accum,
153
+        effective_batch_size=micro_batch * grad_accum * world_size,
136154
         gradient_checkpointing=gradient_checkpointing,
155
+        world_size=world_size,
137156
         est_peak_vram_gb=round(est_peak, 2),
138157
         est_step_seconds=round(est_step, 2),
139158
         reason=reason,
src/dlm/hardware/refusals.pymodified
@@ -99,6 +99,60 @@ def check_refusals(
9999
         )
100100
 
101101
 
102
+def check_multi_gpu_refusals(caps: Capabilities, world_size: int) -> None:
103
+    """Refuse multi-GPU configurations that can't reasonably work.
104
+
105
+    Sprint 23 scope: CUDA only. MPS doesn't do DDP; CPU multi-process
106
+    training is technically possible but a terrible user experience.
107
+    Heterogeneous CUDA GPUs (different SM families) produce
108
+    inconsistent mixed-precision results — refuse rather than let the
109
+    slower arch silently dictate the precision.
110
+
111
+    ROCm multi-GPU is explicitly out of scope for this sprint per the
112
+    sprint 23 plan — refuse with a pointer so users don't chase
113
+    phantom bugs.
114
+    """
115
+    if world_size < 2:
116
+        return
117
+    if caps.backend == Backend.MPS:
118
+        raise ResolutionError(
119
+            "Multi-GPU training on Apple Silicon (MPS) is not supported; "
120
+            "MPS has no DDP path. Train single-GPU or on a CUDA host.",
121
+        )
122
+    if caps.backend == Backend.CPU:
123
+        raise ResolutionError(
124
+            "Multi-GPU training on CPU is not supported. "
125
+            "Drop `--gpus` or run single-process.",
126
+        )
127
+    if caps.backend == Backend.ROCM:
128
+        raise ResolutionError(
129
+            "Multi-GPU training on ROCm is out of scope for Sprint 23; "
130
+            "train single-GPU on ROCm or use a CUDA host for multi-GPU runs.",
131
+        )
132
+    # CUDA path — heterogeneous detection is the caller's responsibility
133
+    # since `Capabilities` only reports a single device. Callers that
134
+    # assemble multi-device state (the launcher) should call
135
+    # `assert_homogeneous_cuda` directly before spawning ranks.
136
+
137
+
138
+def assert_homogeneous_cuda(sm_per_device: list[tuple[int, int] | None]) -> None:
139
+    """Refuse if the configured CUDA devices span different SM families.
140
+
141
+    Accepts the list of SM tuples the launcher collected from
142
+    `torch.cuda.get_device_capability(i)` for each selected device.
143
+    Mixed precision behavior on heterogeneous GPUs (e.g. Ampere +
144
+    Turing) is unreliable — bf16 paths silently fall back to fp16 on
145
+    the Turing card and the two ranks drift.
146
+    """
147
+    unique = {sm for sm in sm_per_device if sm is not None}
148
+    if len(unique) > 1:
149
+        raise ResolutionError(
150
+            f"Heterogeneous CUDA GPUs detected (SM families: {sorted(unique)}); "
151
+            "multi-GPU training requires matching compute capability. "
152
+            "Select GPUs of the same generation via `--gpus 0,1` etc.",
153
+        )
154
+
155
+
102156
 def _effective_adapter(training: TrainingConfig) -> str:
103157
     """Return the adapter type effectively in force.
104158