| 1 |
"""Extract a steering direction from paired hidden states. |
| 2 |
|
| 3 |
The math: given N preference pairs, each mapped through the base |
| 4 |
model to `(hidden_chosen_i, hidden_rejected_i)` at some residual- |
| 5 |
stream layer, the difference `d_i = hidden_chosen_i - hidden_rejected_i` |
| 6 |
is a "pull toward chosen" vector for that example. The first |
| 7 |
principal component of the stack of differences is the direction |
| 8 |
these pulls agree on — the steering direction that captures the |
| 9 |
preference shared across examples. |
| 10 |
|
| 11 |
We compute the raw (uncentered) SVD of the difference stack — |
| 12 |
matching the "Steering Llama" literature (Panickssery et al.). |
| 13 |
When every pair agrees, the principal component is the common |
| 14 |
direction; when pairs disagree, it's the direction maximizing |
| 15 |
the sum of squared projections. Sign is fixed by aligning with |
| 16 |
the mean pull, so extraction is reproducible across numpy |
| 17 |
versions. A single example's direction collapses to itself |
| 18 |
normalized — the expected limit case. |
| 19 |
|
| 20 |
The unit-test path takes synthetic NumPy arrays; no HF model |
| 21 |
needed. Wiring a real base model's forward hooks to produce the |
| 22 |
hidden states is a later-sprint concern. |
| 23 |
""" |
| 24 |
|
| 25 |
from __future__ import annotations |
| 26 |
|
| 27 |
from collections.abc import Iterable, Mapping |
| 28 |
from dataclasses import dataclass |
| 29 |
|
| 30 |
import numpy as np |
| 31 |
|
| 32 |
from dlm.control.errors import ControlExtractError, ControlPolicyRefusal |
| 33 |
|
| 34 |
# Minimum |cos(angle)| between principal SVD direction and mean(diffs) |
| 35 |
# for the sign-alignment step to be meaningful. 0.1 is a practical |
| 36 |
# "loosely aligned" floor — below it, the mean pull is essentially |
| 37 |
# orthogonal to the direction and the sign flip is dominated by noise. |
| 38 |
_SIGN_ALIGN_THRESHOLD: float = 0.1 |
| 39 |
|
| 40 |
|
| 41 |
@dataclass(frozen=True) |
| 42 |
class ControlVector: |
| 43 |
"""A single extracted steering direction. |
| 44 |
|
| 45 |
`direction` is a unit vector of length `hidden_dim`. |
| 46 |
`n_pairs` lets callers reconstruct how many examples fed the |
| 47 |
extraction when rendering audit output. `explained_variance` |
| 48 |
is the leading singular value squared over the total — a 1.0 |
| 49 |
reading means every pair agreed perfectly, while 0.25 means |
| 50 |
the principal component explains a quarter of the preference |
| 51 |
spread (the rest is noise or contradictory pairs). |
| 52 |
""" |
| 53 |
|
| 54 |
direction: np.ndarray |
| 55 |
n_pairs: int |
| 56 |
explained_variance: float |
| 57 |
|
| 58 |
|
| 59 |
def extract_control_vector( |
| 60 |
hidden_chosen: np.ndarray, |
| 61 |
hidden_rejected: np.ndarray, |
| 62 |
) -> ControlVector: |
| 63 |
"""Compute the steering direction from paired hidden states. |
| 64 |
|
| 65 |
`hidden_chosen` / `hidden_rejected` are `(N, hidden_dim)` float |
| 66 |
arrays of hidden states at one residual-stream layer. The output |
| 67 |
`direction` is a unit vector oriented so that positive strength |
| 68 |
pushes toward `chosen`. |
| 69 |
|
| 70 |
Raises `ControlExtractError` on: |
| 71 |
- mismatched shapes |
| 72 |
- `N < 1` |
| 73 |
- non-finite entries (NaN hidden states from a bad forward pass) |
| 74 |
- zero-variance differences (every chosen identical to rejected → |
| 75 |
no signal to extract) |
| 76 |
""" |
| 77 |
if hidden_chosen.shape != hidden_rejected.shape: |
| 78 |
raise ControlExtractError( |
| 79 |
f"chosen/rejected shape mismatch: {hidden_chosen.shape} vs {hidden_rejected.shape}" |
| 80 |
) |
| 81 |
if hidden_chosen.ndim != 2: |
| 82 |
raise ControlExtractError(f"expected 2D (N, hidden_dim) arrays, got {hidden_chosen.ndim}D") |
| 83 |
if hidden_chosen.shape[0] < 1: |
| 84 |
raise ControlExtractError("need at least one (chosen, rejected) pair") |
| 85 |
if not (np.isfinite(hidden_chosen).all() and np.isfinite(hidden_rejected).all()): |
| 86 |
raise ControlExtractError("hidden states contain non-finite values") |
| 87 |
|
| 88 |
diffs = hidden_chosen.astype(np.float64) - hidden_rejected.astype(np.float64) |
| 89 |
# Single-pair limit case: the direction is just that pair, |
| 90 |
# normalized. SVD on one row works but this short-circuit keeps |
| 91 |
# the explained-variance denominator well-defined (it's 1.0 by |
| 92 |
# definition when there's only one component to explain). |
| 93 |
if diffs.shape[0] == 1: |
| 94 |
norm = float(np.linalg.norm(diffs[0])) |
| 95 |
if norm == 0.0: |
| 96 |
raise ControlExtractError("single pair has zero chosen/rejected difference") |
| 97 |
return ControlVector( |
| 98 |
direction=(diffs[0] / norm).astype(np.float32), |
| 99 |
n_pairs=1, |
| 100 |
explained_variance=1.0, |
| 101 |
) |
| 102 |
|
| 103 |
# Raw (uncentered) SVD on the difference stack — matches the |
| 104 |
# control-vector literature (Panickssery et al., "Steering |
| 105 |
# Llama"). Centering would wipe the signal when every pair |
| 106 |
# agrees exactly; uncentered, the principal component is the |
| 107 |
# direction maximizing the sum of squared projections, which |
| 108 |
# coincides with the mean pull when all diffs align and tracks |
| 109 |
# the dominant direction otherwise. |
| 110 |
total_energy = float(np.sum(diffs**2)) |
| 111 |
if total_energy == 0.0: |
| 112 |
raise ControlExtractError( |
| 113 |
"zero chosen/rejected differences across all pairs — no signal to extract" |
| 114 |
) |
| 115 |
|
| 116 |
# Thin SVD: full_matrices=False so we don't allocate an |
| 117 |
# (N, N) left matrix we never use. |
| 118 |
_u, singular_values, vh = np.linalg.svd(diffs, full_matrices=False) |
| 119 |
# Principal direction is the first right-singular vector. |
| 120 |
direction = vh[0] |
| 121 |
|
| 122 |
# Orient so that the direction points *toward* chosen. Without |
| 123 |
# this the sign of the first singular vector is arbitrary (SVD |
| 124 |
# is unique only up to sign), which would make extraction |
| 125 |
# non-reproducible across numpy versions. Convention: positive |
| 126 |
# strength pushes toward chosen, so align with mean(diffs). |
| 127 |
# |
| 128 |
# Threshold check: when the principal direction is near-orthogonal |
| 129 |
# to the mean pull, |cos(angle)| is near zero and the sign decision |
| 130 |
# is dominated by numerical noise — two runs on slightly different |
| 131 |
# data would produce opposite signs. That's a meaningless vector, |
| 132 |
# not a steering direction; reject rather than ship a coin-flip. |
| 133 |
mean_pull = diffs.mean(axis=0) |
| 134 |
mean_pull_norm = float(np.linalg.norm(mean_pull)) |
| 135 |
dot = float(np.dot(direction, mean_pull)) |
| 136 |
if mean_pull_norm > 0.0: |
| 137 |
cos_align = abs(dot) / mean_pull_norm # direction is already unit-length |
| 138 |
if cos_align < _SIGN_ALIGN_THRESHOLD: |
| 139 |
raise ControlExtractError( |
| 140 |
"principal SVD direction is near-orthogonal to mean(diffs) " |
| 141 |
f"(|cos|={cos_align:.3f} < {_SIGN_ALIGN_THRESHOLD}). " |
| 142 |
"This means the preference pairs disagree enough that the " |
| 143 |
"sign of the extracted direction is unstable — two runs on " |
| 144 |
"similar data would emit opposite vectors. Gather more " |
| 145 |
"coherent pairs or check that chosen/rejected are not swapped." |
| 146 |
) |
| 147 |
if dot < 0: |
| 148 |
direction = -direction |
| 149 |
|
| 150 |
explained = float(singular_values[0] ** 2 / total_energy) |
| 151 |
return ControlVector( |
| 152 |
direction=direction.astype(np.float32), |
| 153 |
n_pairs=int(diffs.shape[0]), |
| 154 |
explained_variance=explained, |
| 155 |
) |
| 156 |
|
| 157 |
|
| 158 |
_SAFETY_POLICY_VALUE = "safety" |
| 159 |
_POLICY_TAG_KEY = "policy" |
| 160 |
|
| 161 |
|
| 162 |
def refuse_if_policy_safety( |
| 163 |
section_tags: Iterable[Mapping[str, str]], |
| 164 |
) -> None: |
| 165 |
"""Refuse extraction when any source section carries `policy: safety`. |
| 166 |
|
| 167 |
A control vector over safety-flagged preference pairs would, by |
| 168 |
construction, be a "more safety vs less safety" steering |
| 169 |
direction — applied at negative strength, it erodes the exact |
| 170 |
behavior the document is trying to preserve. We don't offer the |
| 171 |
user that footgun. The check runs at extraction-entry so the |
| 172 |
artifact never reaches disk. |
| 173 |
|
| 174 |
Takes a flat iterable of per-section tag dicts so callers can |
| 175 |
pass whatever source their sections were collected from |
| 176 |
(preference sections, a mix of types, etc.). Cost is linear |
| 177 |
in the section count — negligible next to the HF forward pass. |
| 178 |
""" |
| 179 |
for tags in section_tags: |
| 180 |
if tags.get(_POLICY_TAG_KEY) == _SAFETY_POLICY_VALUE: |
| 181 |
raise ControlPolicyRefusal( |
| 182 |
"refusing to extract a control vector from preference " |
| 183 |
"sections tagged `policy: safety` — the resulting steering " |
| 184 |
"direction could be used at negative strength to undo the " |
| 185 |
"safety training the document is trying to preserve. " |
| 186 |
"Extract separately from non-safety preferences instead." |
| 187 |
) |