tenseleyflow/sway / 6c7391b

Browse files

tests/hf_scaled_adapter: λ sweep monotonicity + scaling restoration on clean/exception exit

Authored by espadonne
SHA
6c7391bc4facdbc2d6abd2b898695e83417edbb7
Parents
1439538
Tree
17d6b06

1 changed file

StatusFile+-
A tests/integration/test_hf_scaled_adapter.py 140 0
tests/integration/test_hf_scaled_adapter.pyadded
@@ -0,0 +1,140 @@
1
+"""Integration test: ``HF.as_scaled_adapter`` and the response-curve invariants.
2
+
3
+The adapter-ablation probe (the sway signature primitive) leans on
4
+``as_scaled_adapter(lam)`` to walk a λ sweep. Two things must hold:
5
+
6
+1. **Monotonicity of the *signal* across λ**: divergence at λ=1.25
7
+   should be strictly larger than divergence at λ=0 (which is base).
8
+   We don't claim a smooth curve here — the unit tests on the probe
9
+   itself cover curve shape — only that the scaling actually scales.
10
+2. **State restoration on exit**: every ``LoraLayer.scaling[adapter_name]``
11
+   value the context manager touched must be back to its original
12
+   number after the ``with`` block. Anything else corrupts subsequent
13
+   probes.
14
+
15
+Marked ``slow+online`` to share the tiny-model fixture with the rest
16
+of the integration suite.
17
+"""
18
+
19
+from __future__ import annotations
20
+
21
+from pathlib import Path
22
+from typing import Any
23
+
24
+import numpy as np
25
+import pytest
26
+
27
+from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
28
+from dlm_sway.core.model import ModelSpec
29
+from dlm_sway.probes._divergence import divergence
30
+
31
+pytestmark = [pytest.mark.slow, pytest.mark.online]
32
+
33
+
34
+def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
35
+    """Same shape as the toggle-test adapter — a small but non-zero LoRA."""
36
+    import torch
37
+    from peft import LoraConfig, get_peft_model
38
+    from transformers import AutoModelForCausalLM, AutoTokenizer
39
+
40
+    torch.manual_seed(0)
41
+    tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
42
+    if tokenizer.pad_token_id is None:
43
+        tokenizer.pad_token = tokenizer.eos_token
44
+    base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
45
+    cfg = LoraConfig(
46
+        r=8,
47
+        lora_alpha=16,
48
+        target_modules=["q_proj", "v_proj"],
49
+        lora_dropout=0.0,
50
+        bias="none",
51
+        task_type="CAUSAL_LM",
52
+    )
53
+    peft_model = get_peft_model(base, cfg)
54
+    with torch.no_grad():
55
+        for name, param in peft_model.named_parameters():
56
+            if "lora_B" in name:
57
+                param.copy_(torch.randn_like(param) * 0.05)
58
+    peft_model.save_pretrained(str(out_dir))
59
+    tokenizer.save_pretrained(str(out_dir))
60
+
61
+
62
+@pytest.fixture(scope="module")
63
+def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
64
+    adapter_dir = tmp_path_factory.mktemp("scaled-random-adapter")
65
+    _build_random_lora_adapter(tiny_model_dir, adapter_dir)
66
+    return adapter_dir
67
+
68
+
69
+@pytest.fixture(scope="module")
70
+def hf_backend(tiny_model_dir: Path, random_adapter: Path) -> HuggingFaceDifferentialBackend:
71
+    backend = HuggingFaceDifferentialBackend(
72
+        base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
73
+        adapter_path=random_adapter,
74
+    )
75
+    yield backend
76
+    backend.close()
77
+
78
+
79
+def _captured_scalings(backend: HuggingFaceDifferentialBackend) -> dict[tuple[int, str], float]:
80
+    """Snapshot every ``LoraLayer.scaling[key]`` keyed by (id, key)."""
81
+    import peft
82
+
83
+    lora_cls: Any = peft.tuners.lora.LoraLayer
84
+    out: dict[tuple[int, str], float] = {}
85
+    for module in backend._peft_model.modules():  # type: ignore[attr-defined]
86
+        if not isinstance(module, lora_cls):
87
+            continue
88
+        scaling = getattr(module, "scaling", None)
89
+        if not isinstance(scaling, dict):
90
+            continue
91
+        for key, value in scaling.items():
92
+            out[(id(module), key)] = float(value)
93
+    return out
94
+
95
+
96
+def test_lambda_sweep_monotonic_in_signal(hf_backend: HuggingFaceDifferentialBackend) -> None:
97
+    """Divergence(@λ=0, @λ=1.25) > divergence(@λ=0, @λ=0) (which is 0)."""
98
+    prompt = "The quick brown fox"
99
+    with hf_backend.as_scaled_adapter(0.0) as v0:
100
+        d0 = v0.next_token_dist(prompt, top_k=64)
101
+    with hf_backend.as_scaled_adapter(1.25) as v_over:
102
+        d_over = v_over.next_token_dist(prompt, top_k=64)
103
+
104
+    div_at_zero = divergence(d0, d0, kind="js")
105
+    div_at_overshoot = divergence(d0, d_over, kind="js")
106
+    assert div_at_zero == pytest.approx(0.0, abs=1e-9), (
107
+        f"self-divergence at λ=0 should be ~0; got {div_at_zero}"
108
+    )
109
+    assert div_at_overshoot > 1e-6, f"λ=1.25 should drift far from λ=0; got {div_at_overshoot}"
110
+
111
+
112
+def test_lambda_one_matches_finetuned_within_tolerance(
113
+    hf_backend: HuggingFaceDifferentialBackend,
114
+) -> None:
115
+    """``as_scaled_adapter(1.0)`` should be functionally identical to ``as_finetuned()``."""
116
+    prompt = "hello"
117
+    with hf_backend.as_finetuned() as ft:
118
+        d_ft = ft.next_token_dist(prompt, top_k=32)
119
+    with hf_backend.as_scaled_adapter(1.0) as v1:
120
+        d1 = v1.next_token_dist(prompt, top_k=32)
121
+    np.testing.assert_allclose(d_ft.logprobs, d1.logprobs, rtol=1e-5, atol=1e-6)
122
+
123
+
124
+def test_scaling_restored_on_clean_exit(hf_backend: HuggingFaceDifferentialBackend) -> None:
125
+    """Every LoraLayer.scaling[key] is back to its original value after exit."""
126
+    before = _captured_scalings(hf_backend)
127
+    with hf_backend.as_scaled_adapter(0.42) as v:
128
+        v.next_token_dist("anything", top_k=8)
129
+    after = _captured_scalings(hf_backend)
130
+    assert before == after, "scaling table not restored after as_scaled_adapter context"
131
+
132
+
133
+def test_scaling_restored_on_exception(hf_backend: HuggingFaceDifferentialBackend) -> None:
134
+    """Same restoration invariant when the body raises."""
135
+    before = _captured_scalings(hf_backend)
136
+    with pytest.raises(RuntimeError, match="boom"):
137
+        with hf_backend.as_scaled_adapter(0.7):
138
+            raise RuntimeError("boom")
139
+    after = _captured_scalings(hf_backend)
140
+    assert before == after, "scaling table not restored after exception inside the context"