tenseleyflow/sway / 6c8cd81

Browse files

tests/adapter_ablation: probe-level saturation reason coverage (C8, B3 test side)

Authored by espadonne
SHA
6c8cd8155f698702a443a74a24240e53cbb7425b
Parents
4346b16
Tree
a5b03a2

1 changed file

StatusFile+-
M tests/unit/test_probe_adapter_ablation.py 80 0
tests/unit/test_probe_adapter_ablation.pymodified
@@ -7,6 +7,7 @@ the full probe path without loading a real model.
77
 from __future__ import annotations
88
 
99
 import numpy as np
10
+import pytest
1011
 
1112
 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
1213
 from dlm_sway.core.result import Verdict
@@ -90,6 +91,85 @@ class TestShapeMetrics:
9091
         assert sat is None
9192
         assert reason == "flat_curve"
9293
 
94
+    def test_saturation_monotonically_decreasing(self) -> None:
95
+        """A curve that goes *down* with λ — adapter is anti-correlated
96
+        with its own effect (the degenerate case where lam=1 is closer
97
+        to base than lam=0)."""
98
+        lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
99
+        divs = np.asarray([0.9, 0.6, 0.3, 0.1], dtype=np.float64)
100
+        sat, reason = _saturation_lambda(lambdas, divs)
101
+        # max is at λ=0 → smallest λ where divs ≥ 0.9*0.9=0.81 is λ=0 itself.
102
+        assert sat == 0.0
103
+        # Curve is strictly decreasing; the monotonic-non-decreasing check
104
+        # on divs[:0+1] is trivially satisfied (length 1), so we classify
105
+        # as "found" — the probe's overshoot / linearity checks pick up
106
+        # the real pathology here.
107
+        assert reason == "found"
108
+
109
+
110
+class TestProbeVerdictPropagatesSaturationReason:
111
+    """C8 + B3 test side: the probe's ``evidence["saturation_reason"]``
112
+    reflects the helper's return value for each curve shape. We force
113
+    specific curves by monkeypatching the probe's ``divergence()`` call
114
+    so the tests are deterministic and cheap."""
115
+
116
+    def _run_with_curve(self, divs_by_lambda: list[float]) -> dict:
117
+        from dlm_sway.probes import adapter_ablation as ab_mod
118
+
119
+        lambdas = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25]
120
+        assert len(lambdas) == len(divs_by_lambda)
121
+        call_count = {"i": 0}
122
+
123
+        def _fake_divergence(a, b, *, kind):  # noqa: ANN001 — test-local
124
+            del a, b, kind
125
+            # One prompt × len(lambdas) invocations; round-robin the curve.
126
+            idx = call_count["i"] % len(lambdas)
127
+            call_count["i"] += 1
128
+            return divs_by_lambda[idx]
129
+
130
+        backend = _diverging_backend()
131
+        probe, spec = build_probe(
132
+            {
133
+                "name": "abl",
134
+                "kind": "adapter_ablation",
135
+                "prompts": ["q1"],
136
+                "lambdas": lambdas,
137
+                "assert_linearity_gte": 0.0,
138
+                "assert_overshoot_gte": 0.0,
139
+            }
140
+        )
141
+        ctx = RunContext(backend=backend)
142
+        mp = pytest.MonkeyPatch()
143
+        mp.setattr(ab_mod, "divergence", _fake_divergence)
144
+        try:
145
+            result = probe.run(spec, ctx)
146
+        finally:
147
+            mp.undo()
148
+        return dict(result.evidence)
149
+
150
+    def test_flat_curve_surfaces_flat_reason(self) -> None:
151
+        ev = self._run_with_curve([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
152
+        assert ev["saturation_reason"] == "flat_curve"
153
+        assert ev["saturation_lambda"] is None
154
+
155
+    def test_healthy_monotonic_curve_found(self) -> None:
156
+        ev = self._run_with_curve([0.0, 0.3, 0.6, 0.85, 0.95, 1.0])
157
+        assert ev["saturation_reason"] == "found"
158
+        assert ev["saturation_lambda"] is not None
159
+
160
+    def test_overshoot_dip_still_found_via_max_reference(self) -> None:
161
+        """B3 fix: curve peaks at λ=0.75, dips at λ=1.0, recovers at 1.25."""
162
+        ev = self._run_with_curve([0.0, 0.3, 0.6, 0.95, 0.7, 1.0])
163
+        # max=1.0 at λ=1.25 → 0.9*1.0 = 0.9 → smallest λ where div ≥ 0.9
164
+        # is λ=0.75 (div=0.95). Monotonic through 0.75 → "found".
165
+        assert ev["saturation_reason"] == "found"
166
+        assert ev["saturation_lambda"] == 0.75
167
+
168
+    def test_non_monotonic_before_saturation(self) -> None:
169
+        ev = self._run_with_curve([0.0, 0.6, 0.4, 0.95, 1.0, 1.05])
170
+        assert ev["saturation_reason"] == "non_monotonic"
171
+        assert ev["saturation_lambda"] == 0.75
172
+
93173
 
94174
 def _diverging_backend() -> DummyDifferentialBackend:
95175
     """Backend where base ≠ ft at a few prompts; distributions interpolate