tenseleyflow/documentlanguagemodel / 4a21a48

Browse files

hardware(F28): scale per-adapter VRAM by base_params × avg lora_r; count only QLoRA adapters; name offenders (audit-07 M7)

Authored by espadonne
SHA
4a21a48058e015fce9eb6d1f248fe628ad5adae8
Parents
6dfd5b5
Tree
153d3ae

4 changed files

StatusFile+-
M src/dlm/hardware/refusals.py 44 7
M tests/unit/doc/test_migrate.py 7 1
M tests/unit/doc/test_versioned.py 4 0
M tests/unit/hardware/test_f28_multi_adapter_qlora.py 86 26
src/dlm/hardware/refusals.pymodified
@@ -59,20 +59,34 @@ def check_refusals(
5959
         and caps.backend == Backend.CUDA
6060
         and caps.vram_gb is not None
6161
     ):
62
-        # F28: coarse VRAM estimate. QLoRA base at 4-bit ≈ params * 0.5 bytes;
63
-        # each named adapter carries its own LoRA + optimizer state (~1 GB
64
-        # worst-case on the adapter sizes we ship). Plus a 25% activation
65
-        # overhead. When the sum exceeds the 85%-of-VRAM headroom, refuse.
62
+        # F28: estimate peak VRAM for the multi-adapter + QLoRA case.
63
+        # Base lives once in VRAM at 4-bit. Each adapter carries its own
64
+        # LoRA params + AdamW state + gradients, scaling with `base_params`
65
+        # and the adapter's `lora_r` (our LoRA params ≈ 2 * base * r / hidden;
66
+        # AdamW state ≈ 2×LoRA params in fp32). A 7B QLoRA r=16 adapter
67
+        # lands around 300-500 MB; a 135M r=8 adapter is ~10 MB. A flat
68
+        # 1 GB/adapter (pre-audit) was 30× too high for small bases and
69
+        # 2× too low for large ones. The formula below scales linearly
70
+        # with `avg_lora_r × base_params`; 0.1 GB floor keeps tiny
71
+        # multi-adapter setups from false-greenlighting.
72
+        avg_lora_r = _avg_lora_r(training)
6673
         base_gb = base_params * 0.5 / 1e9  # 4-bit base
67
-        per_adapter_gb = 1.0
74
+        per_adapter_gb = max(0.1, base_params * avg_lora_r / (1e9 * 64))
6875
         activations_gb = base_params * 2.0 / 1e9 * 0.25
69
-        est_peak = base_gb + per_adapter_gb * num_adapters + activations_gb
76
+        qlora_adapter_count = _qlora_adapter_count(training, num_adapters)
77
+        est_peak = (
78
+            base_gb + per_adapter_gb * qlora_adapter_count + activations_gb
79
+        )
7080
         budget = caps.vram_gb * 0.85
7181
         if est_peak > budget:
82
+            offenders = _qlora_adapter_names(training)
83
+            offender_note = (
84
+                f" (offending adapters: {sorted(offenders)})" if offenders else ""
85
+            )
7286
             raise ResolutionError(
7387
                 "Multi-adapter QLoRA would exceed VRAM "
7488
                 f"(~{est_peak:.1f} GB estimated vs {budget:.1f} GB budget "
75
-                f"for {caps.vram_gb:.0f} GB device); "
89
+                f"for {caps.vram_gb:.0f} GB device){offender_note}; "
7690
                 "try `adapter: lora` instead of `qlora`, or reduce the "
7791
                 "number of adapters.",
7892
             )
@@ -100,6 +114,29 @@ def _effective_adapter(training: TrainingConfig) -> str:
100114
     return "lora"
101115
 
102116
 
117
+def _avg_lora_r(training: TrainingConfig) -> float:
118
+    """Average LoRA rank across declared adapters (fallback: flat lora_r)."""
119
+    if training.adapters is None or not training.adapters:
120
+        return float(training.lora_r)
121
+    return sum(a.lora_r for a in training.adapters.values()) / len(
122
+        training.adapters
123
+    )
124
+
125
+
126
+def _qlora_adapter_count(training: TrainingConfig, fallback: int) -> int:
127
+    """Return the count of QLoRA-typed adapters; `fallback` for flat docs."""
128
+    if training.adapters is None:
129
+        return fallback
130
+    return sum(1 for a in training.adapters.values() if a.adapter == "qlora")
131
+
132
+
133
+def _qlora_adapter_names(training: TrainingConfig) -> list[str]:
134
+    """Return the declared adapter names using QLoRA (empty on flat docs)."""
135
+    if training.adapters is None:
136
+        return []
137
+    return [n for n, a in training.adapters.items() if a.adapter == "qlora"]
138
+
139
+
103140
 def _refuse_qlora(caps: Capabilities) -> None:
104141
     if caps.backend == Backend.MPS:
105142
         raise ResolutionError(
tests/unit/doc/test_migrate.pymodified
@@ -48,16 +48,22 @@ def bumped_current(scratch_registry: None) -> Iterator[int]:
4848
     """Pretend CURRENT_SCHEMA_VERSION is one higher than shipped."""
4949
     original = versioned_module.CURRENT_SCHEMA_VERSION
5050
     bumped = original + 1
51
-    # Also patch the migrate module's import-time constant.
51
+    # Also patch the migrate + schema module constants. Schema is where
52
+    # the pydantic field validator reads from (audit-07 M6 landed
53
+    # defense-in-depth there); without this patch the validator
54
+    # rejects the bumped version.
5255
     import dlm.doc.migrate as migrate_module
56
+    import dlm.doc.schema as schema_module
5357
 
5458
     versioned_module.CURRENT_SCHEMA_VERSION = bumped
5559
     migrate_module.CURRENT_SCHEMA_VERSION = bumped
60
+    schema_module.CURRENT_SCHEMA_VERSION = bumped
5661
     try:
5762
         yield bumped
5863
     finally:
5964
         versioned_module.CURRENT_SCHEMA_VERSION = original
6065
         migrate_module.CURRENT_SCHEMA_VERSION = original
66
+        schema_module.CURRENT_SCHEMA_VERSION = original
6167
 
6268
 
6369
 class TestIdempotent:
tests/unit/doc/test_versioned.pymodified
@@ -82,7 +82,10 @@ class TestMigratedPath:
8282
             # Simulated migration: drop an obsolete field the sub-current doc had.
8383
             return {k: v for k, v in raw.items() if k != "legacy_field"}
8484
 
85
+        import dlm.doc.schema as schema_module
86
+
8587
         versioned_module.CURRENT_SCHEMA_VERSION = original + 1
88
+        schema_module.CURRENT_SCHEMA_VERSION = original + 1
8689
         try:
8790
             fm = validate_versioned(
8891
                 {
@@ -95,6 +98,7 @@ class TestMigratedPath:
9598
             assert fm.dlm_version == original + 1
9699
         finally:
97100
             versioned_module.CURRENT_SCHEMA_VERSION = original
101
+            schema_module.CURRENT_SCHEMA_VERSION = original
98102
 
99103
     def test_missing_migrator_raises(self, scratch_registry: None) -> None:
100104
         """Sub-current doc + empty registry → UnsupportedMigrationError."""
tests/unit/hardware/test_f28_multi_adapter_qlora.pymodified
@@ -20,43 +20,63 @@ def _qlora_multi_doc(num: int) -> TrainingConfig:
2020
     return TrainingConfig.model_validate({"adapters": adapters})
2121
 
2222
 
23
+def _qlora_multi_doc_with_rank(num: int, lora_r: int) -> TrainingConfig:
24
+    """Multi-adapter doc with `num` QLoRA adapters at the given lora_r."""
25
+    adapters = {
26
+        f"a{i}": AdapterConfig(adapter="qlora", lora_r=lora_r) for i in range(num)
27
+    }
28
+    return TrainingConfig.model_validate({"adapters": adapters})
29
+
30
+
2331
 class TestF28MultiAdapterQLoraRefusal:
24
-    def test_two_adapters_on_small_vram_refused(self) -> None:
25
-        with force_cuda(vram_gb=4.0):
32
+    def test_large_base_high_rank_refused(self) -> None:
33
+        # 7B QLoRA, 3 adapters at r=64 on a 12GB device.
34
+        # base: 7*0.5=3.5 GB; per_adapter: 7*64/64=7 GB (×3=21 GB);
35
+        # activations: 7*2*0.25=3.5 GB → 28 GB > 12*0.85=10.2 GB budget.
36
+        with force_cuda(vram_gb=12.0):
2637
             caps = replace(probe(), has_bitsandbytes=True)
27
-        # 1.5B-param base at 4-bit ≈ 0.75 GB + 2 * 1 GB + ~0.75 GB activ
28
-        # ≈ 3.5 GB > 4 * 0.85 = 3.4 GB budget.
2938
         with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"):
3039
             check_refusals(
31
-                _qlora_multi_doc(2),
40
+                _qlora_multi_doc_with_rank(3, 64),
3241
                 caps,
33
-                base_params=1_500_000_000,
34
-                num_adapters=2,
42
+                base_params=7_000_000_000,
43
+                num_adapters=3,
3544
             )
3645
 
3746
     def test_error_message_points_to_adapter_lora_alternative(self) -> None:
38
-        with force_cuda(vram_gb=4.0):
47
+        with force_cuda(vram_gb=12.0):
3948
             caps = replace(probe(), has_bitsandbytes=True)
4049
         with pytest.raises(ResolutionError) as exc_info:
4150
             check_refusals(
42
-                _qlora_multi_doc(3),
51
+                _qlora_multi_doc_with_rank(3, 64),
4352
                 caps,
44
-                base_params=1_500_000_000,
53
+                base_params=7_000_000_000,
4554
                 num_adapters=3,
4655
             )
4756
         message = str(exc_info.value)
4857
         assert "adapter: lora" in message
4958
         assert "reduce the number of adapters" in message
5059
 
60
+    def test_error_message_names_offending_adapters(self) -> None:
61
+        """Audit-07 M7/N3: refusal lists which adapters triggered it."""
62
+        with force_cuda(vram_gb=12.0):
63
+            caps = replace(probe(), has_bitsandbytes=True)
64
+        with pytest.raises(ResolutionError) as exc_info:
65
+            check_refusals(
66
+                _qlora_multi_doc_with_rank(3, 64),
67
+                caps,
68
+                base_params=7_000_000_000,
69
+                num_adapters=3,
70
+            )
71
+        message = str(exc_info.value)
72
+        assert "offending adapters" in message
73
+        assert "'a0'" in message
74
+
5175
     def test_single_adapter_qlora_not_affected_by_f28(self) -> None:
52
-        # num_adapters=1 on a small VRAM box: the multi-adapter F28 gate
53
-        # doesn't fire (single-adapter QLoRA is the normal path); other
54
-        # refusals still apply but F28 specifically does not.
76
+        # num_adapters=1: F28's `num_adapters > 1` gate skips entirely.
5577
         with force_cuda(vram_gb=4.0):
5678
             caps = replace(probe(), has_bitsandbytes=True)
5779
         flat = TrainingConfig.model_validate({"adapter": "qlora"})
58
-        # No raise: the QLoRA checks pass (bnb present) and num_adapters
59
-        # defaults to 1, so F28's `num_adapters > 1` gate skips.
6080
         check_refusals(flat, caps, base_params=1_500_000_000, num_adapters=1)
6181
 
6282
     def test_multi_adapter_lora_not_refused(self) -> None:
@@ -69,34 +89,74 @@ class TestF28MultiAdapterQLoraRefusal:
6989
             lora_multi, caps, base_params=1_500_000_000, num_adapters=2
7090
         )
7191
 
92
+    def test_small_base_low_rank_multi_qlora_passes(self) -> None:
93
+        """The old formula falsely refused small-base multi-QLoRA.
94
+        The new formula is correctly permissive — 1.5B with r=8 fits in 4GB."""
95
+        with force_cuda(vram_gb=4.0):
96
+            caps = replace(probe(), has_bitsandbytes=True)
97
+        # 1.5B base, r=8, 2 adapters:
98
+        # base 0.75 + per_adapter ~0.19 × 2 + activations 0.75 ≈ 1.9 GB
99
+        # vs 4 × 0.85 = 3.4 GB budget → accepts.
100
+        check_refusals(
101
+            _qlora_multi_doc_with_rank(2, 8),
102
+            caps,
103
+            base_params=1_500_000_000,
104
+            num_adapters=2,
105
+        )
106
+
72107
     def test_multi_adapter_qlora_on_large_vram_passes(self) -> None:
73108
         with force_cuda(vram_gb=80.0):  # H100
74109
             caps = replace(probe(), has_bitsandbytes=True)
75
-        # 1.5B base → 0.75 + 3*1 + 0.75 ≈ 4.5 GB, well under 80 * 0.85 = 68.
110
+        # Even 7B + 3 adapters at r=64 (28 GB) fits under 80 × 0.85 = 68.
76111
         check_refusals(
77
-            _qlora_multi_doc(3),
112
+            _qlora_multi_doc_with_rank(3, 64),
78113
             caps,
79
-            base_params=1_500_000_000,
114
+            base_params=7_000_000_000,
80115
             num_adapters=3,
81116
         )
82117
 
83118
 
84119
 class TestEffectiveAdapter:
85
-    def test_mixed_multi_adapter_treated_as_qlora_for_refusals(self) -> None:
86
-        """If any declared adapter is QLoRA, F28 applies."""
87
-        with force_cuda(vram_gb=4.0):
120
+    def test_mixed_multi_adapter_refusal_only_counts_qlora_adapters(self) -> None:
121
+        """Audit-07 M7: mixed doc with one QLoRA + many LoRA doesn't
122
+        get charged the per-adapter VRAM for LoRAs. The formula counts
123
+        only QLoRA-typed adapters in the per-adapter budget line."""
124
+        with force_cuda(vram_gb=12.0):
88125
             caps = replace(probe(), has_bitsandbytes=True)
89
-        # One LoRA + one QLoRA → the "any qlora" rule still triggers
90
-        # F28's per-adapter math against num_adapters=2.
126
+        # 7B base, 1 QLoRA + 2 LoRA at r=64. Only the 1 QLoRA counts:
127
+        # base 3.5 + per_adapter 7 × 1 + activations 3.5 = 14 GB vs
128
+        # 12 × 0.85 = 10.2 GB budget → refuses (with 1-adapter charge).
91129
         mixed = TrainingConfig.model_validate(
92130
             {
93131
                 "adapters": {
94
-                    "lora_one": {"adapter": "lora"},
95
-                    "qlora_two": {"adapter": "qlora"},
132
+                    "qlora_one": {"adapter": "qlora", "lora_r": 64},
133
+                    "lora_a": {"adapter": "lora", "lora_r": 64},
134
+                    "lora_b": {"adapter": "lora", "lora_r": 64},
96135
                 },
97136
             }
98137
         )
99138
         with pytest.raises(ResolutionError, match="Multi-adapter QLoRA"):
100139
             check_refusals(
101
-                mixed, caps, base_params=1_500_000_000, num_adapters=2
140
+                mixed, caps, base_params=7_000_000_000, num_adapters=3
141
+            )
142
+
143
+    def test_mixed_adapter_error_names_only_qlora_offenders(self) -> None:
144
+        with force_cuda(vram_gb=12.0):
145
+            caps = replace(probe(), has_bitsandbytes=True)
146
+        mixed = TrainingConfig.model_validate(
147
+            {
148
+                "adapters": {
149
+                    "qlora_one": {"adapter": "qlora", "lora_r": 64},
150
+                    "lora_a": {"adapter": "lora", "lora_r": 64},
151
+                    "lora_b": {"adapter": "lora", "lora_r": 64},
152
+                },
153
+            }
154
+        )
155
+        with pytest.raises(ResolutionError) as exc_info:
156
+            check_refusals(
157
+                mixed, caps, base_params=7_000_000_000, num_adapters=3
102158
             )
159
+        message = str(exc_info.value)
160
+        assert "qlora_one" in message
161
+        assert "lora_a" not in message
162
+        assert "lora_b" not in message