@@ -24,7 +24,7 @@ from dlm.doc.schema import TrainingConfig |
| 24 | from dlm.hardware.backend import Backend | 24 | from dlm.hardware.backend import Backend |
| 25 | from dlm.hardware.capabilities import Capabilities | 25 | from dlm.hardware.capabilities import Capabilities |
| 26 | from dlm.hardware.memory import estimate_peak_vram_gb, estimate_step_seconds | 26 | from dlm.hardware.memory import estimate_peak_vram_gb, estimate_step_seconds |
| 27 | -from dlm.hardware.refusals import check_refusals | 27 | +from dlm.hardware.refusals import check_multi_gpu_refusals, check_refusals |
| 28 | | 28 | |
| 29 | AttnImpl = Literal["flash_attention_2", "sdpa", "eager"] | 29 | AttnImpl = Literal["flash_attention_2", "sdpa", "eager"] |
| 30 | Precision = Literal["bf16", "fp16"] | 30 | Precision = Literal["bf16", "fp16"] |
@@ -39,6 +39,10 @@ class TrainingPlan: |
| 39 | """Resolved training plan for the current host. | 39 | """Resolved training plan for the current host. |
| 40 | | 40 | |
| 41 | Fields mirror the knobs the trainer (Sprint 09) actually consumes. | 41 | Fields mirror the knobs the trainer (Sprint 09) actually consumes. |
| | 42 | + `world_size` (Sprint 23) is the number of data-parallel ranks; 1 |
| | 43 | + on single-GPU / single-process paths. `effective_batch_size` |
| | 44 | + already folds `world_size` in, so users reading the plan don't |
| | 45 | + have to multiply themselves. |
| 42 | """ | 46 | """ |
| 43 | | 47 | |
| 44 | precision: Precision | 48 | precision: Precision |
@@ -49,6 +53,7 @@ class TrainingPlan: |
| 49 | grad_accum: int | 53 | grad_accum: int |
| 50 | effective_batch_size: int | 54 | effective_batch_size: int |
| 51 | gradient_checkpointing: bool | 55 | gradient_checkpointing: bool |
| | 56 | + world_size: int |
| 52 | est_peak_vram_gb: float | 57 | est_peak_vram_gb: float |
| 53 | est_step_seconds: float | 58 | est_step_seconds: float |
| 54 | reason: str | 59 | reason: str |
@@ -70,6 +75,7 @@ def resolve( |
| 70 | force: bool = False, | 75 | force: bool = False, |
| 71 | phase: Phase = "sft", | 76 | phase: Phase = "sft", |
| 72 | num_adapters: int = 1, | 77 | num_adapters: int = 1, |
| | 78 | + world_size: int = 1, |
| 73 | ) -> TrainingPlan: | 79 | ) -> TrainingPlan: |
| 74 | """Produce a concrete plan from a frontmatter config + host caps. | 80 | """Produce a concrete plan from a frontmatter config + host caps. |
| 75 | | 81 | |
@@ -82,7 +88,19 @@ def resolve( |
| 82 | `num_adapters` lets multi-adapter callers surface the count so | 88 | `num_adapters` lets multi-adapter callers surface the count so |
| 83 | F28 (multi-adapter QLoRA VRAM refusal) can fire before training | 89 | F28 (multi-adapter QLoRA VRAM refusal) can fire before training |
| 84 | starts. Single-adapter docs keep the default. | 90 | starts. Single-adapter docs keep the default. |
| | 91 | + |
| | 92 | + `world_size` (Sprint 23) is the number of data-parallel ranks. |
| | 93 | + Multiplies the reported `effective_batch_size` (each rank |
| | 94 | + processes a micro-batch independently) and scales the per-rank |
| | 95 | + step-time estimate down — more GPUs, less wall-clock time per |
| | 96 | + global step up to comm overhead. `world_size > 1` triggers the |
| | 97 | + multi-GPU refusal matrix (MPS/CPU refusal, heterogeneous CUDA |
| | 98 | + refusal). |
| 85 | """ | 99 | """ |
| | 100 | + if world_size < 1: |
| | 101 | + raise ValueError(f"world_size must be >= 1, got {world_size}") |
| | 102 | + if world_size > 1: |
| | 103 | + check_multi_gpu_refusals(caps, world_size) |
| 86 | check_refusals( | 104 | check_refusals( |
| 87 | training, caps, base_params, force=force, num_adapters=num_adapters | 105 | training, caps, base_params, force=force, num_adapters=num_adapters |
| 88 | ) | 106 | ) |
@@ -132,8 +150,9 @@ def resolve( |
| 132 | quant_compute_dtype=quant_dtype, | 150 | quant_compute_dtype=quant_dtype, |
| 133 | micro_batch_size=micro_batch, | 151 | micro_batch_size=micro_batch, |
| 134 | grad_accum=grad_accum, | 152 | grad_accum=grad_accum, |
| 135 | - effective_batch_size=micro_batch * grad_accum, | 153 | + effective_batch_size=micro_batch * grad_accum * world_size, |
| 136 | gradient_checkpointing=gradient_checkpointing, | 154 | gradient_checkpointing=gradient_checkpointing, |
| | 155 | + world_size=world_size, |
| 137 | est_peak_vram_gb=round(est_peak, 2), | 156 | est_peak_vram_gb=round(est_peak, 2), |
| 138 | est_step_seconds=round(est_step, 2), | 157 | est_step_seconds=round(est_step, 2), |
| 139 | reason=reason, | 158 | reason=reason, |