Python · 11163 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.probes.adapter_ablation`.
2
3 Uses the dummy backend's lam-interpolation implementation to exercise
4 the full probe path without loading a real model.
5 """
6
7 from __future__ import annotations
8
9 import numpy as np
10 import pytest
11
12 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
13 from dlm_sway.core.result import Verdict
14 from dlm_sway.core.scoring import ScalableDifferentialBackend, TokenDist
15 from dlm_sway.probes.adapter_ablation import (
16 _overshoot,
17 _r_squared,
18 _saturation_lambda,
19 )
20 from dlm_sway.probes.base import RunContext, build_probe
21
22
23 class TestShapeMetrics:
24 def test_r_squared_perfect_linear(self) -> None:
25 x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
26 y = 2 * x + 0.1
27 assert _r_squared(x, y) > 0.99
28
29 def test_r_squared_zero_slope_defined(self) -> None:
30 x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
31 y = np.zeros_like(x)
32 # Flat y → ss_tot = 0 → defined as 1.0 (perfect fit).
33 assert _r_squared(x, y) == 1.0
34
35 def test_saturation_lambda_expected(self) -> None:
36 lambdas = np.asarray([0.0, 0.25, 0.5, 0.75, 1.0], dtype=np.float64)
37 divs = np.asarray([0.0, 0.5, 0.8, 0.95, 1.0], dtype=np.float64)
38 sat, reason = _saturation_lambda(lambdas, divs)
39 assert sat == 0.75 # 0.95 / 1.0 = 0.95 ≥ 0.9
40 assert reason == "found"
41
42 def test_overshoot_recovered(self) -> None:
43 lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
44 divs = np.asarray([0.0, 0.5, 1.0, 1.15], dtype=np.float64)
45 assert _overshoot(lambdas, divs) == 1.15
46
47 def test_saturation_flat_curve(self) -> None:
48 """Adapter that produces no signal — every divergence is zero."""
49 lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
50 divs = np.asarray([0.0, 0.0, 0.0, 0.0], dtype=np.float64)
51 sat, reason = _saturation_lambda(lambdas, divs)
52 assert sat is None
53 assert reason == "flat_curve"
54
55 def test_saturation_overshoot_peak_above_one(self) -> None:
56 """B3 fix: curve peaks at λ=1.25 (0.95 reaches 90% of max=1.0).
57 Saturation now picks the smallest λ where divs ≥ 0.9 × 1.0 = 0.9 —
58 namely λ=0.5 with div=0.95. Pre-fix this returned ``None`` because
59 the search was bounded at λ ≤ 1.0 with `div(λ=1)` as reference."""
60 lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
61 divs = np.asarray([0.0, 0.95, 0.7, 1.0], dtype=np.float64)
62 sat, reason = _saturation_lambda(lambdas, divs)
63 assert sat == 0.5
64 # Curve was monotonic up to and including the saturation point;
65 # the dip happened *after* it, which the per-saturation-point
66 # monotonicity check accepts.
67 assert reason == "found"
68
69 def test_saturation_non_monotonic_before_saturation(self) -> None:
70 """A curve that zigzags up to the saturation point gets a WARN."""
71 lambdas = np.asarray([0.0, 0.25, 0.5, 0.75, 1.0], dtype=np.float64)
72 # 0 → 0.6 → 0.4 (dip) → 0.95 (90% of max=1.0)
73 divs = np.asarray([0.0, 0.6, 0.4, 0.95, 1.0], dtype=np.float64)
74 sat, reason = _saturation_lambda(lambdas, divs)
75 assert sat == 0.75
76 assert reason == "non_monotonic"
77
78 def test_saturation_below_floor_with_negative_max(self) -> None:
79 """Pathological negative-only curve (shouldn't happen with JS but
80 guards against numerical drift / future probe variants)."""
81 lambdas = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
82 divs = np.asarray([-0.5, -0.3, -0.1], dtype=np.float64)
83 sat, reason = _saturation_lambda(lambdas, divs)
84 assert sat is None
85 assert reason == "flat_curve" # max ≤ 0 → flat by definition
86
87 def test_saturation_nan_max_treated_as_flat(self) -> None:
88 lambdas = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
89 divs = np.asarray([0.1, np.nan, 0.5], dtype=np.float64)
90 sat, reason = _saturation_lambda(lambdas, divs)
91 assert sat is None
92 assert reason == "flat_curve"
93
94 def test_saturation_monotonically_decreasing(self) -> None:
95 """A curve that goes *down* with λ — adapter is anti-correlated
96 with its own effect (the degenerate case where lam=1 is closer
97 to base than lam=0)."""
98 lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
99 divs = np.asarray([0.9, 0.6, 0.3, 0.1], dtype=np.float64)
100 sat, reason = _saturation_lambda(lambdas, divs)
101 # max is at λ=0 → smallest λ where divs ≥ 0.9*0.9=0.81 is λ=0 itself.
102 assert sat == 0.0
103 # Curve is strictly decreasing; the monotonic-non-decreasing check
104 # on divs[:0+1] is trivially satisfied (length 1), so we classify
105 # as "found" — the probe's overshoot / linearity checks pick up
106 # the real pathology here.
107 assert reason == "found"
108
109
110 class TestProbeVerdictPropagatesSaturationReason:
111 """C8 + B3 test side: the probe's ``evidence["saturation_reason"]``
112 reflects the helper's return value for each curve shape. We force
113 specific curves by monkeypatching the probe's ``divergence()`` call
114 so the tests are deterministic and cheap."""
115
116 def _run_with_curve(self, divs_by_lambda: list[float]) -> dict:
117 from dlm_sway.probes import adapter_ablation as ab_mod
118
119 lambdas = [0.0, 0.25, 0.5, 0.75, 1.0, 1.25]
120 assert len(lambdas) == len(divs_by_lambda)
121 call_count = {"i": 0}
122
123 def _fake_divergence(a, b, *, kind): # noqa: ANN001 — test-local
124 del a, b, kind
125 # One prompt × len(lambdas) invocations; round-robin the curve.
126 idx = call_count["i"] % len(lambdas)
127 call_count["i"] += 1
128 return divs_by_lambda[idx]
129
130 backend = _diverging_backend()
131 probe, spec = build_probe(
132 {
133 "name": "abl",
134 "kind": "adapter_ablation",
135 "prompts": ["q1"],
136 "lambdas": lambdas,
137 "assert_linearity_gte": 0.0,
138 "assert_overshoot_gte": 0.0,
139 }
140 )
141 ctx = RunContext(backend=backend)
142 mp = pytest.MonkeyPatch()
143 mp.setattr(ab_mod, "divergence", _fake_divergence)
144 try:
145 result = probe.run(spec, ctx)
146 finally:
147 mp.undo()
148 return dict(result.evidence)
149
150 def test_flat_curve_surfaces_flat_reason(self) -> None:
151 ev = self._run_with_curve([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
152 assert ev["saturation_reason"] == "flat_curve"
153 assert ev["saturation_lambda"] is None
154
155 def test_healthy_monotonic_curve_found(self) -> None:
156 ev = self._run_with_curve([0.0, 0.3, 0.6, 0.85, 0.95, 1.0])
157 assert ev["saturation_reason"] == "found"
158 assert ev["saturation_lambda"] is not None
159
160 def test_overshoot_dip_still_found_via_max_reference(self) -> None:
161 """B3 fix: curve peaks at λ=0.75, dips at λ=1.0, recovers at 1.25."""
162 ev = self._run_with_curve([0.0, 0.3, 0.6, 0.95, 0.7, 1.0])
163 # max=1.0 at λ=1.25 → 0.9*1.0 = 0.9 → smallest λ where div ≥ 0.9
164 # is λ=0.75 (div=0.95). Monotonic through 0.75 → "found".
165 assert ev["saturation_reason"] == "found"
166 assert ev["saturation_lambda"] == 0.75
167
168 def test_non_monotonic_before_saturation(self) -> None:
169 ev = self._run_with_curve([0.0, 0.6, 0.4, 0.95, 1.0, 1.05])
170 assert ev["saturation_reason"] == "non_monotonic"
171 assert ev["saturation_lambda"] == 0.75
172
173
174 def _diverging_backend() -> DummyDifferentialBackend:
175 """Backend where base ≠ ft at a few prompts; distributions interpolate
176 smoothly under lam-blending in DummyDifferentialBackend.as_scaled_adapter."""
177 base = DummyResponses(
178 token_dists={
179 "q1": TokenDist(
180 token_ids=np.array([1, 2, 3], dtype=np.int64),
181 logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)),
182 vocab_size=100,
183 ),
184 "q2": TokenDist(
185 token_ids=np.array([5, 6], dtype=np.int64),
186 logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)),
187 vocab_size=100,
188 ),
189 }
190 )
191 ft = DummyResponses(
192 token_dists={
193 "q1": TokenDist(
194 token_ids=np.array([1, 2, 3], dtype=np.int64),
195 logprobs=np.log(np.array([0.2, 0.4, 0.4], dtype=np.float32)),
196 vocab_size=100,
197 ),
198 "q2": TokenDist(
199 token_ids=np.array([5, 6], dtype=np.int64),
200 logprobs=np.log(np.array([0.3, 0.7], dtype=np.float32)),
201 vocab_size=100,
202 ),
203 }
204 )
205 return DummyDifferentialBackend(base=base, ft=ft)
206
207
208 class TestProbe:
209 def test_backend_implements_scalable_protocol(self) -> None:
210 backend = _diverging_backend()
211 assert isinstance(backend, ScalableDifferentialBackend)
212
213 def test_probe_runs_and_emits_shape_metrics(self) -> None:
214 probe, spec = build_probe(
215 {
216 "name": "abl",
217 "kind": "adapter_ablation",
218 "prompts": ["q1", "q2"],
219 "lambdas": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25],
220 # Very permissive to tolerate the log-space blend of a
221 # tiny synthetic fixture.
222 "assert_linearity_gte": 0.3,
223 "assert_overshoot_gte": 1.0,
224 }
225 )
226 ctx = RunContext(backend=_diverging_backend())
227 result = probe.run(spec, ctx)
228 assert result.verdict in (Verdict.PASS, Verdict.FAIL)
229 assert "lambdas" in result.evidence
230 assert "mean_divergence_per_lambda" in result.evidence
231 assert len(result.evidence["mean_divergence_per_lambda"]) == 6
232 # Divergence should increase as λ grows from 0 toward ft.
233 divs = result.evidence["mean_divergence_per_lambda"]
234 # λ=0 → 0 divergence from itself. λ>0 should be non-decreasing
235 # for the bulk of the curve.
236 assert divs[-2] >= divs[0]
237
238 def test_skip_when_backend_not_scalable(self) -> None:
239 class _NonScalable:
240 def as_base(self): # noqa: ANN202
241 raise NotImplementedError
242
243 def as_finetuned(self): # noqa: ANN202
244 raise NotImplementedError
245
246 probe, spec = build_probe(
247 {
248 "name": "abl",
249 "kind": "adapter_ablation",
250 "prompts": ["q1"],
251 }
252 )
253 ctx = RunContext(backend=_NonScalable()) # type: ignore[arg-type]
254 result = probe.run(spec, ctx)
255 assert result.verdict == Verdict.SKIP
256 assert "ScalableDifferentialBackend" in result.message
257
258 def test_error_on_empty_prompts(self) -> None:
259 backend = _diverging_backend()
260 probe, spec = build_probe({"name": "abl", "kind": "adapter_ablation", "prompts": []})
261 ctx = RunContext(backend=backend)
262 result = probe.run(spec, ctx)
263 assert result.verdict == Verdict.ERROR