Python · 6692 bytes Raw Blame History
1 """Attach a control vector to a model at inference time.
2
3 The apply path is a thin `forward_pre_hook` over one residual-
4 stream layer. On every forward pass, the hook adds
5 `strength * vector` to each token's hidden state. No weight
6 changes, no retraining — the steering is purely a forward-time
7 perturbation, which is why extraction takes seconds instead of
8 hours.
9
10 Strength semantics: positive pushes toward the `chosen`
11 distribution the vector was extracted against; negative pushes
12 away. Typical range is `[-2, 2]`; beyond `±3` the model tends
13 to collapse into repetition or nonsense.
14
15 Usage:
16 ```python
17 from dlm.control import apply_control, extract_control_vector
18
19 vec = extract_control_vector(chosen, rejected)
20 with apply_control(model, vec.direction, layer_index=12, strength=1.5):
21 out = model.generate(...)
22 ```
23
24 The context manager guarantees the hook is removed on exit,
25 even when the wrapped block raises — we can't leave a stray
26 hook on the model, because subsequent unrelated forward passes
27 would silently keep steering.
28 """
29
30 from __future__ import annotations
31
32 from collections.abc import Iterator
33 from contextlib import contextmanager
34 from typing import TYPE_CHECKING, Any
35
36 import numpy as np
37
38 from dlm.control.errors import ControlApplyError
39
40 if TYPE_CHECKING:
41 import torch
42
43
44 def _resolve_layer(model: Any, layer_index: int) -> Any:
45 """Locate the residual-stream module at `layer_index`.
46
47 HF decoder-only models expose `model.model.layers[i]` (Llama,
48 Qwen, SmolLM, Phi — the canonical path). PEFT wraps that under a
49 `base_model.model.model.layers[i]` chain (`PeftModel.base_model`
50 is a `LoraModel` whose `.model` is the HF model). Rather than
51 hard-code two shapes, we walk down repeated `base_model` / `model`
52 hops — the first node that exposes a `layers` attribute wins.
53 `layer_index` can be negative (`-1` is the last layer).
54
55 Raises `ControlApplyError` when the walker can't find a `layers`
56 attribute anywhere down the chain or when the index is out of
57 bounds.
58 """
59 layers = None
60 node: Any = model
61 # Cap the walk so a pathological graph can't spin forever; real
62 # wrappers are at most 2-3 deep (PEFT adds 2, a rare Accelerate
63 # wrapper adds 1).
64 for _ in range(6):
65 layers = getattr(node, "layers", None)
66 if layers is not None:
67 break
68 next_node = getattr(node, "base_model", None) or getattr(node, "model", None)
69 if next_node is None or next_node is node:
70 break
71 node = next_node
72
73 if layers is None:
74 raise ControlApplyError(
75 "model exposes no `layers` attribute along the "
76 "`base_model` / `model` chain — don't know where to "
77 "attach the forward hook. Pass a HF decoder-only model "
78 "(or a PEFT-wrapped one), or file an issue with the model "
79 "class for wiring."
80 )
81 try:
82 return layers[layer_index]
83 except (IndexError, TypeError) as exc:
84 raise ControlApplyError(
85 f"layer_index={layer_index} out of bounds for a {len(layers)}-layer model"
86 ) from exc
87
88
89 def _make_hook(vector: torch.Tensor, strength: float) -> Any:
90 """Build a `forward_pre_hook` that adds `strength * vector` to inputs.
91
92 The hook receives `(module, args)` where `args[0]` is the
93 hidden-state tensor of shape `(batch, seq, hidden_dim)`. We
94 broadcast the vector across the `batch` and `seq` axes — the
95 same steering direction applies to every token position, which
96 is the canonical control-vector interpretation (steer the entire
97 generation, not one token).
98
99 Returns the new args tuple with the perturbed hidden state in
100 position 0. HF layers accept positional args for the hidden
101 state; kwargs flow through untouched.
102 """
103
104 def _hook(_module: Any, args: tuple[Any, ...]) -> tuple[Any, ...]:
105 if not args:
106 return args
107 hidden = args[0]
108 # Move/cast vector to match hidden's device + dtype.
109 steer = vector.to(device=hidden.device, dtype=hidden.dtype)
110 perturbed = hidden + strength * steer
111 return (perturbed, *args[1:])
112
113 return _hook
114
115
116 @contextmanager
117 def apply_control(
118 model: Any,
119 vector: np.ndarray,
120 *,
121 layer_index: int,
122 strength: float = 1.0,
123 ) -> Iterator[Any]:
124 """Attach `strength * vector` to the residual stream at `layer_index`.
125
126 Yields the model for use inside a `with` block. On exit — whether
127 clean or via exception — the forward hook is removed. No
128 weights change; the effect is forward-pass-only.
129
130 Raises `ControlApplyError` on shape mismatch or invalid layer
131 index. Shape validation happens up front, not inside the hook,
132 so a malformed vector fails before any compute burns.
133
134 `vector` is accepted as NumPy (the storage format) and converted
135 to torch on demand — dtype matching to the model's hidden state
136 happens inside the hook, so a float32 vector can steer a bf16
137 model without explicit casting by the caller.
138 """
139 import torch # deferred — apply is runtime-only
140
141 if vector.ndim != 1:
142 raise ControlApplyError(
143 f"control vector must be 1D (hidden_dim,), got shape {vector.shape}"
144 )
145 if not np.isfinite(vector).all():
146 raise ControlApplyError("control vector contains non-finite values")
147
148 target_layer = _resolve_layer(model, layer_index)
149 # Validate vector length against a weight the layer actually
150 # owns. Different architectures put the input-projection under
151 # different names — try the common ones.
152 expected_dim: int | None = None
153 for attr in ("self_attn", "attention", "attn"):
154 sub = getattr(target_layer, attr, None)
155 if sub is None:
156 continue
157 for proj_attr in ("q_proj", "qkv_proj"):
158 proj = getattr(sub, proj_attr, None)
159 if proj is None:
160 continue
161 weight = getattr(proj, "weight", None)
162 if weight is None:
163 continue
164 expected_dim = int(weight.shape[-1])
165 break
166 if expected_dim is not None:
167 break
168
169 if expected_dim is not None and vector.shape[0] != expected_dim:
170 raise ControlApplyError(
171 f"control vector dim {vector.shape[0]} does not match model "
172 f"hidden dim {expected_dim} at layer {layer_index}"
173 )
174
175 vec_tensor = torch.from_numpy(np.ascontiguousarray(vector))
176 hook = _make_hook(vec_tensor, strength)
177 handle = target_layer.register_forward_pre_hook(hook)
178 try:
179 yield model
180 finally:
181 handle.remove()