Python · 6554 bytes Raw Blame History
1 """F28: multi-adapter + QLoRA VRAM refusal (Sprint 20b)."""
2
3 from __future__ import annotations
4
5 from dataclasses import replace
6
7 import pytest
8
9 from dlm.doc.schema import AdapterConfig, TrainingConfig
10 from dlm.hardware.capabilities import probe
11 from dlm.hardware.refusals import ResolutionError, check_refusals
12 from tests.fixtures.hardware_mocks import force_cuda
13
14
15 def _qlora_multi_doc(num: int) -> TrainingConfig:
16 """TrainingConfig with `num` QLoRA adapters declared."""
17 adapters = {f"a{i}": AdapterConfig(adapter="qlora") for i in range(num)}
18 return TrainingConfig.model_validate({"adapters": adapters})
19
20
21 def _qlora_multi_doc_with_rank(num: int, lora_r: int) -> TrainingConfig:
22 """Multi-adapter doc with `num` QLoRA adapters at the given lora_r."""
23 adapters = {f"a{i}": AdapterConfig(adapter="qlora", lora_r=lora_r) for i in range(num)}
24 return TrainingConfig.model_validate({"adapters": adapters})
25
26
27 class TestF28MultiAdapterQLoraRefusal:
28 def test_large_base_high_rank_refused(self) -> None:
29 # 7B QLoRA, 3 adapters at r=64 on a 12GB device.
30 # base: 7*0.5=3.5 GB; per_adapter: 7*64/64=7 GB (×3=21 GB);
31 # activations: 7*2*0.25=3.5 GB → 28 GB > 12*0.85=10.2 GB budget.
32 with force_cuda(vram_gb=12.0):
33 caps = replace(probe(), has_bitsandbytes=True)
34 with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"):
35 check_refusals(
36 _qlora_multi_doc_with_rank(3, 64),
37 caps,
38 base_params=7_000_000_000,
39 num_adapters=3,
40 )
41
42 def test_error_message_points_to_adapter_lora_alternative(self) -> None:
43 with force_cuda(vram_gb=12.0):
44 caps = replace(probe(), has_bitsandbytes=True)
45 with pytest.raises(ResolutionError) as exc_info:
46 check_refusals(
47 _qlora_multi_doc_with_rank(3, 64),
48 caps,
49 base_params=7_000_000_000,
50 num_adapters=3,
51 )
52 message = str(exc_info.value)
53 assert "adapter: lora" in message
54 assert "reduce the number of adapters" in message
55
56 def test_error_message_names_offending_adapters(self) -> None:
57 """Audit-07 M7/N3: refusal lists which adapters triggered it."""
58 with force_cuda(vram_gb=12.0):
59 caps = replace(probe(), has_bitsandbytes=True)
60 with pytest.raises(ResolutionError) as exc_info:
61 check_refusals(
62 _qlora_multi_doc_with_rank(3, 64),
63 caps,
64 base_params=7_000_000_000,
65 num_adapters=3,
66 )
67 message = str(exc_info.value)
68 assert "offending adapters" in message
69 assert "'a0'" in message
70
71 def test_single_adapter_qlora_not_affected_by_f28(self) -> None:
72 # num_adapters=1: F28's `num_adapters > 1` gate skips entirely.
73 with force_cuda(vram_gb=4.0):
74 caps = replace(probe(), has_bitsandbytes=True)
75 flat = TrainingConfig.model_validate({"adapter": "qlora"})
76 check_refusals(flat, caps, base_params=1_500_000_000, num_adapters=1)
77
78 def test_multi_adapter_lora_not_refused(self) -> None:
79 with force_cuda(vram_gb=4.0):
80 caps = replace(probe(), has_bitsandbytes=True)
81 adapters = {"a0": AdapterConfig(), "a1": AdapterConfig()}
82 lora_multi = TrainingConfig.model_validate({"adapters": adapters})
83 # LoRA bypasses QLoRA refusals entirely.
84 check_refusals(lora_multi, caps, base_params=1_500_000_000, num_adapters=2)
85
86 def test_small_base_low_rank_multi_qlora_passes(self) -> None:
87 """The old formula falsely refused small-base multi-QLoRA.
88 The new formula is correctly permissive — 1.5B with r=8 fits in 4GB."""
89 with force_cuda(vram_gb=4.0):
90 caps = replace(probe(), has_bitsandbytes=True)
91 # 1.5B base, r=8, 2 adapters:
92 # base 0.75 + per_adapter ~0.19 × 2 + activations 0.75 ≈ 1.9 GB
93 # vs 4 × 0.85 = 3.4 GB budget → accepts.
94 check_refusals(
95 _qlora_multi_doc_with_rank(2, 8),
96 caps,
97 base_params=1_500_000_000,
98 num_adapters=2,
99 )
100
101 def test_multi_adapter_qlora_on_large_vram_passes(self) -> None:
102 with force_cuda(vram_gb=80.0): # H100
103 caps = replace(probe(), has_bitsandbytes=True)
104 # Even 7B + 3 adapters at r=64 (28 GB) fits under 80 × 0.85 = 68.
105 check_refusals(
106 _qlora_multi_doc_with_rank(3, 64),
107 caps,
108 base_params=7_000_000_000,
109 num_adapters=3,
110 )
111
112
113 class TestEffectiveAdapter:
114 def test_mixed_multi_adapter_refusal_only_counts_qlora_adapters(self) -> None:
115 """Audit-07 M7: mixed doc with one QLoRA + many LoRA doesn't
116 get charged the per-adapter VRAM for LoRAs. The formula counts
117 only QLoRA-typed adapters in the per-adapter budget line."""
118 with force_cuda(vram_gb=12.0):
119 caps = replace(probe(), has_bitsandbytes=True)
120 # 7B base, 1 QLoRA + 2 LoRA at r=64. Only the 1 QLoRA counts:
121 # base 3.5 + per_adapter 7 × 1 + activations 3.5 = 14 GB vs
122 # 12 × 0.85 = 10.2 GB budget → refuses (with 1-adapter charge).
123 mixed = TrainingConfig.model_validate(
124 {
125 "adapters": {
126 "qlora_one": {"adapter": "qlora", "lora_r": 64},
127 "lora_a": {"adapter": "lora", "lora_r": 64},
128 "lora_b": {"adapter": "lora", "lora_r": 64},
129 },
130 }
131 )
132 with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"):
133 check_refusals(mixed, caps, base_params=7_000_000_000, num_adapters=3)
134
135 def test_mixed_adapter_error_names_only_qlora_offenders(self) -> None:
136 with force_cuda(vram_gb=12.0):
137 caps = replace(probe(), has_bitsandbytes=True)
138 mixed = TrainingConfig.model_validate(
139 {
140 "adapters": {
141 "qlora_one": {"adapter": "qlora", "lora_r": 64},
142 "lora_a": {"adapter": "lora", "lora_r": 64},
143 "lora_b": {"adapter": "lora", "lora_r": 64},
144 },
145 }
146 )
147 with pytest.raises(ResolutionError) as exc_info:
148 check_refusals(mixed, caps, base_params=7_000_000_000, num_adapters=3)
149 message = str(exc_info.value)
150 assert "qlora_one" in message
151 assert "lora_a" not in message
152 assert "lora_b" not in message