Python · 8939 bytes Raw Blame History
1 """`apply_control` — forward_pre_hook attach/detach + arithmetic."""
2
3 from __future__ import annotations
4
5 import numpy as np
6 import pytest
7 import torch
8 from torch import nn
9
10 from dlm.control import ControlApplyError, apply_control
11 from dlm.control.apply import _make_hook
12
13
14 class _ToyLayer(nn.Module):
15 """A stand-in for an HF decoder layer.
16
17 Owns `self_attn.q_proj` so the dim-validation path has something
18 to inspect; the forward just passes the hidden state through so
19 the hook's perturbation is visible on the output.
20 """
21
22 def __init__(self, hidden_dim: int) -> None:
23 super().__init__()
24 self.self_attn = nn.Module()
25 # `nn.Linear(hidden_dim, hidden_dim)` weight shape is
26 # `(out, in)` — the `[-1]` the apply path reads is `in = hidden_dim`.
27 self.self_attn.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
28
29 def forward(self, hidden: torch.Tensor) -> torch.Tensor:
30 return hidden
31
32
33 class _ToyModel(nn.Module):
34 """Minimal HF-shaped wrapper: `model.model.layers[i]`."""
35
36 def __init__(self, n_layers: int, hidden_dim: int) -> None:
37 super().__init__()
38 self.model = nn.Module()
39 self.model.layers = nn.ModuleList([_ToyLayer(hidden_dim) for _ in range(n_layers)])
40
41
42 def _run_through_layer(model: _ToyModel, layer_index: int, hidden: torch.Tensor) -> torch.Tensor:
43 return model.model.layers[layer_index](hidden)
44
45
46 class TestHookArithmetic:
47 def test_hook_with_no_args_returns_original_tuple(self) -> None:
48 hook = _make_hook(torch.ones(4), 1.0)
49 empty: tuple[object, ...] = ()
50
51 assert hook(None, empty) == empty
52
53 def test_adds_scaled_vector_to_hidden(self) -> None:
54 model = _ToyModel(n_layers=4, hidden_dim=8)
55 vector = np.ones(8, dtype=np.float32)
56 hidden = torch.zeros(1, 3, 8)
57 with apply_control(model, vector, layer_index=2, strength=2.5):
58 out = _run_through_layer(model, 2, hidden)
59 # strength=2.5, vector=[1,1,1,...] → each output element is 2.5.
60 assert torch.allclose(out, torch.full_like(out, 2.5))
61
62 def test_zero_strength_is_passthrough(self) -> None:
63 model = _ToyModel(n_layers=2, hidden_dim=4)
64 vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
65 hidden = torch.randn(2, 5, 4)
66 with apply_control(model, vector, layer_index=0, strength=0.0):
67 out = _run_through_layer(model, 0, hidden)
68 assert torch.allclose(out, hidden)
69
70 def test_negative_strength_pushes_opposite(self) -> None:
71 model = _ToyModel(n_layers=2, hidden_dim=3)
72 vector = np.array([1.0, 0.0, 0.0], dtype=np.float32)
73 hidden = torch.zeros(1, 1, 3)
74 with apply_control(model, vector, layer_index=0, strength=-1.5):
75 out = _run_through_layer(model, 0, hidden)
76 expected = torch.tensor([[[-1.5, 0.0, 0.0]]])
77 assert torch.allclose(out, expected)
78
79 def test_vector_broadcasts_across_batch_and_seq(self) -> None:
80 model = _ToyModel(n_layers=1, hidden_dim=4)
81 vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
82 hidden = torch.zeros(3, 7, 4) # batch=3, seq=7
83 with apply_control(model, vector, layer_index=0, strength=1.0):
84 out = _run_through_layer(model, 0, hidden)
85 # Every (batch, seq, 0) is 1.0; every (batch, seq, 1..3) is 0.
86 assert torch.allclose(out[..., 0], torch.ones(3, 7))
87 assert torch.allclose(out[..., 1:], torch.zeros(3, 7, 3))
88
89
90 class TestHookLifecycle:
91 def test_hook_removed_on_clean_exit(self) -> None:
92 model = _ToyModel(n_layers=1, hidden_dim=4)
93 vector = np.ones(4, dtype=np.float32)
94 hidden = torch.zeros(1, 1, 4)
95 with apply_control(model, vector, layer_index=0, strength=1.0):
96 pass
97 # After the block, the layer should not perturb anymore.
98 out = _run_through_layer(model, 0, hidden)
99 assert torch.allclose(out, hidden)
100
101 def test_hook_removed_on_exception(self) -> None:
102 model = _ToyModel(n_layers=1, hidden_dim=4)
103 vector = np.ones(4, dtype=np.float32)
104 hidden = torch.zeros(1, 1, 4)
105 with pytest.raises(RuntimeError, match="boom"):
106 with apply_control(model, vector, layer_index=0, strength=1.0):
107 raise RuntimeError("boom")
108 # Hook must be gone even after an exception.
109 out = _run_through_layer(model, 0, hidden)
110 assert torch.allclose(out, hidden)
111
112
113 class TestValidation:
114 def test_non_1d_vector_rejected(self) -> None:
115 model = _ToyModel(n_layers=1, hidden_dim=4)
116 vector = np.zeros((2, 4), dtype=np.float32)
117 with pytest.raises(ControlApplyError, match="1D"):
118 with apply_control(model, vector, layer_index=0):
119 pass
120
121 def test_non_finite_vector_rejected(self) -> None:
122 model = _ToyModel(n_layers=1, hidden_dim=4)
123 vector = np.array([1.0, float("nan"), 0.0, 0.0], dtype=np.float32)
124 with pytest.raises(ControlApplyError, match="non-finite"):
125 with apply_control(model, vector, layer_index=0):
126 pass
127
128 def test_dim_mismatch_rejected(self) -> None:
129 model = _ToyModel(n_layers=1, hidden_dim=4)
130 vector = np.zeros(8, dtype=np.float32) # wrong dim
131 with pytest.raises(ControlApplyError, match="hidden dim"):
132 with apply_control(model, vector, layer_index=0):
133 pass
134
135 def test_out_of_bounds_layer_rejected(self) -> None:
136 model = _ToyModel(n_layers=2, hidden_dim=4)
137 vector = np.ones(4, dtype=np.float32)
138 with pytest.raises(ControlApplyError, match="out of bounds"):
139 with apply_control(model, vector, layer_index=99):
140 pass
141
142 def test_negative_layer_index_works(self) -> None:
143 # `-1` should select the last layer, matching list semantics.
144 model = _ToyModel(n_layers=3, hidden_dim=4)
145 vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
146 hidden = torch.zeros(1, 1, 4)
147 with apply_control(model, vector, layer_index=-1, strength=2.0):
148 out = model.model.layers[-1](hidden)
149 assert out[0, 0, 0].item() == 2.0
150
151 def test_model_without_layers_attribute_rejected(self) -> None:
152 bare = nn.Linear(4, 4)
153 vector = np.ones(4, dtype=np.float32)
154 with pytest.raises(ControlApplyError, match="layers"):
155 with apply_control(bare, vector, layer_index=0):
156 pass
157
158 def test_peft_wrapped_model_resolves_through_base_model(self) -> None:
159 """`apply_control` must walk through `PeftModel.base_model.model`.
160
161 PEFT wraps the HF model one extra hop down: the layers live at
162 `peft_model.base_model.model.layers[i]` rather than
163 `model.model.layers[i]`. The resolver must find them anyway;
164 otherwise every PEFT-loaded inference session silently cannot
165 apply a control vector.
166 """
167
168 class _FakeLoraWrapper(nn.Module):
169 def __init__(self, hf_model: _ToyModel) -> None:
170 super().__init__()
171 # `PeftModel.base_model` is a LoraModel; `.model` points
172 # at the underlying HF model. Mirror that shape here.
173 self.model = hf_model
174
175 class _FakePeftModel(nn.Module):
176 def __init__(self, hf_model: _ToyModel) -> None:
177 super().__init__()
178 self.base_model = _FakeLoraWrapper(hf_model)
179
180 inner = _ToyModel(n_layers=4, hidden_dim=8)
181 wrapped = _FakePeftModel(inner)
182 vector = np.ones(8, dtype=np.float32)
183 hidden = torch.zeros(1, 3, 8)
184 with apply_control(wrapped, vector, layer_index=2, strength=1.0):
185 out = _run_through_layer(inner, 2, hidden)
186 assert torch.allclose(out, torch.ones_like(out))
187
188 def test_falls_through_sparse_projection_paths(self) -> None:
189 class _SparseProjLayer(nn.Module):
190 def __init__(self, hidden_dim: int) -> None:
191 super().__init__()
192 self.self_attn = nn.Module()
193 self.self_attn.q_proj = nn.Module()
194 self.attn = nn.Module()
195 self.attn.qkv_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
196
197 def forward(self, hidden: torch.Tensor) -> torch.Tensor:
198 return hidden
199
200 class _SparseProjModel(nn.Module):
201 def __init__(self, hidden_dim: int) -> None:
202 super().__init__()
203 self.model = nn.Module()
204 self.model.layers = nn.ModuleList([_SparseProjLayer(hidden_dim)])
205
206 model = _SparseProjModel(hidden_dim=4)
207 vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
208 hidden = torch.zeros(1, 1, 4)
209
210 with apply_control(model, vector, layer_index=0, strength=1.0):
211 out = model.model.layers[0](hidden)
212
213 assert out[0, 0, 0].item() == 1.0