tenseleyflow/sway / 2473839

Browse files

tests/hf_null_adapter: seed determinism + adapter restoration on clean/exception exit

Authored by espadonne
SHA
247383937a976f28d9dc2ea1cacbc270d7d8a4ce
Parents
6c7391b
Tree
f0bbfc3

1 changed file

StatusFile+-
A tests/integration/test_hf_null_adapter.py 142 0
tests/integration/test_hf_null_adapter.pyadded
@@ -0,0 +1,142 @@
1
+"""Integration test: ``HF.as_null_adapter`` determinism + restoration.
2
+
3
+Two contracts the calibration matrix depends on:
4
+
5
+1. **Same seed → same null weights.** ``as_null_adapter(seed=0)`` called
6
+   twice in a row must produce bit-identical lora_A / lora_B tensors
7
+   inside the context. If a future PEFT release randomized something
8
+   we missed (e.g. dropout sampling), the cached null stats would be
9
+   silently inconsistent across runs.
10
+2. **Original adapter restored on exit.** After the context manager
11
+   returns, every ``lora_A`` / ``lora_B`` parameter must equal its
12
+   pre-context value. Otherwise the next probe runs against a
13
+   randomly-poisoned ft view.
14
+
15
+Marked ``slow+online`` to share the tiny-model fixture.
16
+"""
17
+
18
+from __future__ import annotations
19
+
20
+from pathlib import Path
21
+
22
+import numpy as np
23
+import pytest
24
+
25
+from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
26
+from dlm_sway.core.model import ModelSpec
27
+
28
+pytestmark = [pytest.mark.slow, pytest.mark.online]
29
+
30
+
31
+def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
32
+    """Same shape as the toggle-test adapter."""
33
+    import torch
34
+    from peft import LoraConfig, get_peft_model
35
+    from transformers import AutoModelForCausalLM, AutoTokenizer
36
+
37
+    torch.manual_seed(0)
38
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
39
+    if tokenizer.pad_token_id is None:
40
+        tokenizer.pad_token = tokenizer.eos_token
41
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
42
+    cfg = LoraConfig(
43
+        r=8,
44
+        lora_alpha=16,
45
+        target_modules=["q_proj", "v_proj"],
46
+        lora_dropout=0.0,
47
+        bias="none",
48
+        task_type="CAUSAL_LM",
49
+    )
50
+    peft_model = get_peft_model(base, cfg)
51
+    with torch.no_grad():
52
+        for name, param in peft_model.named_parameters():
53
+            if "lora_B" in name:
54
+                param.copy_(torch.randn_like(param) * 0.05)
55
+    peft_model.save_pretrained(str(out_dir))
56
+    tokenizer.save_pretrained(str(out_dir))
57
+
58
+
59
+@pytest.fixture(scope="module")
60
+def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
61
+    adapter_dir = tmp_path_factory.mktemp("null-random-adapter")
62
+    _build_random_lora_adapter(tiny_model_dir, adapter_dir)
63
+    return adapter_dir
64
+
65
+
66
+@pytest.fixture(scope="module")
67
+def hf_backend(tiny_model_dir: Path, random_adapter: Path) -> HuggingFaceDifferentialBackend:
68
+    backend = HuggingFaceDifferentialBackend(
69
+        base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
70
+        adapter_path=random_adapter,
71
+    )
72
+    yield backend
73
+    backend.close()
74
+
75
+
76
+def _snapshot_lora_params(backend: HuggingFaceDifferentialBackend) -> dict[str, np.ndarray]:
77
+    """Capture every lora_A / lora_B parameter as a numpy copy."""
78
+    out: dict[str, np.ndarray] = {}
79
+    for pname, param in backend._peft_model.named_parameters():  # type: ignore[attr-defined]
80
+        if any(key in pname for key in ("lora_A", "lora_B")):
81
+            out[pname] = param.detach().cpu().numpy().copy()
82
+    return out
83
+
84
+
85
+def test_same_seed_produces_identical_null_weights(
86
+    hf_backend: HuggingFaceDifferentialBackend,
87
+) -> None:
88
+    with hf_backend.as_null_adapter(seed=0):
89
+        first = _snapshot_lora_params(hf_backend)
90
+    with hf_backend.as_null_adapter(seed=0):
91
+        second = _snapshot_lora_params(hf_backend)
92
+    assert set(first) == set(second)
93
+    for key in first:
94
+        np.testing.assert_array_equal(
95
+            first[key],
96
+            second[key],
97
+            err_msg=f"as_null_adapter(seed=0) was not deterministic for {key!r}",
98
+        )
99
+
100
+
101
+def test_different_seeds_produce_different_null_weights(
102
+    hf_backend: HuggingFaceDifferentialBackend,
103
+) -> None:
104
+    with hf_backend.as_null_adapter(seed=0):
105
+        a = _snapshot_lora_params(hf_backend)
106
+    with hf_backend.as_null_adapter(seed=1):
107
+        b = _snapshot_lora_params(hf_backend)
108
+    different = any(not np.array_equal(a[k], b[k]) for k in a)
109
+    assert different, "null adapters at seed=0 and seed=1 produced identical weights"
110
+
111
+
112
+def test_original_adapter_restored_on_exit(
113
+    hf_backend: HuggingFaceDifferentialBackend,
114
+) -> None:
115
+    before = _snapshot_lora_params(hf_backend)
116
+    with hf_backend.as_null_adapter(seed=42):
117
+        # confirm the inner state is *not* the original
118
+        inner = _snapshot_lora_params(hf_backend)
119
+        assert any(not np.array_equal(before[k], inner[k]) for k in before)
120
+    after = _snapshot_lora_params(hf_backend)
121
+    for key in before:
122
+        np.testing.assert_array_equal(
123
+            before[key],
124
+            after[key],
125
+            err_msg=f"original adapter not restored for {key!r}",
126
+        )
127
+
128
+
129
+def test_original_adapter_restored_on_exception(
130
+    hf_backend: HuggingFaceDifferentialBackend,
131
+) -> None:
132
+    before = _snapshot_lora_params(hf_backend)
133
+    with pytest.raises(RuntimeError, match="boom"):
134
+        with hf_backend.as_null_adapter(seed=99):
135
+            raise RuntimeError("boom")
136
+    after = _snapshot_lora_params(hf_backend)
137
+    for key in before:
138
+        np.testing.assert_array_equal(
139
+            before[key],
140
+            after[key],
141
+            err_msg=f"original adapter not restored after exception for {key!r}",
142
+        )