Python · 5662 bytes Raw Blame History
1 """Integration test: PEFT ``disable_adapter`` actually changes logits.
2
3 This is the load-bearing sanity check for the whole differential design.
4 If a future ``peft`` release subtly breaks the disable-context semantics,
5 sway's KL / SIS / ablation probes would all silently report zero signal.
6 We catch that here, before the rest of the test battery runs.
7
8 The test builds a random-init LoRA adapter on a tiny model so no network
9 dependency beyond the base model snapshot itself.
10 """
11
12 from __future__ import annotations
13
14 from pathlib import Path
15
16 import pytest
17
18 from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
19 from dlm_sway.core.model import ModelSpec
20
21 pytestmark = [pytest.mark.slow, pytest.mark.online]
22
23
24 def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
25 """Construct a LoRA adapter with random-init weights on ``base_dir``.
26
27 The weights are kept small so the toggle-delta is clear but the
28 adapter is structurally valid (correct ``adapter_config.json``,
29 tokenizer files, safetensors layout).
30 """
31 import torch
32 from peft import LoraConfig, get_peft_model
33 from transformers import AutoModelForCausalLM, AutoTokenizer
34
35 torch.manual_seed(0)
36
37 tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
38 if tokenizer.pad_token_id is None:
39 tokenizer.pad_token = tokenizer.eos_token
40 base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
41
42 cfg = LoraConfig(
43 r=8,
44 lora_alpha=16,
45 target_modules=["q_proj", "v_proj"],
46 lora_dropout=0.0,
47 bias="none",
48 task_type="CAUSAL_LM",
49 )
50 peft_model = get_peft_model(base, cfg)
51
52 # Explicitly scale lora_B out of its PEFT-default zero-init so the
53 # adapter actually changes outputs. Real training does this via
54 # gradients; we do it with a scaled normal.
55 with torch.no_grad():
56 for name, param in peft_model.named_parameters():
57 if "lora_B" in name:
58 param.copy_(torch.randn_like(param) * 0.05)
59
60 peft_model.save_pretrained(str(out_dir))
61 tokenizer.save_pretrained(str(out_dir))
62
63
64 @pytest.fixture(scope="module")
65 def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
66 adapter_dir = tmp_path_factory.mktemp("random-adapter")
67 _build_random_lora_adapter(tiny_model_dir, adapter_dir)
68 return adapter_dir
69
70
71 def test_disable_adapter_changes_logits(tiny_model_dir: Path, random_adapter: Path) -> None:
72 """The keystone invariant: base view ≠ ft view on the same prompt."""
73 import numpy as np
74
75 backend = HuggingFaceDifferentialBackend(
76 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
77 adapter_path=random_adapter,
78 )
79 try:
80 prompt = "The quick brown fox"
81 with backend.as_base() as b:
82 base_dist = b.next_token_dist(prompt, top_k=32)
83 with backend.as_finetuned() as f:
84 ft_dist = f.next_token_dist(prompt, top_k=32)
85
86 # Top-k indices may shift under the adapter; take a safe shared
87 # subset instead of asserting identical ordering.
88 assert not np.array_equal(base_dist.token_ids, ft_dist.token_ids) or not np.allclose(
89 base_dist.logprobs, ft_dist.logprobs, atol=1e-5
90 ), "adapter toggle did not change next-token distribution"
91 finally:
92 backend.close()
93
94
95 def test_roundtrip_toggle_restores_base(tiny_model_dir: Path, random_adapter: Path) -> None:
96 """as_base → as_finetuned → as_base yields a stable base view."""
97 import numpy as np
98
99 backend = HuggingFaceDifferentialBackend(
100 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
101 adapter_path=random_adapter,
102 )
103 try:
104 prompt = "hello"
105 with backend.as_base() as b:
106 first = b.next_token_dist(prompt, top_k=16).logprobs
107 with backend.as_finetuned() as f:
108 f.next_token_dist(prompt, top_k=16) # toggle
109 with backend.as_base() as b:
110 second = b.next_token_dist(prompt, top_k=16).logprobs
111 np.testing.assert_allclose(first, second, rtol=1e-5, atol=1e-6)
112 finally:
113 backend.close()
114
115
116 def test_disable_re_enable_bit_identical_logits(tiny_model_dir: Path, random_adapter: Path) -> None:
117 """B15 mitigation: ft → base → ft produces bit-identical ft logits.
118
119 Subtle state corruption inside ``disable_adapter()`` (e.g. a wrong
120 re-attach order on context exit) would silently shift the second ft
121 pass by an immeasurably small amount that ``assert_allclose``
122 tolerates but ``assert_array_equal`` doesn't. Pin the stricter claim
123 on fp32 + CPU so the test stays deterministic across hosts.
124 """
125 import numpy as np
126
127 backend = HuggingFaceDifferentialBackend(
128 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
129 adapter_path=random_adapter,
130 )
131 try:
132 prompt = "the disable_adapter contract is"
133 with backend.as_finetuned() as f:
134 first = np.array(f.next_token_dist(prompt, top_k=32).logprobs, copy=True)
135 with backend.as_base() as b:
136 b.next_token_dist(prompt, top_k=32) # toggle through base
137 with backend.as_finetuned() as f:
138 second = np.array(f.next_token_dist(prompt, top_k=32).logprobs, copy=True)
139 np.testing.assert_array_equal(
140 first,
141 second,
142 err_msg="ft logits drifted across a base toggle — disable_adapter exit may have corrupted the adapter state (B15)",
143 )
144 finally:
145 backend.close()