Python · 8254 bytes Raw Blame History
1 """Extract a steering direction from paired hidden states.
2
3 The math: given N preference pairs, each mapped through the base
4 model to `(hidden_chosen_i, hidden_rejected_i)` at some residual-
5 stream layer, the difference `d_i = hidden_chosen_i - hidden_rejected_i`
6 is a "pull toward chosen" vector for that example. The first
7 principal component of the stack of differences is the direction
8 these pulls agree on — the steering direction that captures the
9 preference shared across examples.
10
11 We compute the raw (uncentered) SVD of the difference stack —
12 matching the "Steering Llama" literature (Panickssery et al.).
13 When every pair agrees, the principal component is the common
14 direction; when pairs disagree, it's the direction maximizing
15 the sum of squared projections. Sign is fixed by aligning with
16 the mean pull, so extraction is reproducible across numpy
17 versions. A single example's direction collapses to itself
18 normalized — the expected limit case.
19
20 The unit-test path takes synthetic NumPy arrays; no HF model
21 needed. Wiring a real base model's forward hooks to produce the
22 hidden states is a later-sprint concern.
23 """
24
25 from __future__ import annotations
26
27 from collections.abc import Iterable, Mapping
28 from dataclasses import dataclass
29
30 import numpy as np
31
32 from dlm.control.errors import ControlExtractError, ControlPolicyRefusal
33
34 # Minimum |cos(angle)| between principal SVD direction and mean(diffs)
35 # for the sign-alignment step to be meaningful. 0.1 is a practical
36 # "loosely aligned" floor — below it, the mean pull is essentially
37 # orthogonal to the direction and the sign flip is dominated by noise.
38 _SIGN_ALIGN_THRESHOLD: float = 0.1
39
40
41 @dataclass(frozen=True)
42 class ControlVector:
43 """A single extracted steering direction.
44
45 `direction` is a unit vector of length `hidden_dim`.
46 `n_pairs` lets callers reconstruct how many examples fed the
47 extraction when rendering audit output. `explained_variance`
48 is the leading singular value squared over the total — a 1.0
49 reading means every pair agreed perfectly, while 0.25 means
50 the principal component explains a quarter of the preference
51 spread (the rest is noise or contradictory pairs).
52 """
53
54 direction: np.ndarray
55 n_pairs: int
56 explained_variance: float
57
58
59 def extract_control_vector(
60 hidden_chosen: np.ndarray,
61 hidden_rejected: np.ndarray,
62 ) -> ControlVector:
63 """Compute the steering direction from paired hidden states.
64
65 `hidden_chosen` / `hidden_rejected` are `(N, hidden_dim)` float
66 arrays of hidden states at one residual-stream layer. The output
67 `direction` is a unit vector oriented so that positive strength
68 pushes toward `chosen`.
69
70 Raises `ControlExtractError` on:
71 - mismatched shapes
72 - `N < 1`
73 - non-finite entries (NaN hidden states from a bad forward pass)
74 - zero-variance differences (every chosen identical to rejected →
75 no signal to extract)
76 """
77 if hidden_chosen.shape != hidden_rejected.shape:
78 raise ControlExtractError(
79 f"chosen/rejected shape mismatch: {hidden_chosen.shape} vs {hidden_rejected.shape}"
80 )
81 if hidden_chosen.ndim != 2:
82 raise ControlExtractError(f"expected 2D (N, hidden_dim) arrays, got {hidden_chosen.ndim}D")
83 if hidden_chosen.shape[0] < 1:
84 raise ControlExtractError("need at least one (chosen, rejected) pair")
85 if not (np.isfinite(hidden_chosen).all() and np.isfinite(hidden_rejected).all()):
86 raise ControlExtractError("hidden states contain non-finite values")
87
88 diffs = hidden_chosen.astype(np.float64) - hidden_rejected.astype(np.float64)
89 # Single-pair limit case: the direction is just that pair,
90 # normalized. SVD on one row works but this short-circuit keeps
91 # the explained-variance denominator well-defined (it's 1.0 by
92 # definition when there's only one component to explain).
93 if diffs.shape[0] == 1:
94 norm = float(np.linalg.norm(diffs[0]))
95 if norm == 0.0:
96 raise ControlExtractError("single pair has zero chosen/rejected difference")
97 return ControlVector(
98 direction=(diffs[0] / norm).astype(np.float32),
99 n_pairs=1,
100 explained_variance=1.0,
101 )
102
103 # Raw (uncentered) SVD on the difference stack — matches the
104 # control-vector literature (Panickssery et al., "Steering
105 # Llama"). Centering would wipe the signal when every pair
106 # agrees exactly; uncentered, the principal component is the
107 # direction maximizing the sum of squared projections, which
108 # coincides with the mean pull when all diffs align and tracks
109 # the dominant direction otherwise.
110 total_energy = float(np.sum(diffs**2))
111 if total_energy == 0.0:
112 raise ControlExtractError(
113 "zero chosen/rejected differences across all pairs — no signal to extract"
114 )
115
116 # Thin SVD: full_matrices=False so we don't allocate an
117 # (N, N) left matrix we never use.
118 _u, singular_values, vh = np.linalg.svd(diffs, full_matrices=False)
119 # Principal direction is the first right-singular vector.
120 direction = vh[0]
121
122 # Orient so that the direction points *toward* chosen. Without
123 # this the sign of the first singular vector is arbitrary (SVD
124 # is unique only up to sign), which would make extraction
125 # non-reproducible across numpy versions. Convention: positive
126 # strength pushes toward chosen, so align with mean(diffs).
127 #
128 # Threshold check: when the principal direction is near-orthogonal
129 # to the mean pull, |cos(angle)| is near zero and the sign decision
130 # is dominated by numerical noise — two runs on slightly different
131 # data would produce opposite signs. That's a meaningless vector,
132 # not a steering direction; reject rather than ship a coin-flip.
133 mean_pull = diffs.mean(axis=0)
134 mean_pull_norm = float(np.linalg.norm(mean_pull))
135 dot = float(np.dot(direction, mean_pull))
136 if mean_pull_norm > 0.0:
137 cos_align = abs(dot) / mean_pull_norm # direction is already unit-length
138 if cos_align < _SIGN_ALIGN_THRESHOLD:
139 raise ControlExtractError(
140 "principal SVD direction is near-orthogonal to mean(diffs) "
141 f"(|cos|={cos_align:.3f} < {_SIGN_ALIGN_THRESHOLD}). "
142 "This means the preference pairs disagree enough that the "
143 "sign of the extracted direction is unstable — two runs on "
144 "similar data would emit opposite vectors. Gather more "
145 "coherent pairs or check that chosen/rejected are not swapped."
146 )
147 if dot < 0:
148 direction = -direction
149
150 explained = float(singular_values[0] ** 2 / total_energy)
151 return ControlVector(
152 direction=direction.astype(np.float32),
153 n_pairs=int(diffs.shape[0]),
154 explained_variance=explained,
155 )
156
157
158 _SAFETY_POLICY_VALUE = "safety"
159 _POLICY_TAG_KEY = "policy"
160
161
162 def refuse_if_policy_safety(
163 section_tags: Iterable[Mapping[str, str]],
164 ) -> None:
165 """Refuse extraction when any source section carries `policy: safety`.
166
167 A control vector over safety-flagged preference pairs would, by
168 construction, be a "more safety vs less safety" steering
169 direction — applied at negative strength, it erodes the exact
170 behavior the document is trying to preserve. We don't offer the
171 user that footgun. The check runs at extraction-entry so the
172 artifact never reaches disk.
173
174 Takes a flat iterable of per-section tag dicts so callers can
175 pass whatever source their sections were collected from
176 (preference sections, a mix of types, etc.). Cost is linear
177 in the section count — negligible next to the HF forward pass.
178 """
179 for tags in section_tags:
180 if tags.get(_POLICY_TAG_KEY) == _SAFETY_POLICY_VALUE:
181 raise ControlPolicyRefusal(
182 "refusing to extract a control vector from preference "
183 "sections tagged `policy: safety` — the resulting steering "
184 "direction could be used at negative strength to undo the "
185 "safety training the document is trying to preserve. "
186 "Extract separately from non-safety preferences instead."
187 )