tenseleyflow/documentlanguagemodel / 61d382f

Browse files

feat(control): reject near-orthogonal sign alignment in extract_control_vector

Authored by espadonne
SHA
61d382f6fe01f1689331bbc5190b21ce1eab440b
Parents
d12f6a7
Tree
eb5fa4d

1 changed file

StatusFile+-
M src/dlm/control/extract.py 26 1
src/dlm/control/extract.pymodified
@@ -31,6 +31,12 @@ import numpy as np
3131
 
3232
 from dlm.control.errors import ControlExtractError, ControlPolicyRefusal
3333
 
34
+# Minimum |cos(angle)| between principal SVD direction and mean(diffs)
35
+# for the sign-alignment step to be meaningful. 0.1 is a practical
36
+# "loosely aligned" floor — below it, the mean pull is essentially
37
+# orthogonal to the direction and the sign flip is dominated by noise.
38
+_SIGN_ALIGN_THRESHOLD: float = 0.1
39
+
3440
 
3541
 @dataclass(frozen=True)
3642
 class ControlVector:
@@ -118,8 +124,27 @@ def extract_control_vector(
118124
     # is unique only up to sign), which would make extraction
119125
     # non-reproducible across numpy versions. Convention: positive
120126
     # strength pushes toward chosen, so align with mean(diffs).
127
+    #
128
+    # Threshold check: when the principal direction is near-orthogonal
129
+    # to the mean pull, |cos(angle)| is near zero and the sign decision
130
+    # is dominated by numerical noise — two runs on slightly different
131
+    # data would produce opposite signs. That's a meaningless vector,
132
+    # not a steering direction; reject rather than ship a coin-flip.
121133
     mean_pull = diffs.mean(axis=0)
122
-    if float(np.dot(direction, mean_pull)) < 0:
134
+    mean_pull_norm = float(np.linalg.norm(mean_pull))
135
+    dot = float(np.dot(direction, mean_pull))
136
+    if mean_pull_norm > 0.0:
137
+        cos_align = abs(dot) / mean_pull_norm  # direction is already unit-length
138
+        if cos_align < _SIGN_ALIGN_THRESHOLD:
139
+            raise ControlExtractError(
140
+                "principal SVD direction is near-orthogonal to mean(diffs) "
141
+                f"(|cos|={cos_align:.3f} < {_SIGN_ALIGN_THRESHOLD}). "
142
+                "This means the preference pairs disagree enough that the "
143
+                "sign of the extracted direction is unstable — two runs on "
144
+                "similar data would emit opposite vectors. Gather more "
145
+                "coherent pairs or check that chosen/rejected are not swapped."
146
+            )
147
+    if dot < 0:
123148
         direction = -direction
124149
 
125150
     explained = float(singular_values[0] ** 2 / total_energy)