Python · 4253 bytes Raw Blame History
1 """Integration test: PEFT ``disable_adapter`` actually changes logits.
2
3 This is the load-bearing sanity check for the whole differential design.
4 If a future ``peft`` release subtly breaks the disable-context semantics,
5 sway's KL / SIS / ablation probes would all silently report zero signal.
6 We catch that here, before the rest of the test battery runs.
7
8 The test builds a random-init LoRA adapter on a tiny model so no network
9 dependency beyond the base model snapshot itself.
10 """
11
12 from __future__ import annotations
13
14 from pathlib import Path
15
16 import pytest
17
18 from dlm_sway.backends.hf import HuggingFaceDifferentialBackend
19 from dlm_sway.core.model import ModelSpec
20
21 pytestmark = [pytest.mark.slow, pytest.mark.online]
22
23
24 def _build_random_lora_adapter(base_dir: Path, out_dir: Path) -> None:
25 """Construct a LoRA adapter with random-init weights on ``base_dir``.
26
27 The weights are kept small so the toggle-delta is clear but the
28 adapter is structurally valid (correct ``adapter_config.json``,
29 tokenizer files, safetensors layout).
30 """
31 import torch
32 from peft import LoraConfig, get_peft_model
33 from transformers import AutoModelForCausalLM, AutoTokenizer
34
35 torch.manual_seed(0)
36
37 tokenizer = AutoTokenizer.from_pretrained(str(base_dir))
38 if tokenizer.pad_token_id is None:
39 tokenizer.pad_token = tokenizer.eos_token
40 base = AutoModelForCausalLM.from_pretrained(str(base_dir), torch_dtype=torch.float32)
41
42 cfg = LoraConfig(
43 r=8,
44 lora_alpha=16,
45 target_modules=["q_proj", "v_proj"],
46 lora_dropout=0.0,
47 bias="none",
48 task_type="CAUSAL_LM",
49 )
50 peft_model = get_peft_model(base, cfg)
51
52 # Explicitly scale lora_B out of its PEFT-default zero-init so the
53 # adapter actually changes outputs. Real training does this via
54 # gradients; we do it with a scaled normal.
55 with torch.no_grad():
56 for name, param in peft_model.named_parameters():
57 if "lora_B" in name:
58 param.copy_(torch.randn_like(param) * 0.05)
59
60 peft_model.save_pretrained(str(out_dir))
61 tokenizer.save_pretrained(str(out_dir))
62
63
64 @pytest.fixture(scope="module")
65 def random_adapter(tiny_model_dir: Path, tmp_path_factory: pytest.TempPathFactory) -> Path:
66 adapter_dir = tmp_path_factory.mktemp("random-adapter")
67 _build_random_lora_adapter(tiny_model_dir, adapter_dir)
68 return adapter_dir
69
70
71 def test_disable_adapter_changes_logits(
72 tiny_model_dir: Path, random_adapter: Path
73 ) -> None:
74 """The keystone invariant: base view ≠ ft view on the same prompt."""
75 import numpy as np
76
77 backend = HuggingFaceDifferentialBackend(
78 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
79 adapter_path=random_adapter,
80 )
81 try:
82 prompt = "The quick brown fox"
83 with backend.as_base() as b:
84 base_dist = b.next_token_dist(prompt, top_k=32)
85 with backend.as_finetuned() as f:
86 ft_dist = f.next_token_dist(prompt, top_k=32)
87
88 # Top-k indices may shift under the adapter; take a safe shared
89 # subset instead of asserting identical ordering.
90 assert not np.array_equal(base_dist.token_ids, ft_dist.token_ids) or not np.allclose(
91 base_dist.logprobs, ft_dist.logprobs, atol=1e-5
92 ), "adapter toggle did not change next-token distribution"
93 finally:
94 backend.close()
95
96
97 def test_roundtrip_toggle_restores_base(
98 tiny_model_dir: Path, random_adapter: Path
99 ) -> None:
100 """as_base → as_finetuned → as_base yields a stable base view."""
101 import numpy as np
102
103 backend = HuggingFaceDifferentialBackend(
104 base_spec=ModelSpec(base=str(tiny_model_dir), kind="hf", dtype="fp32", device="cpu"),
105 adapter_path=random_adapter,
106 )
107 try:
108 prompt = "hello"
109 with backend.as_base() as b:
110 first = b.next_token_dist(prompt, top_k=16).logprobs
111 with backend.as_finetuned() as f:
112 f.next_token_dist(prompt, top_k=16) # toggle
113 with backend.as_base() as b:
114 second = b.next_token_dist(prompt, top_k=16).logprobs
115 np.testing.assert_allclose(first, second, rtol=1e-5, atol=1e-6)
116 finally:
117 backend.close()