feat(control): reject near-orthogonal sign alignment in extract_control_vector
- SHA
61d382f6fe01f1689331bbc5190b21ce1eab440b- Parents
-
d12f6a7 - Tree
eb5fa4d
61d382f
61d382f6fe01f1689331bbc5190b21ce1eab440bd12f6a7
eb5fa4d| Status | File | + | - |
|---|---|---|---|
| M |
src/dlm/control/extract.py
|
26 | 1 |
src/dlm/control/extract.pymodified@@ -31,6 +31,12 @@ import numpy as np | ||
| 31 | 31 | |
| 32 | 32 | from dlm.control.errors import ControlExtractError, ControlPolicyRefusal |
| 33 | 33 | |
| 34 | +# Minimum |cos(angle)| between principal SVD direction and mean(diffs) | |
| 35 | +# for the sign-alignment step to be meaningful. 0.1 is a practical | |
| 36 | +# "loosely aligned" floor — below it, the mean pull is essentially | |
| 37 | +# orthogonal to the direction and the sign flip is dominated by noise. | |
| 38 | +_SIGN_ALIGN_THRESHOLD: float = 0.1 | |
| 39 | + | |
| 34 | 40 | |
| 35 | 41 | @dataclass(frozen=True) |
| 36 | 42 | class ControlVector: |
@@ -118,8 +124,27 @@ def extract_control_vector( | ||
| 118 | 124 | # is unique only up to sign), which would make extraction |
| 119 | 125 | # non-reproducible across numpy versions. Convention: positive |
| 120 | 126 | # strength pushes toward chosen, so align with mean(diffs). |
| 127 | + # | |
| 128 | + # Threshold check: when the principal direction is near-orthogonal | |
| 129 | + # to the mean pull, |cos(angle)| is near zero and the sign decision | |
| 130 | + # is dominated by numerical noise — two runs on slightly different | |
| 131 | + # data would produce opposite signs. That's a meaningless vector, | |
| 132 | + # not a steering direction; reject rather than ship a coin-flip. | |
| 121 | 133 | mean_pull = diffs.mean(axis=0) |
| 122 | - if float(np.dot(direction, mean_pull)) < 0: | |
| 134 | + mean_pull_norm = float(np.linalg.norm(mean_pull)) | |
| 135 | + dot = float(np.dot(direction, mean_pull)) | |
| 136 | + if mean_pull_norm > 0.0: | |
| 137 | + cos_align = abs(dot) / mean_pull_norm # direction is already unit-length | |
| 138 | + if cos_align < _SIGN_ALIGN_THRESHOLD: | |
| 139 | + raise ControlExtractError( | |
| 140 | + "principal SVD direction is near-orthogonal to mean(diffs) " | |
| 141 | + f"(|cos|={cos_align:.3f} < {_SIGN_ALIGN_THRESHOLD}). " | |
| 142 | + "This means the preference pairs disagree enough that the " | |
| 143 | + "sign of the extracted direction is unstable — two runs on " | |
| 144 | + "similar data would emit opposite vectors. Gather more " | |
| 145 | + "coherent pairs or check that chosen/rejected are not swapped." | |
| 146 | + ) | |
| 147 | + if dot < 0: | |
| 123 | 148 | direction = -direction |
| 124 | 149 | |
| 125 | 150 | explained = float(singular_values[0] ** 2 / total_energy) |