tenseleyflow/sway / 88c59e5

Browse files

tests/integration: HF multi-rank null calibration — correctness + no-reload timing

Authored by espadonne
SHA
88c59e5500ec5572a4f7e39fc9e72d9c9faa8b0c
Parents
c5e9de2
Tree
c151bc1

1 changed file

StatusFile+-
A tests/integration/test_hf_multi_rank_null.py 137 0
tests/integration/test_hf_multi_rank_null.pyadded
@@ -0,0 +1,137 @@
1
+"""Integration test: HF backend under multi-rank null calibration (S10 / F4).
2
+
3
+Two contracts:
4
+
5
+1. **Correctness.** Three ``rank_multipliers`` against SmolLM2-135M
6
+   produce three per-rank null-stats groups; each rank's stats are
7
+   finite and the std rises with rank (larger rank_scale → larger
8
+   effective noise by the sqrt(r) scaling the backend uses).
9
+2. **Performance.** Wall time for ``rank_multipliers=[0.5, 1.0, 2.0]``
10
+   stays under ``2×`` the single-rank baseline — i.e., we don't
11
+   reload the base model per multiplier. The noise-scaling approach
12
+   makes rank switches free (no tensor reshape, no reload).
13
+
14
+Marked ``slow+online``.
15
+"""
16
+
17
+from __future__ import annotations
18
+
19
+import time
20
+from collections.abc import Iterator
21
+from pathlib import Path
22
+
23
+import pytest
24
+
25
+from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
26
+from dlm_sway.core.model import ModelSpec
27
+from dlm_sway.core.result import Verdict
28
+from dlm_sway.probes.base import RunContext, build_probe
29
+
30
+pytestmark = [pytest.mark.slow, pytest.mark.online]
31
+
32
+
33
+def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
34
+    import torch
35
+    from peft import LoraConfig, get_peft_model
36
+    from transformers import AutoModelForCausalLM, AutoTokenizer
37
+
38
+    torch.manual_seed(0)
39
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
40
+    if tokenizer.pad_token_id is None:
41
+        tokenizer.pad_token = tokenizer.eos_token
42
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
43
+    cfg = LoraConfig(
44
+        r=8,
45
+        lora_alpha=16,
46
+        target_modules=["q_proj", "v_proj"],
47
+        lora_dropout=0.0,
48
+        bias="none",
49
+        task_type="CAUSAL_LM",
50
+    )
51
+    peft_model = get_peft_model(base, cfg)
52
+    with torch.no_grad():
53
+        for name, param in peft_model.named_parameters():
54
+            if "lora_B" in name:
55
+                param.copy_(torch.randn_like(param) * 0.05)
56
+    peft_model.save_pretrained(str(out_dir))
57
+    tokenizer.save_pretrained(str(out_dir))
58
+
59
+
60
+@pytest.fixture(scope="module")
61
+def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
62
+    adapter_dir = tmp_path_factory.mktemp("multi-rank-random-adapter")
63
+    _build_random_lora_adapter(tiny_model_dir, adapter_dir)
64
+    return adapter_dir
65
+
66
+
67
+@pytest.fixture(scope="module")
68
+def hf_backend(
69
+    tiny_model_dir: Path, random_adapter: Path
70
+) -> Iterator[HuggingFaceDifferentialBackend]:
71
+    backend = HuggingFaceDifferentialBackend(
72
+        base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
73
+        adapter_path=random_adapter,
74
+    )
75
+    yield backend
76
+    backend.close()
77
+
78
+
79
+def _run_null(
80
+    backend: HuggingFaceDifferentialBackend, rank_multipliers: list[float]
81
+) -> tuple[float, dict[str, dict[str, dict[str, float]]]]:
82
+    """Run null_adapter once and return (wall_seconds, null_stats_by_rank)."""
83
+    probe, spec = build_probe(
84
+        {
85
+            "name": "null",
86
+            "kind": "null_adapter",
87
+            "runs": 2,
88
+            "rank_multipliers": rank_multipliers,
89
+            "calibrate_kinds": ["delta_kl"],
90
+            "cache": False,  # force real compute for the timing comparison
91
+        }
92
+    )
93
+    ctx = RunContext(backend=backend)
94
+    t0 = time.perf_counter()
95
+    result = probe.run(spec, ctx)
96
+    wall = time.perf_counter() - t0
97
+    assert result.verdict == Verdict.PASS, result.message
98
+    return wall, dict(result.evidence["null_stats_by_rank"])
99
+
100
+
101
+def test_three_ranks_produce_three_stats_groups(
102
+    hf_backend: HuggingFaceDifferentialBackend,
103
+) -> None:
104
+    _, by_rank = _run_null(hf_backend, [0.5, 1.0, 2.0])
105
+    assert set(by_rank) == {"rank_0.50", "rank_1.00", "rank_2.00"}
106
+    for rkey, kind_stats in by_rank.items():
107
+        delta_kl = kind_stats.get("delta_kl")
108
+        assert delta_kl is not None, f"{rkey} missing delta_kl stats"
109
+        assert delta_kl["n"] == 2.0
110
+        assert delta_kl["std"] > 0.0
111
+
112
+
113
+def test_multi_rank_does_not_reload_base(
114
+    hf_backend: HuggingFaceDifferentialBackend,
115
+) -> None:
116
+    """Three ranks must scale ~linearly with probe iterations, *not*
117
+    incur a per-rank base-model reload.
118
+
119
+    Three ranks × two seeds = 6 calibration iterations vs single-rank's
120
+    2 iterations — so a linear-compute upper bound is ≈3×. The S07
121
+    forward-pass cache on the base view can save more, but doesn't
122
+    always (null-side view_ids are distinct per rank and seed). We
123
+    assert < 4× as the clear "no reload" ceiling: a per-rank base
124
+    reload would blow this past 10× on a 135M model.
125
+    """
126
+    # Warmup: first call amortizes the base-model load. Without this
127
+    # the single-rank baseline absorbs the load cost and the ratio
128
+    # becomes uninformative.
129
+    _run_null(hf_backend, [1.0])
130
+
131
+    single_wall, _ = _run_null(hf_backend, [1.0])
132
+    multi_wall, _ = _run_null(hf_backend, [0.5, 1.0, 2.0])
133
+    ratio = multi_wall / max(single_wall, 0.01)
134
+    assert ratio < 4.0, (
135
+        f"multi-rank wall {multi_wall:.2f}s is {ratio:.2f}× single-rank {single_wall:.2f}s "
136
+        "(threshold: < 4× — a true base-model reload would exceed 10×)"
137
+    )