hardware(F28): scale per-adapter VRAM by base_params × avg lora_r; count only QLoRA adapters; name offenders (audit-07 M7)
- SHA
4a21a48058e015fce9eb6d1f248fe628ad5adae8- Parents
-
6dfd5b5 - Tree
153d3ae
4a21a48
4a21a48058e015fce9eb6d1f248fe628ad5adae86dfd5b5
153d3ae| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm/hardware/refusals.py
|
44 | 7 |
| M |
tests/unit/doc/test_migrate.py
|
7 | 1 |
| M |
tests/unit/doc/test_versioned.py
|
4 | 0 |
| M |
tests/unit/hardware/test_f28_multi_adapter_qlora.py
|
86 | 26 |
src/dlm/hardware/refusals.pymodified@@ -59,20 +59,34 @@ def check_refusals( | ||
| 59 | 59 | and caps.backend == Backend.CUDA |
| 60 | 60 | and caps.vram_gb is not None |
| 61 | 61 | ): |
| 62 | - # F28: coarse VRAM estimate. QLoRA base at 4-bit ≈ params * 0.5 bytes; | |
| 63 | - # each named adapter carries its own LoRA + optimizer state (~1 GB | |
| 64 | - # worst-case on the adapter sizes we ship). Plus a 25% activation | |
| 65 | - # overhead. When the sum exceeds the 85%-of-VRAM headroom, refuse. | |
| 62 | + # F28: estimate peak VRAM for the multi-adapter + QLoRA case. | |
| 63 | + # Base lives once in VRAM at 4-bit. Each adapter carries its own | |
| 64 | + # LoRA params + AdamW state + gradients, scaling with `base_params` | |
| 65 | + # and the adapter's `lora_r` (our LoRA params ≈ 2 * base * r / hidden; | |
| 66 | + # AdamW state ≈ 2×LoRA params in fp32). A 7B QLoRA r=16 adapter | |
| 67 | + # lands around 300-500 MB; a 135M r=8 adapter is ~10 MB. A flat | |
| 68 | + # 1 GB/adapter (pre-audit) was 30× too high for small bases and | |
| 69 | + # 2× too low for large ones. The formula below scales linearly | |
| 70 | + # with `avg_lora_r × base_params`; 0.1 GB floor keeps tiny | |
| 71 | + # multi-adapter setups from false-greenlighting. | |
| 72 | + avg_lora_r = _avg_lora_r(training) | |
| 66 | 73 | base_gb = base_params * 0.5 / 1e9 # 4-bit base |
| 67 | - per_adapter_gb = 1.0 | |
| 74 | + per_adapter_gb = max(0.1, base_params * avg_lora_r / (1e9 * 64)) | |
| 68 | 75 | activations_gb = base_params * 2.0 / 1e9 * 0.25 |
| 69 | - est_peak = base_gb + per_adapter_gb * num_adapters + activations_gb | |
| 76 | + qlora_adapter_count = _qlora_adapter_count(training, num_adapters) | |
| 77 | + est_peak = ( | |
| 78 | + base_gb + per_adapter_gb * qlora_adapter_count + activations_gb | |
| 79 | + ) | |
| 70 | 80 | budget = caps.vram_gb * 0.85 |
| 71 | 81 | if est_peak > budget: |
| 82 | + offenders = _qlora_adapter_names(training) | |
| 83 | + offender_note = ( | |
| 84 | + f" (offending adapters: {sorted(offenders)})" if offenders else "" | |
| 85 | + ) | |
| 72 | 86 | raise ResolutionError( |
| 73 | 87 | "Multi-adapter QLoRA would exceed VRAM " |
| 74 | 88 | f"(~{est_peak:.1f} GB estimated vs {budget:.1f} GB budget " |
| 75 | - f"for {caps.vram_gb:.0f} GB device); " | |
| 89 | + f"for {caps.vram_gb:.0f} GB device){offender_note}; " | |
| 76 | 90 | "try `adapter: lora` instead of `qlora`, or reduce the " |
| 77 | 91 | "number of adapters.", |
| 78 | 92 | ) |
@@ -100,6 +114,29 @@ def _effective_adapter(training: TrainingConfig) -> str: | ||
| 100 | 114 | return "lora" |
| 101 | 115 | |
| 102 | 116 | |
| 117 | +def _avg_lora_r(training: TrainingConfig) -> float: | |
| 118 | + """Average LoRA rank across declared adapters (fallback: flat lora_r).""" | |
| 119 | + if training.adapters is None or not training.adapters: | |
| 120 | + return float(training.lora_r) | |
| 121 | + return sum(a.lora_r for a in training.adapters.values()) / len( | |
| 122 | + training.adapters | |
| 123 | + ) | |
| 124 | + | |
| 125 | + | |
| 126 | +def _qlora_adapter_count(training: TrainingConfig, fallback: int) -> int: | |
| 127 | + """Return the count of QLoRA-typed adapters; `fallback` for flat docs.""" | |
| 128 | + if training.adapters is None: | |
| 129 | + return fallback | |
| 130 | + return sum(1 for a in training.adapters.values() if a.adapter == "qlora") | |
| 131 | + | |
| 132 | + | |
| 133 | +def _qlora_adapter_names(training: TrainingConfig) -> list[str]: | |
| 134 | + """Return the declared adapter names using QLoRA (empty on flat docs).""" | |
| 135 | + if training.adapters is None: | |
| 136 | + return [] | |
| 137 | + return [n for n, a in training.adapters.items() if a.adapter == "qlora"] | |
| 138 | + | |
| 139 | + | |
| 103 | 140 | def _refuse_qlora(caps: Capabilities) -> None: |
| 104 | 141 | if caps.backend == Backend.MPS: |
| 105 | 142 | raise ResolutionError( |
tests/unit/doc/test_migrate.pymodified@@ -48,16 +48,22 @@ def bumped_current(scratch_registry: None) -> Iterator[int]: | ||
| 48 | 48 | """Pretend CURRENT_SCHEMA_VERSION is one higher than shipped.""" |
| 49 | 49 | original = versioned_module.CURRENT_SCHEMA_VERSION |
| 50 | 50 | bumped = original + 1 |
| 51 | - # Also patch the migrate module's import-time constant. | |
| 51 | + # Also patch the migrate + schema module constants. Schema is where | |
| 52 | + # the pydantic field validator reads from (audit-07 M6 landed | |
| 53 | + # defense-in-depth there); without this patch the validator | |
| 54 | + # rejects the bumped version. | |
| 52 | 55 | import dlm.doc.migrate as migrate_module |
| 56 | + import dlm.doc.schema as schema_module | |
| 53 | 57 | |
| 54 | 58 | versioned_module.CURRENT_SCHEMA_VERSION = bumped |
| 55 | 59 | migrate_module.CURRENT_SCHEMA_VERSION = bumped |
| 60 | + schema_module.CURRENT_SCHEMA_VERSION = bumped | |
| 56 | 61 | try: |
| 57 | 62 | yield bumped |
| 58 | 63 | finally: |
| 59 | 64 | versioned_module.CURRENT_SCHEMA_VERSION = original |
| 60 | 65 | migrate_module.CURRENT_SCHEMA_VERSION = original |
| 66 | + schema_module.CURRENT_SCHEMA_VERSION = original | |
| 61 | 67 | |
| 62 | 68 | |
| 63 | 69 | class TestIdempotent: |
tests/unit/doc/test_versioned.pymodified@@ -82,7 +82,10 @@ class TestMigratedPath: | ||
| 82 | 82 | # Simulated migration: drop an obsolete field the sub-current doc had. |
| 83 | 83 | return {k: v for k, v in raw.items() if k != "legacy_field"} |
| 84 | 84 | |
| 85 | + import dlm.doc.schema as schema_module | |
| 86 | + | |
| 85 | 87 | versioned_module.CURRENT_SCHEMA_VERSION = original + 1 |
| 88 | + schema_module.CURRENT_SCHEMA_VERSION = original + 1 | |
| 86 | 89 | try: |
| 87 | 90 | fm = validate_versioned( |
| 88 | 91 | { |
@@ -95,6 +98,7 @@ class TestMigratedPath: | ||
| 95 | 98 | assert fm.dlm_version == original + 1 |
| 96 | 99 | finally: |
| 97 | 100 | versioned_module.CURRENT_SCHEMA_VERSION = original |
| 101 | + schema_module.CURRENT_SCHEMA_VERSION = original | |
| 98 | 102 | |
| 99 | 103 | def test_missing_migrator_raises(self, scratch_registry: None) -> None: |
| 100 | 104 | """Sub-current doc + empty registry → UnsupportedMigrationError.""" |
tests/unit/hardware/test_f28_multi_adapter_qlora.pymodified@@ -20,43 +20,63 @@ def _qlora_multi_doc(num: int) -> TrainingConfig: | ||
| 20 | 20 | return TrainingConfig.model_validate({"adapters": adapters}) |
| 21 | 21 | |
| 22 | 22 | |
| 23 | +def _qlora_multi_doc_with_rank(num: int, lora_r: int) -> TrainingConfig: | |
| 24 | + """Multi-adapter doc with `num` QLoRA adapters at the given lora_r.""" | |
| 25 | + adapters = { | |
| 26 | + f"a{i}": AdapterConfig(adapter="qlora", lora_r=lora_r) for i in range(num) | |
| 27 | + } | |
| 28 | + return TrainingConfig.model_validate({"adapters": adapters}) | |
| 29 | + | |
| 30 | + | |
| 23 | 31 | class TestF28MultiAdapterQLoraRefusal: |
| 24 | - def test_two_adapters_on_small_vram_refused(self) -> None: | |
| 25 | - with force_cuda(vram_gb=4.0): | |
| 32 | + def test_large_base_high_rank_refused(self) -> None: | |
| 33 | + # 7B QLoRA, 3 adapters at r=64 on a 12GB device. | |
| 34 | + # base: 7*0.5=3.5 GB; per_adapter: 7*64/64=7 GB (×3=21 GB); | |
| 35 | + # activations: 7*2*0.25=3.5 GB → 28 GB > 12*0.85=10.2 GB budget. | |
| 36 | + with force_cuda(vram_gb=12.0): | |
| 26 | 37 | caps = replace(probe(), has_bitsandbytes=True) |
| 27 | - # 1.5B-param base at 4-bit ≈ 0.75 GB + 2 * 1 GB + ~0.75 GB activ | |
| 28 | - # ≈ 3.5 GB > 4 * 0.85 = 3.4 GB budget. | |
| 29 | 38 | with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"): |
| 30 | 39 | check_refusals( |
| 31 | - _qlora_multi_doc(2), | |
| 40 | + _qlora_multi_doc_with_rank(3, 64), | |
| 32 | 41 | caps, |
| 33 | - base_params=1_500_000_000, | |
| 34 | - num_adapters=2, | |
| 42 | + base_params=7_000_000_000, | |
| 43 | + num_adapters=3, | |
| 35 | 44 | ) |
| 36 | 45 | |
| 37 | 46 | def test_error_message_points_to_adapter_lora_alternative(self) -> None: |
| 38 | - with force_cuda(vram_gb=4.0): | |
| 47 | + with force_cuda(vram_gb=12.0): | |
| 39 | 48 | caps = replace(probe(), has_bitsandbytes=True) |
| 40 | 49 | with pytest.raises(ResolutionError) as exc_info: |
| 41 | 50 | check_refusals( |
| 42 | - _qlora_multi_doc(3), | |
| 51 | + _qlora_multi_doc_with_rank(3, 64), | |
| 43 | 52 | caps, |
| 44 | - base_params=1_500_000_000, | |
| 53 | + base_params=7_000_000_000, | |
| 45 | 54 | num_adapters=3, |
| 46 | 55 | ) |
| 47 | 56 | message = str(exc_info.value) |
| 48 | 57 | assert "adapter: lora" in message |
| 49 | 58 | assert "reduce the number of adapters" in message |
| 50 | 59 | |
| 60 | + def test_error_message_names_offending_adapters(self) -> None: | |
| 61 | + """Audit-07 M7/N3: refusal lists which adapters triggered it.""" | |
| 62 | + with force_cuda(vram_gb=12.0): | |
| 63 | + caps = replace(probe(), has_bitsandbytes=True) | |
| 64 | + with pytest.raises(ResolutionError) as exc_info: | |
| 65 | + check_refusals( | |
| 66 | + _qlora_multi_doc_with_rank(3, 64), | |
| 67 | + caps, | |
| 68 | + base_params=7_000_000_000, | |
| 69 | + num_adapters=3, | |
| 70 | + ) | |
| 71 | + message = str(exc_info.value) | |
| 72 | + assert "offending adapters" in message | |
| 73 | + assert "'a0'" in message | |
| 74 | + | |
| 51 | 75 | def test_single_adapter_qlora_not_affected_by_f28(self) -> None: |
| 52 | - # num_adapters=1 on a small VRAM box: the multi-adapter F28 gate | |
| 53 | - # doesn't fire (single-adapter QLoRA is the normal path); other | |
| 54 | - # refusals still apply but F28 specifically does not. | |
| 76 | + # num_adapters=1: F28's `num_adapters > 1` gate skips entirely. | |
| 55 | 77 | with force_cuda(vram_gb=4.0): |
| 56 | 78 | caps = replace(probe(), has_bitsandbytes=True) |
| 57 | 79 | flat = TrainingConfig.model_validate({"adapter": "qlora"}) |
| 58 | - # No raise: the QLoRA checks pass (bnb present) and num_adapters | |
| 59 | - # defaults to 1, so F28's `num_adapters > 1` gate skips. | |
| 60 | 80 | check_refusals(flat, caps, base_params=1_500_000_000, num_adapters=1) |
| 61 | 81 | |
| 62 | 82 | def test_multi_adapter_lora_not_refused(self) -> None: |
@@ -69,34 +89,74 @@ class TestF28MultiAdapterQLoraRefusal: | ||
| 69 | 89 | lora_multi, caps, base_params=1_500_000_000, num_adapters=2 |
| 70 | 90 | ) |
| 71 | 91 | |
| 92 | + def test_small_base_low_rank_multi_qlora_passes(self) -> None: | |
| 93 | + """The old formula falsely refused small-base multi-QLoRA. | |
| 94 | + The new formula is correctly permissive — 1.5B with r=8 fits in 4GB.""" | |
| 95 | + with force_cuda(vram_gb=4.0): | |
| 96 | + caps = replace(probe(), has_bitsandbytes=True) | |
| 97 | + # 1.5B base, r=8, 2 adapters: | |
| 98 | + # base 0.75 + per_adapter ~0.19 × 2 + activations 0.75 ≈ 1.9 GB | |
| 99 | + # vs 4 × 0.85 = 3.4 GB budget → accepts. | |
| 100 | + check_refusals( | |
| 101 | + _qlora_multi_doc_with_rank(2, 8), | |
| 102 | + caps, | |
| 103 | + base_params=1_500_000_000, | |
| 104 | + num_adapters=2, | |
| 105 | + ) | |
| 106 | + | |
| 72 | 107 | def test_multi_adapter_qlora_on_large_vram_passes(self) -> None: |
| 73 | 108 | with force_cuda(vram_gb=80.0): # H100 |
| 74 | 109 | caps = replace(probe(), has_bitsandbytes=True) |
| 75 | - # 1.5B base → 0.75 + 3*1 + 0.75 ≈ 4.5 GB, well under 80 * 0.85 = 68. | |
| 110 | + # Even 7B + 3 adapters at r=64 (28 GB) fits under 80 × 0.85 = 68. | |
| 76 | 111 | check_refusals( |
| 77 | - _qlora_multi_doc(3), | |
| 112 | + _qlora_multi_doc_with_rank(3, 64), | |
| 78 | 113 | caps, |
| 79 | - base_params=1_500_000_000, | |
| 114 | + base_params=7_000_000_000, | |
| 80 | 115 | num_adapters=3, |
| 81 | 116 | ) |
| 82 | 117 | |
| 83 | 118 | |
| 84 | 119 | class TestEffectiveAdapter: |
| 85 | - def test_mixed_multi_adapter_treated_as_qlora_for_refusals(self) -> None: | |
| 86 | - """If any declared adapter is QLoRA, F28 applies.""" | |
| 87 | - with force_cuda(vram_gb=4.0): | |
| 120 | + def test_mixed_multi_adapter_refusal_only_counts_qlora_adapters(self) -> None: | |
| 121 | + """Audit-07 M7: mixed doc with one QLoRA + many LoRA doesn't | |
| 122 | + get charged the per-adapter VRAM for LoRAs. The formula counts | |
| 123 | + only QLoRA-typed adapters in the per-adapter budget line.""" | |
| 124 | + with force_cuda(vram_gb=12.0): | |
| 88 | 125 | caps = replace(probe(), has_bitsandbytes=True) |
| 89 | - # One LoRA + one QLoRA → the "any qlora" rule still triggers | |
| 90 | - # F28's per-adapter math against num_adapters=2. | |
| 126 | + # 7B base, 1 QLoRA + 2 LoRA at r=64. Only the 1 QLoRA counts: | |
| 127 | + # base 3.5 + per_adapter 7 × 1 + activations 3.5 = 14 GB vs | |
| 128 | + # 12 × 0.85 = 10.2 GB budget → refuses (with 1-adapter charge). | |
| 91 | 129 | mixed = TrainingConfig.model_validate( |
| 92 | 130 | { |
| 93 | 131 | "adapters": { |
| 94 | - "lora_one": {"adapter": "lora"}, | |
| 95 | - "qlora_two": {"adapter": "qlora"}, | |
| 132 | + "qlora_one": {"adapter": "qlora", "lora_r": 64}, | |
| 133 | + "lora_a": {"adapter": "lora", "lora_r": 64}, | |
| 134 | + "lora_b": {"adapter": "lora", "lora_r": 64}, | |
| 96 | 135 | }, |
| 97 | 136 | } |
| 98 | 137 | ) |
| 99 | 138 | with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"): |
| 100 | 139 | check_refusals( |
| 101 | - mixed, caps, base_params=1_500_000_000, num_adapters=2 | |
| 140 | + mixed, caps, base_params=7_000_000_000, num_adapters=3 | |
| 141 | + ) | |
| 142 | + | |
| 143 | + def test_mixed_adapter_error_names_only_qlora_offenders(self) -> None: | |
| 144 | + with force_cuda(vram_gb=12.0): | |
| 145 | + caps = replace(probe(), has_bitsandbytes=True) | |
| 146 | + mixed = TrainingConfig.model_validate( | |
| 147 | + { | |
| 148 | + "adapters": { | |
| 149 | + "qlora_one": {"adapter": "qlora", "lora_r": 64}, | |
| 150 | + "lora_a": {"adapter": "lora", "lora_r": 64}, | |
| 151 | + "lora_b": {"adapter": "lora", "lora_r": 64}, | |
| 152 | + }, | |
| 153 | + } | |
| 154 | + ) | |
| 155 | + with pytest.raises(ResolutionError) as exc_info: | |
| 156 | + check_refusals( | |
| 157 | + mixed, caps, base_params=7_000_000_000, num_adapters=3 | |
| 102 | 158 | ) |
| 159 | + message = str(exc_info.value) | |
| 160 | + assert "qlora_one" in message | |
| 161 | + assert "lora_a" not in message | |
| 162 | + assert "lora_b" not in message | |