| 1 |
"""`apply_control` — forward_pre_hook attach/detach + arithmetic.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import numpy as np |
| 6 |
import pytest |
| 7 |
import torch |
| 8 |
from torch import nn |
| 9 |
|
| 10 |
from dlm.control import ControlApplyError, apply_control |
| 11 |
from dlm.control.apply import _make_hook |
| 12 |
|
| 13 |
|
| 14 |
class _ToyLayer(nn.Module): |
| 15 |
"""A stand-in for an HF decoder layer. |
| 16 |
|
| 17 |
Owns `self_attn.q_proj` so the dim-validation path has something |
| 18 |
to inspect; the forward just passes the hidden state through so |
| 19 |
the hook's perturbation is visible on the output. |
| 20 |
""" |
| 21 |
|
| 22 |
def __init__(self, hidden_dim: int) -> None: |
| 23 |
super().__init__() |
| 24 |
self.self_attn = nn.Module() |
| 25 |
# `nn.Linear(hidden_dim, hidden_dim)` weight shape is |
| 26 |
# `(out, in)` — the `[-1]` the apply path reads is `in = hidden_dim`. |
| 27 |
self.self_attn.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| 28 |
|
| 29 |
def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| 30 |
return hidden |
| 31 |
|
| 32 |
|
| 33 |
class _ToyModel(nn.Module): |
| 34 |
"""Minimal HF-shaped wrapper: `model.model.layers[i]`.""" |
| 35 |
|
| 36 |
def __init__(self, n_layers: int, hidden_dim: int) -> None: |
| 37 |
super().__init__() |
| 38 |
self.model = nn.Module() |
| 39 |
self.model.layers = nn.ModuleList([_ToyLayer(hidden_dim) for _ in range(n_layers)]) |
| 40 |
|
| 41 |
|
| 42 |
def _run_through_layer(model: _ToyModel, layer_index: int, hidden: torch.Tensor) -> torch.Tensor: |
| 43 |
return model.model.layers[layer_index](hidden) |
| 44 |
|
| 45 |
|
| 46 |
class TestHookArithmetic: |
| 47 |
def test_hook_with_no_args_returns_original_tuple(self) -> None: |
| 48 |
hook = _make_hook(torch.ones(4), 1.0) |
| 49 |
empty: tuple[object, ...] = () |
| 50 |
|
| 51 |
assert hook(None, empty) == empty |
| 52 |
|
| 53 |
def test_adds_scaled_vector_to_hidden(self) -> None: |
| 54 |
model = _ToyModel(n_layers=4, hidden_dim=8) |
| 55 |
vector = np.ones(8, dtype=np.float32) |
| 56 |
hidden = torch.zeros(1, 3, 8) |
| 57 |
with apply_control(model, vector, layer_index=2, strength=2.5): |
| 58 |
out = _run_through_layer(model, 2, hidden) |
| 59 |
# strength=2.5, vector=[1,1,1,...] → each output element is 2.5. |
| 60 |
assert torch.allclose(out, torch.full_like(out, 2.5)) |
| 61 |
|
| 62 |
def test_zero_strength_is_passthrough(self) -> None: |
| 63 |
model = _ToyModel(n_layers=2, hidden_dim=4) |
| 64 |
vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| 65 |
hidden = torch.randn(2, 5, 4) |
| 66 |
with apply_control(model, vector, layer_index=0, strength=0.0): |
| 67 |
out = _run_through_layer(model, 0, hidden) |
| 68 |
assert torch.allclose(out, hidden) |
| 69 |
|
| 70 |
def test_negative_strength_pushes_opposite(self) -> None: |
| 71 |
model = _ToyModel(n_layers=2, hidden_dim=3) |
| 72 |
vector = np.array([1.0, 0.0, 0.0], dtype=np.float32) |
| 73 |
hidden = torch.zeros(1, 1, 3) |
| 74 |
with apply_control(model, vector, layer_index=0, strength=-1.5): |
| 75 |
out = _run_through_layer(model, 0, hidden) |
| 76 |
expected = torch.tensor([[[-1.5, 0.0, 0.0]]]) |
| 77 |
assert torch.allclose(out, expected) |
| 78 |
|
| 79 |
def test_vector_broadcasts_across_batch_and_seq(self) -> None: |
| 80 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 81 |
vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) |
| 82 |
hidden = torch.zeros(3, 7, 4) # batch=3, seq=7 |
| 83 |
with apply_control(model, vector, layer_index=0, strength=1.0): |
| 84 |
out = _run_through_layer(model, 0, hidden) |
| 85 |
# Every (batch, seq, 0) is 1.0; every (batch, seq, 1..3) is 0. |
| 86 |
assert torch.allclose(out[..., 0], torch.ones(3, 7)) |
| 87 |
assert torch.allclose(out[..., 1:], torch.zeros(3, 7, 3)) |
| 88 |
|
| 89 |
|
| 90 |
class TestHookLifecycle: |
| 91 |
def test_hook_removed_on_clean_exit(self) -> None: |
| 92 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 93 |
vector = np.ones(4, dtype=np.float32) |
| 94 |
hidden = torch.zeros(1, 1, 4) |
| 95 |
with apply_control(model, vector, layer_index=0, strength=1.0): |
| 96 |
pass |
| 97 |
# After the block, the layer should not perturb anymore. |
| 98 |
out = _run_through_layer(model, 0, hidden) |
| 99 |
assert torch.allclose(out, hidden) |
| 100 |
|
| 101 |
def test_hook_removed_on_exception(self) -> None: |
| 102 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 103 |
vector = np.ones(4, dtype=np.float32) |
| 104 |
hidden = torch.zeros(1, 1, 4) |
| 105 |
with pytest.raises(RuntimeError, match="boom"): |
| 106 |
with apply_control(model, vector, layer_index=0, strength=1.0): |
| 107 |
raise RuntimeError("boom") |
| 108 |
# Hook must be gone even after an exception. |
| 109 |
out = _run_through_layer(model, 0, hidden) |
| 110 |
assert torch.allclose(out, hidden) |
| 111 |
|
| 112 |
|
| 113 |
class TestValidation: |
| 114 |
def test_non_1d_vector_rejected(self) -> None: |
| 115 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 116 |
vector = np.zeros((2, 4), dtype=np.float32) |
| 117 |
with pytest.raises(ControlApplyError, match="1D"): |
| 118 |
with apply_control(model, vector, layer_index=0): |
| 119 |
pass |
| 120 |
|
| 121 |
def test_non_finite_vector_rejected(self) -> None: |
| 122 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 123 |
vector = np.array([1.0, float("nan"), 0.0, 0.0], dtype=np.float32) |
| 124 |
with pytest.raises(ControlApplyError, match="non-finite"): |
| 125 |
with apply_control(model, vector, layer_index=0): |
| 126 |
pass |
| 127 |
|
| 128 |
def test_dim_mismatch_rejected(self) -> None: |
| 129 |
model = _ToyModel(n_layers=1, hidden_dim=4) |
| 130 |
vector = np.zeros(8, dtype=np.float32) # wrong dim |
| 131 |
with pytest.raises(ControlApplyError, match="hidden dim"): |
| 132 |
with apply_control(model, vector, layer_index=0): |
| 133 |
pass |
| 134 |
|
| 135 |
def test_out_of_bounds_layer_rejected(self) -> None: |
| 136 |
model = _ToyModel(n_layers=2, hidden_dim=4) |
| 137 |
vector = np.ones(4, dtype=np.float32) |
| 138 |
with pytest.raises(ControlApplyError, match="out of bounds"): |
| 139 |
with apply_control(model, vector, layer_index=99): |
| 140 |
pass |
| 141 |
|
| 142 |
def test_negative_layer_index_works(self) -> None: |
| 143 |
# `-1` should select the last layer, matching list semantics. |
| 144 |
model = _ToyModel(n_layers=3, hidden_dim=4) |
| 145 |
vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) |
| 146 |
hidden = torch.zeros(1, 1, 4) |
| 147 |
with apply_control(model, vector, layer_index=-1, strength=2.0): |
| 148 |
out = model.model.layers[-1](hidden) |
| 149 |
assert out[0, 0, 0].item() == 2.0 |
| 150 |
|
| 151 |
def test_model_without_layers_attribute_rejected(self) -> None: |
| 152 |
bare = nn.Linear(4, 4) |
| 153 |
vector = np.ones(4, dtype=np.float32) |
| 154 |
with pytest.raises(ControlApplyError, match="layers"): |
| 155 |
with apply_control(bare, vector, layer_index=0): |
| 156 |
pass |
| 157 |
|
| 158 |
def test_peft_wrapped_model_resolves_through_base_model(self) -> None: |
| 159 |
"""`apply_control` must walk through `PeftModel.base_model.model`. |
| 160 |
|
| 161 |
PEFT wraps the HF model one extra hop down: the layers live at |
| 162 |
`peft_model.base_model.model.layers[i]` rather than |
| 163 |
`model.model.layers[i]`. The resolver must find them anyway; |
| 164 |
otherwise every PEFT-loaded inference session silently cannot |
| 165 |
apply a control vector. |
| 166 |
""" |
| 167 |
|
| 168 |
class _FakeLoraWrapper(nn.Module): |
| 169 |
def __init__(self, hf_model: _ToyModel) -> None: |
| 170 |
super().__init__() |
| 171 |
# `PeftModel.base_model` is a LoraModel; `.model` points |
| 172 |
# at the underlying HF model. Mirror that shape here. |
| 173 |
self.model = hf_model |
| 174 |
|
| 175 |
class _FakePeftModel(nn.Module): |
| 176 |
def __init__(self, hf_model: _ToyModel) -> None: |
| 177 |
super().__init__() |
| 178 |
self.base_model = _FakeLoraWrapper(hf_model) |
| 179 |
|
| 180 |
inner = _ToyModel(n_layers=4, hidden_dim=8) |
| 181 |
wrapped = _FakePeftModel(inner) |
| 182 |
vector = np.ones(8, dtype=np.float32) |
| 183 |
hidden = torch.zeros(1, 3, 8) |
| 184 |
with apply_control(wrapped, vector, layer_index=2, strength=1.0): |
| 185 |
out = _run_through_layer(inner, 2, hidden) |
| 186 |
assert torch.allclose(out, torch.ones_like(out)) |
| 187 |
|
| 188 |
def test_falls_through_sparse_projection_paths(self) -> None: |
| 189 |
class _SparseProjLayer(nn.Module): |
| 190 |
def __init__(self, hidden_dim: int) -> None: |
| 191 |
super().__init__() |
| 192 |
self.self_attn = nn.Module() |
| 193 |
self.self_attn.q_proj = nn.Module() |
| 194 |
self.attn = nn.Module() |
| 195 |
self.attn.qkv_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) |
| 196 |
|
| 197 |
def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| 198 |
return hidden |
| 199 |
|
| 200 |
class _SparseProjModel(nn.Module): |
| 201 |
def __init__(self, hidden_dim: int) -> None: |
| 202 |
super().__init__() |
| 203 |
self.model = nn.Module() |
| 204 |
self.model.layers = nn.ModuleList([_SparseProjLayer(hidden_dim)]) |
| 205 |
|
| 206 |
model = _SparseProjModel(hidden_dim=4) |
| 207 |
vector = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) |
| 208 |
hidden = torch.zeros(1, 1, 4) |
| 209 |
|
| 210 |
with apply_control(model, vector, layer_index=0, strength=1.0): |
| 211 |
out = model.model.layers[0](hidden) |
| 212 |
|
| 213 |
assert out[0, 0, 0].item() == 1.0 |