@@ -7,6 +7,7 @@ the full probe path without loading a real model. |
| 7 | 7 | from __future__ import annotations |
| 8 | 8 | |
| 9 | 9 | import numpy as np |
| 10 | +import pytest |
| 10 | 11 | |
| 11 | 12 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 12 | 13 | from dlm_sway.core.result import Verdict |
@@ -90,6 +91,85 @@ class TestShapeMetrics: |
| 90 | 91 | assert sat is None |
| 91 | 92 | assert reason == "flat_curve" |
| 92 | 93 | |
| 94 | + def test_saturation_monotonically_decreasing(self) -> None: |
| 95 | + """A curve that goes *down* with λ — adapter is anti-correlated |
| 96 | + with its own effect (the degenerate case where lam=1 is closer |
| 97 | + to base than lam=0).""" |
| 98 | + lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64) |
| 99 | + divs = np.asarray([0.9, 0.6, 0.3, 0.1], dtype=np.float64) |
| 100 | + sat, reason = _saturation_lambda(lambdas, divs) |
| 101 | + # max is at λ=0 → smallest λ where divs ≥ 0.9*0.9=0.81 is λ=0 itself. |
| 102 | + assert sat == 0.0 |
| 103 | + # Curve is strictly decreasing; the monotonic-non-decreasing check |
| 104 | + # on divs[:0+1] is trivially satisfied (length 1), so we classify |
| 105 | + # as "found" — the probe's overshoot / linearity checks pick up |
| 106 | + # the real pathology here. |
| 107 | + assert reason == "found" |
| 108 | + |
| 109 | + |
| 110 | +class TestProbeVerdictPropagatesSaturationReason: |
| 111 | + """C8 + B3 test side: the probe's ``evidence["saturation_reason"]`` |
| 112 | + reflects the helper's return value for each curve shape. We force |
| 113 | + specific curves by monkeypatching the probe's ``divergence()`` call |
| 114 | + so the tests are deterministic and cheap.""" |
| 115 | + |
| 116 | + def _run_with_curve(self, divs_by_lambda: list[float]) -> dict: |
| 117 | + from dlm_sway.probes import adapter_ablation as ab_mod |
| 118 | + |
| 119 | + lambdas = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25] |
| 120 | + assert len(lambdas) == len(divs_by_lambda) |
| 121 | + call_count = {"i": 0} |
| 122 | + |
| 123 | + def _fake_divergence(a, b, *, kind): # noqa: ANN001 — test-local |
| 124 | + del a, b, kind |
| 125 | + # One prompt × len(lambdas) invocations; round-robin the curve. |
| 126 | + idx = call_count["i"] % len(lambdas) |
| 127 | + call_count["i"] += 1 |
| 128 | + return divs_by_lambda[idx] |
| 129 | + |
| 130 | + backend = _diverging_backend() |
| 131 | + probe, spec = build_probe( |
| 132 | + { |
| 133 | + "name": "abl", |
| 134 | + "kind": "adapter_ablation", |
| 135 | + "prompts": ["q1"], |
| 136 | + "lambdas": lambdas, |
| 137 | + "assert_linearity_gte": 0.0, |
| 138 | + "assert_overshoot_gte": 0.0, |
| 139 | + } |
| 140 | + ) |
| 141 | + ctx = RunContext(backend=backend) |
| 142 | + mp = pytest.MonkeyPatch() |
| 143 | + mp.setattr(ab_mod, "divergence", _fake_divergence) |
| 144 | + try: |
| 145 | + result = probe.run(spec, ctx) |
| 146 | + finally: |
| 147 | + mp.undo() |
| 148 | + return dict(result.evidence) |
| 149 | + |
| 150 | + def test_flat_curve_surfaces_flat_reason(self) -> None: |
| 151 | + ev = self._run_with_curve([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) |
| 152 | + assert ev["saturation_reason"] == "flat_curve" |
| 153 | + assert ev["saturation_lambda"] is None |
| 154 | + |
| 155 | + def test_healthy_monotonic_curve_found(self) -> None: |
| 156 | + ev = self._run_with_curve([0.0, 0.3, 0.6, 0.85, 0.95, 1.0]) |
| 157 | + assert ev["saturation_reason"] == "found" |
| 158 | + assert ev["saturation_lambda"] is not None |
| 159 | + |
| 160 | + def test_overshoot_dip_still_found_via_max_reference(self) -> None: |
| 161 | + """B3 fix: curve peaks at λ=0.75, dips at λ=1.0, recovers at 1.25.""" |
| 162 | + ev = self._run_with_curve([0.0, 0.3, 0.6, 0.95, 0.7, 1.0]) |
| 163 | + # max=1.0 at λ=1.25 → 0.9*1.0 = 0.9 → smallest λ where div ≥ 0.9 |
| 164 | + # is λ=0.75 (div=0.95). Monotonic through 0.75 → "found". |
| 165 | + assert ev["saturation_reason"] == "found" |
| 166 | + assert ev["saturation_lambda"] == 0.75 |
| 167 | + |
| 168 | + def test_non_monotonic_before_saturation(self) -> None: |
| 169 | + ev = self._run_with_curve([0.0, 0.6, 0.4, 0.95, 1.0, 1.05]) |
| 170 | + assert ev["saturation_reason"] == "non_monotonic" |
| 171 | + assert ev["saturation_lambda"] == 0.75 |
| 172 | + |
| 93 | 173 | |
| 94 | 174 | def _diverging_backend() -> DummyDifferentialBackend: |
| 95 | 175 | """Backend where base ≠ ft at a few prompts; distributions interpolate |