Python · 5651 bytes Raw Blame History
1 """Integration test: ``HF.as_scaled_adapter`` and the response-curve invariants.
2
3 The adapter-ablation probe (the sway signature primitive) leans on
4 ``as_scaled_adapter(lam)`` to walk a λ sweep. Two things must hold:
5
6 1. **Monotonicity of the *signal* across λ**: divergence at λ=1.25
7 should be strictly larger than divergence at λ=0 (which is base).
8 We don't claim a smooth curve here — the unit tests on the probe
9 itself cover curve shape — only that the scaling actually scales.
10 2. **State restoration on exit**: every ``LoraLayer.scaling[adapter_name]``
11 value the context manager touched must be back to its original
12 number after the ``with`` block. Anything else corrupts subsequent
13 probes.
14
15 Marked ``slow+online`` to share the tiny-model fixture with the rest
16 of the integration suite.
17 """
18
19 from __future__ import annotations
20
21 from pathlib import Path
22 from typing import Any
23
24 import numpy as np
25 import pytest
26
27 from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
28 from dlm_sway.core.model import ModelSpec
29 from dlm_sway.probes._divergence import divergence
30
31 pytestmark = [pytest.mark.slow, pytest.mark.online]
32
33
34 def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
35 """Same shape as the toggle-test adapter — a small but non-zero LoRA."""
36 import torch
37 from peft import LoraConfig, get_peft_model
38 from transformers import AutoModelForCausalLM, AutoTokenizer
39
40 torch.manual_seed(0)
41 tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
42 if tokenizer.pad_token_id is None:
43 tokenizer.pad_token = tokenizer.eos_token
44 base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
45 cfg = LoraConfig(
46 r=8,
47 lora_alpha=16,
48 target_modules=["q_proj", "v_proj"],
49 lora_dropout=0.0,
50 bias="none",
51 task_type="CAUSAL_LM",
52 )
53 peft_model = get_peft_model(base, cfg)
54 with torch.no_grad():
55 for name, param in peft_model.named_parameters():
56 if "lora_B" in name:
57 param.copy_(torch.randn_like(param) * 0.05)
58 peft_model.save_pretrained(str(out_dir))
59 tokenizer.save_pretrained(str(out_dir))
60
61
62 @pytest.fixture(scope="module")
63 def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
64 adapter_dir = tmp_path_factory.mktemp("scaled-random-adapter")
65 _build_random_lora_adapter(tiny_model_dir, adapter_dir)
66 return adapter_dir
67
68
69 @pytest.fixture(scope="module")
70 def hf_backend(tiny_model_dir: Path, random_adapter: Path) -> HuggingFaceDifferentialBackend:
71 backend = HuggingFaceDifferentialBackend(
72 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
73 adapter_path=random_adapter,
74 )
75 yield backend
76 backend.close()
77
78
79 def _captured_scalings(backend: HuggingFaceDifferentialBackend) -> dict[tuple[int, str], float]:
80 """Snapshot every ``LoraLayer.scaling[key]`` keyed by (id, key)."""
81 import peft
82
83 lora_cls: Any = peft.tuners.lora.LoraLayer
84 out: dict[tuple[int, str], float] = {}
85 for module in backend._peft_model.modules(): # type: ignore[attr-defined]
86 if not isinstance(module, lora_cls):
87 continue
88 scaling = getattr(module, "scaling", None)
89 if not isinstance(scaling, dict):
90 continue
91 for key, value in scaling.items():
92 out[(id(module), key)] = float(value)
93 return out
94
95
96 def test_lambda_sweep_monotonic_in_signal(hf_backend: HuggingFaceDifferentialBackend) -> None:
97 """Divergence(@λ=0, @λ=1.25) > divergence(@λ=0, @λ=0) (which is 0)."""
98 prompt = "The quick brown fox"
99 with hf_backend.as_scaled_adapter(0.0) as v0:
100 d0 = v0.next_token_dist(prompt, top_k=64)
101 with hf_backend.as_scaled_adapter(1.25) as v_over:
102 d_over = v_over.next_token_dist(prompt, top_k=64)
103
104 div_at_zero = divergence(d0, d0, kind="js")
105 div_at_overshoot = divergence(d0, d_over, kind="js")
106 assert div_at_zero == pytest.approx(0.0, abs=1e-9), (
107 f"self-divergence at λ=0 should be ~0; got {div_at_zero}"
108 )
109 assert div_at_overshoot > 1e-6, f"λ=1.25 should drift far from λ=0; got {div_at_overshoot}"
110
111
112 def test_lambda_one_matches_finetuned_within_tolerance(
113 hf_backend: HuggingFaceDifferentialBackend,
114 ) -> None:
115 """``as_scaled_adapter(1.0)`` should be functionally identical to ``as_finetuned()``."""
116 prompt = "hello"
117 with hf_backend.as_finetuned() as ft:
118 d_ft = ft.next_token_dist(prompt, top_k=32)
119 with hf_backend.as_scaled_adapter(1.0) as v1:
120 d1 = v1.next_token_dist(prompt, top_k=32)
121 np.testing.assert_allclose(d_ft.logprobs, d1.logprobs, rtol=1e-5, atol=1e-6)
122
123
124 def test_scaling_restored_on_clean_exit(hf_backend: HuggingFaceDifferentialBackend) -> None:
125 """Every LoraLayer.scaling[key] is back to its original value after exit."""
126 before = _captured_scalings(hf_backend)
127 with hf_backend.as_scaled_adapter(0.42) as v:
128 v.next_token_dist("anything", top_k=8)
129 after = _captured_scalings(hf_backend)
130 assert before == after, "scaling table not restored after as_scaled_adapter context"
131
132
133 def test_scaling_restored_on_exception(hf_backend: HuggingFaceDifferentialBackend) -> None:
134 """Same restoration invariant when the body raises."""
135 before = _captured_scalings(hf_backend)
136 with pytest.raises(RuntimeError, match="boom"):
137 with hf_backend.as_scaled_adapter(0.7):
138 raise RuntimeError("boom")
139 after = _captured_scalings(hf_backend)
140 assert before == after, "scaling table not restored after exception inside the context"