| 1 | """Tests for :mod:`dlm_sway.probes.adapter_ablation`. |
| 2 | |
| 3 | Uses the dummy backend's lam-interpolation implementation to exercise |
| 4 | the full probe path without loading a real model. |
| 5 | """ |
| 6 | |
| 7 | from __future__ import annotations |
| 8 | |
| 9 | import numpy as np |
| 10 | import pytest |
| 11 | |
| 12 | from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses |
| 13 | from dlm_sway.core.result import Verdict |
| 14 | from dlm_sway.core.scoring import ScalableDifferentialBackend, TokenDist |
| 15 | from dlm_sway.probes.adapter_ablation import ( |
| 16 | _overshoot, |
| 17 | _r_squared, |
| 18 | _saturation_lambda, |
| 19 | ) |
| 20 | from dlm_sway.probes.base import RunContext, build_probe |
| 21 | |
| 22 | |
| 23 | class TestShapeMetrics: |
| 24 | def test_r_squared_perfect_linear(self) -> None: |
| 25 | x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64) |
| 26 | y = 2 * x + 0.1 |
| 27 | assert _r_squared(x, y) > 0.99 |
| 28 | |
| 29 | def test_r_squared_zero_slope_defined(self) -> None: |
| 30 | x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64) |
| 31 | y = np.zeros_like(x) |
| 32 | # Flat y → ss_tot = 0 → defined as 1.0 (perfect fit). |
| 33 | assert _r_squared(x, y) == 1.0 |
| 34 | |
| 35 | def test_saturation_lambda_expected(self) -> None: |
| 36 | lambdas = np.asarray([0.0, 0.25, 0.5, 0.75, 1.0], dtype=np.float64) |
| 37 | divs = np.asarray([0.0, 0.5, 0.8, 0.95, 1.0], dtype=np.float64) |
| 38 | sat, reason = _saturation_lambda(lambdas, divs) |
| 39 | assert sat == 0.75 # 0.95 / 1.0 = 0.95 ≥ 0.9 |
| 40 | assert reason == "found" |
| 41 | |
| 42 | def test_overshoot_recovered(self) -> None: |
| 43 | lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64) |
| 44 | divs = np.asarray([0.0, 0.5, 1.0, 1.15], dtype=np.float64) |
| 45 | assert _overshoot(lambdas, divs) == 1.15 |
| 46 | |
| 47 | def test_saturation_flat_curve(self) -> None: |
| 48 | """Adapter that produces no signal — every divergence is zero.""" |
| 49 | lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64) |
| 50 | divs = np.asarray([0.0, 0.0, 0.0, 0.0], dtype=np.float64) |
| 51 | sat, reason = _saturation_lambda(lambdas, divs) |
| 52 | assert sat is None |
| 53 | assert reason == "flat_curve" |
| 54 | |
| 55 | def test_saturation_overshoot_peak_above_one(self) -> None: |
| 56 | """B3 fix: curve peaks at λ=1.25 (0.95 reaches 90% of max=1.0). |
| 57 | Saturation now picks the smallest λ where divs ≥ 0.9 × 1.0 = 0.9 — |
| 58 | namely λ=0.5 with div=0.95. Pre-fix this returned ``None`` because |
| 59 | the search was bounded at λ ≤ 1.0 with `div(λ=1)` as reference.""" |
| 60 | lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64) |
| 61 | divs = np.asarray([0.0, 0.95, 0.7, 1.0], dtype=np.float64) |
| 62 | sat, reason = _saturation_lambda(lambdas, divs) |
| 63 | assert sat == 0.5 |
| 64 | # Curve was monotonic up to and including the saturation point; |
| 65 | # the dip happened *after* it, which the per-saturation-point |
| 66 | # monotonicity check accepts. |
| 67 | assert reason == "found" |
| 68 | |
| 69 | def test_saturation_non_monotonic_before_saturation(self) -> None: |
| 70 | """A curve that zigzags up to the saturation point gets a WARN.""" |
| 71 | lambdas = np.asarray([0.0, 0.25, 0.5, 0.75, 1.0], dtype=np.float64) |
| 72 | # 0 → 0.6 → 0.4 (dip) → 0.95 (90% of max=1.0) |
| 73 | divs = np.asarray([0.0, 0.6, 0.4, 0.95, 1.0], dtype=np.float64) |
| 74 | sat, reason = _saturation_lambda(lambdas, divs) |
| 75 | assert sat == 0.75 |
| 76 | assert reason == "non_monotonic" |
| 77 | |
| 78 | def test_saturation_below_floor_with_negative_max(self) -> None: |
| 79 | """Pathological negative-only curve (shouldn't happen with JS but |
| 80 | guards against numerical drift / future probe variants).""" |
| 81 | lambdas = np.asarray([0.0, 0.5, 1.0], dtype=np.float64) |
| 82 | divs = np.asarray([-0.5, -0.3, -0.1], dtype=np.float64) |
| 83 | sat, reason = _saturation_lambda(lambdas, divs) |
| 84 | assert sat is None |
| 85 | assert reason == "flat_curve" # max ≤ 0 → flat by definition |
| 86 | |
| 87 | def test_saturation_nan_max_treated_as_flat(self) -> None: |
| 88 | lambdas = np.asarray([0.0, 0.5, 1.0], dtype=np.float64) |
| 89 | divs = np.asarray([0.1, np.nan, 0.5], dtype=np.float64) |
| 90 | sat, reason = _saturation_lambda(lambdas, divs) |
| 91 | assert sat is None |
| 92 | assert reason == "flat_curve" |
| 93 | |
| 94 | def test_saturation_monotonically_decreasing(self) -> None: |
| 95 | """A curve that goes *down* with λ — adapter is anti-correlated |
| 96 | with its own effect (the degenerate case where lam=1 is closer |
| 97 | to base than lam=0).""" |
| 98 | lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64) |
| 99 | divs = np.asarray([0.9, 0.6, 0.3, 0.1], dtype=np.float64) |
| 100 | sat, reason = _saturation_lambda(lambdas, divs) |
| 101 | # max is at λ=0 → smallest λ where divs ≥ 0.9*0.9=0.81 is λ=0 itself. |
| 102 | assert sat == 0.0 |
| 103 | # Curve is strictly decreasing; the monotonic-non-decreasing check |
| 104 | # on divs[:0+1] is trivially satisfied (length 1), so we classify |
| 105 | # as "found" — the probe's overshoot / linearity checks pick up |
| 106 | # the real pathology here. |
| 107 | assert reason == "found" |
| 108 | |
| 109 | |
| 110 | class TestProbeVerdictPropagatesSaturationReason: |
| 111 | """C8 + B3 test side: the probe's ``evidence["saturation_reason"]`` |
| 112 | reflects the helper's return value for each curve shape. We force |
| 113 | specific curves by monkeypatching the probe's ``divergence()`` call |
| 114 | so the tests are deterministic and cheap.""" |
| 115 | |
| 116 | def _run_with_curve(self, divs_by_lambda: list[float]) -> dict: |
| 117 | from dlm_sway.probes import adapter_ablation as ab_mod |
| 118 | |
| 119 | lambdas = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25] |
| 120 | assert len(lambdas) == len(divs_by_lambda) |
| 121 | call_count = {"i": 0} |
| 122 | |
| 123 | def _fake_divergence(a, b, *, kind): # noqa: ANN001 — test-local |
| 124 | del a, b, kind |
| 125 | # One prompt × len(lambdas) invocations; round-robin the curve. |
| 126 | idx = call_count["i"] % len(lambdas) |
| 127 | call_count["i"] += 1 |
| 128 | return divs_by_lambda[idx] |
| 129 | |
| 130 | backend = _diverging_backend() |
| 131 | probe, spec = build_probe( |
| 132 | { |
| 133 | "name": "abl", |
| 134 | "kind": "adapter_ablation", |
| 135 | "prompts": ["q1"], |
| 136 | "lambdas": lambdas, |
| 137 | "assert_linearity_gte": 0.0, |
| 138 | "assert_overshoot_gte": 0.0, |
| 139 | } |
| 140 | ) |
| 141 | ctx = RunContext(backend=backend) |
| 142 | mp = pytest.MonkeyPatch() |
| 143 | mp.setattr(ab_mod, "divergence", _fake_divergence) |
| 144 | try: |
| 145 | result = probe.run(spec, ctx) |
| 146 | finally: |
| 147 | mp.undo() |
| 148 | return dict(result.evidence) |
| 149 | |
| 150 | def test_flat_curve_surfaces_flat_reason(self) -> None: |
| 151 | ev = self._run_with_curve([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) |
| 152 | assert ev["saturation_reason"] == "flat_curve" |
| 153 | assert ev["saturation_lambda"] is None |
| 154 | |
| 155 | def test_healthy_monotonic_curve_found(self) -> None: |
| 156 | ev = self._run_with_curve([0.0, 0.3, 0.6, 0.85, 0.95, 1.0]) |
| 157 | assert ev["saturation_reason"] == "found" |
| 158 | assert ev["saturation_lambda"] is not None |
| 159 | |
| 160 | def test_overshoot_dip_still_found_via_max_reference(self) -> None: |
| 161 | """B3 fix: curve peaks at λ=0.75, dips at λ=1.0, recovers at 1.25.""" |
| 162 | ev = self._run_with_curve([0.0, 0.3, 0.6, 0.95, 0.7, 1.0]) |
| 163 | # max=1.0 at λ=1.25 → 0.9*1.0 = 0.9 → smallest λ where div ≥ 0.9 |
| 164 | # is λ=0.75 (div=0.95). Monotonic through 0.75 → "found". |
| 165 | assert ev["saturation_reason"] == "found" |
| 166 | assert ev["saturation_lambda"] == 0.75 |
| 167 | |
| 168 | def test_non_monotonic_before_saturation(self) -> None: |
| 169 | ev = self._run_with_curve([0.0, 0.6, 0.4, 0.95, 1.0, 1.05]) |
| 170 | assert ev["saturation_reason"] == "non_monotonic" |
| 171 | assert ev["saturation_lambda"] == 0.75 |
| 172 | |
| 173 | |
| 174 | def _diverging_backend() -> DummyDifferentialBackend: |
| 175 | """Backend where base ≠ ft at a few prompts; distributions interpolate |
| 176 | smoothly under lam-blending in DummyDifferentialBackend.as_scaled_adapter.""" |
| 177 | base = DummyResponses( |
| 178 | token_dists={ |
| 179 | "q1": TokenDist( |
| 180 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 181 | logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)), |
| 182 | vocab_size=100, |
| 183 | ), |
| 184 | "q2": TokenDist( |
| 185 | token_ids=np.array([5, 6], dtype=np.int64), |
| 186 | logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)), |
| 187 | vocab_size=100, |
| 188 | ), |
| 189 | } |
| 190 | ) |
| 191 | ft = DummyResponses( |
| 192 | token_dists={ |
| 193 | "q1": TokenDist( |
| 194 | token_ids=np.array([1, 2, 3], dtype=np.int64), |
| 195 | logprobs=np.log(np.array([0.2, 0.4, 0.4], dtype=np.float32)), |
| 196 | vocab_size=100, |
| 197 | ), |
| 198 | "q2": TokenDist( |
| 199 | token_ids=np.array([5, 6], dtype=np.int64), |
| 200 | logprobs=np.log(np.array([0.3, 0.7], dtype=np.float32)), |
| 201 | vocab_size=100, |
| 202 | ), |
| 203 | } |
| 204 | ) |
| 205 | return DummyDifferentialBackend(base=base, ft=ft) |
| 206 | |
| 207 | |
| 208 | class TestProbe: |
| 209 | def test_backend_implements_scalable_protocol(self) -> None: |
| 210 | backend = _diverging_backend() |
| 211 | assert isinstance(backend, ScalableDifferentialBackend) |
| 212 | |
| 213 | def test_probe_runs_and_emits_shape_metrics(self) -> None: |
| 214 | probe, spec = build_probe( |
| 215 | { |
| 216 | "name": "abl", |
| 217 | "kind": "adapter_ablation", |
| 218 | "prompts": ["q1", "q2"], |
| 219 | "lambdas": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25], |
| 220 | # Very permissive to tolerate the log-space blend of a |
| 221 | # tiny synthetic fixture. |
| 222 | "assert_linearity_gte": 0.3, |
| 223 | "assert_overshoot_gte": 1.0, |
| 224 | } |
| 225 | ) |
| 226 | ctx = RunContext(backend=_diverging_backend()) |
| 227 | result = probe.run(spec, ctx) |
| 228 | assert result.verdict in (Verdict.PASS, Verdict.FAIL) |
| 229 | assert "lambdas" in result.evidence |
| 230 | assert "mean_divergence_per_lambda" in result.evidence |
| 231 | assert len(result.evidence["mean_divergence_per_lambda"]) == 6 |
| 232 | # Divergence should increase as λ grows from 0 toward ft. |
| 233 | divs = result.evidence["mean_divergence_per_lambda"] |
| 234 | # λ=0 → 0 divergence from itself. λ>0 should be non-decreasing |
| 235 | # for the bulk of the curve. |
| 236 | assert divs[-2] >= divs[0] |
| 237 | |
| 238 | def test_skip_when_backend_not_scalable(self) -> None: |
| 239 | class _NonScalable: |
| 240 | def as_base(self): # noqa: ANN202 |
| 241 | raise NotImplementedError |
| 242 | |
| 243 | def as_finetuned(self): # noqa: ANN202 |
| 244 | raise NotImplementedError |
| 245 | |
| 246 | probe, spec = build_probe( |
| 247 | { |
| 248 | "name": "abl", |
| 249 | "kind": "adapter_ablation", |
| 250 | "prompts": ["q1"], |
| 251 | } |
| 252 | ) |
| 253 | ctx = RunContext(backend=_NonScalable()) # type: ignore[arg-type] |
| 254 | result = probe.run(spec, ctx) |
| 255 | assert result.verdict == Verdict.SKIP |
| 256 | assert "ScalableDifferentialBackend" in result.message |
| 257 | |
| 258 | def test_error_on_empty_prompts(self) -> None: |
| 259 | backend = _diverging_backend() |
| 260 | probe, spec = build_probe({"name": "abl", "kind": "adapter_ablation", "prompts": []}) |
| 261 | ctx = RunContext(backend=backend) |
| 262 | result = probe.run(spec, ctx) |
| 263 | assert result.verdict == Verdict.ERROR |