tenseleyflow/sway / e1af82b

Browse files

sway(probes): N2 adapter_ablation — the signature lambda-scaled KL curve

Authored by espadonne
SHA
e1af82b44c0a04a881cb6130c17c2d48ba38cfcb
Parents
7f806df
Tree
1061d18

2 changed files

StatusFile+-
A src/dlm_sway/probes/adapter_ablation.py 193 0
A tests/unit/test_probe_adapter_ablation.py 135 0
src/dlm_sway/probes/adapter_ablation.pyadded
@@ -0,0 +1,193 @@
1
+"""N2 AdapterAblation — the sway signature primitive.
2
+
3
+Scales the LoRA additive term by λ ∈ {0, 0.25, 0.5, 0.75, 1.0, 1.25}
4
+and measures the mean divergence from the base distribution at each
5
+step. Fits a monotonic response curve; reports three shape metrics:
6
+
7
+- **linearity**: R² of a linear fit on ``(λ, mean_div)``. High means
8
+  the adapter's effect scales predictably; low means it's "all or
9
+  nothing" (degenerate).
10
+- **saturation_lambda**: the smallest λ at which divergence reaches
11
+  90% of the λ=1 value. Too low (<0.3) means the adapter fires at
12
+  partial strength — fragile. Too high (>1.0) means the adapter is
13
+  under-trained.
14
+- **overshoot**: divergence at λ=1.25 divided by λ=1.0. >1.05 is the
15
+  healthy "pushing past 1 still moves the model" signal. An overshoot
16
+  below 1.0 suggests collapse.
17
+
18
+This is the single novel primitive that no generic eval harness
19
+provides — sway's position next to the adapter math makes it possible.
20
+
21
+Requires the backend to implement
22
+:class:`~dlm_sway.core.scoring.ScalableDifferentialBackend`. Probes
23
+SKIP gracefully on backends that don't.
24
+"""
25
+
26
+from __future__ import annotations
27
+
28
+from typing import Literal
29
+
30
+import numpy as np
31
+from pydantic import Field
32
+
33
+from dlm_sway.core.result import ProbeResult, Verdict
34
+from dlm_sway.core.scoring import ScalableDifferentialBackend
35
+from dlm_sway.probes._divergence import Divergence, divergence
36
+from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
37
+
38
+
39
+class AdapterAblationSpec(ProbeSpec):
40
+    kind: Literal["adapter_ablation"] = "adapter_ablation"
41
+    prompts: list[str] = Field(default_factory=list)
42
+    lambdas: list[float] = Field(
43
+        default_factory=lambda: [0.0, 0.25, 0.5, 0.75, 1.0, 1.25],
44
+        min_length=3,
45
+    )
46
+    divergence: Divergence = "js"
47
+    top_k: int | None = None
48
+    assert_linearity_gte: float = 0.85
49
+    assert_saturation_between: tuple[float, float] = (0.3, 1.05)
50
+    assert_overshoot_gte: float = 1.02
51
+
52
+
53
+class AdapterAblationProbe(Probe):
54
+    kind = "adapter_ablation"
55
+    spec_cls = AdapterAblationSpec
56
+    category = "ablation"
57
+
58
+    def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
59
+        assert isinstance(spec, AdapterAblationSpec)
60
+        if not spec.prompts:
61
+            return ProbeResult(
62
+                name=spec.name,
63
+                kind=spec.kind,
64
+                verdict=Verdict.ERROR,
65
+                score=None,
66
+                message="no prompts provided",
67
+            )
68
+        if not isinstance(ctx.backend, ScalableDifferentialBackend):
69
+            return ProbeResult(
70
+                name=spec.name,
71
+                kind=spec.kind,
72
+                verdict=Verdict.SKIP,
73
+                score=None,
74
+                message=(
75
+                    "backend does not implement ScalableDifferentialBackend — "
76
+                    "adapter ablation requires LoRA-scale access"
77
+                ),
78
+            )
79
+
80
+        top_k = spec.top_k if spec.top_k is not None else ctx.top_k
81
+
82
+        # Reference distribution at λ=0 (adapter scaled to zero → base).
83
+        lam_zero = min(spec.lambdas)
84
+        per_lambda: list[float] = []
85
+        for lam in spec.lambdas:
86
+            divs_for_lam: list[float] = []
87
+            for prompt in spec.prompts:
88
+                with ctx.backend.as_scaled_adapter(lam_zero) as ref:
89
+                    ref_dist = ref.next_token_dist(prompt, top_k=top_k)
90
+                with ctx.backend.as_scaled_adapter(lam) as scaled:
91
+                    scaled_dist = scaled.next_token_dist(prompt, top_k=top_k)
92
+                divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence))
93
+            per_lambda.append(float(np.mean(divs_for_lam)))
94
+
95
+        lambdas_arr = np.asarray(spec.lambdas, dtype=np.float64)
96
+        divs_arr = np.asarray(per_lambda, dtype=np.float64)
97
+
98
+        linearity = _r_squared(lambdas_arr, divs_arr)
99
+        saturation_lambda = _saturation_lambda(lambdas_arr, divs_arr)
100
+        overshoot = _overshoot(lambdas_arr, divs_arr)
101
+
102
+        # Pass when all three shape metrics land in their healthy bands.
103
+        sat_lo, sat_hi = spec.assert_saturation_between
104
+        ok_lin = linearity >= spec.assert_linearity_gte
105
+        ok_sat = saturation_lambda is not None and sat_lo <= saturation_lambda <= sat_hi
106
+        ok_over = overshoot >= spec.assert_overshoot_gte
107
+        verdict = Verdict.PASS if (ok_lin and ok_sat and ok_over) else Verdict.FAIL
108
+
109
+        lin_score = max(0.0, min(1.0, linearity / max(spec.assert_linearity_gte, 1e-6)))
110
+        over_score = max(0.0, min(1.0, (overshoot - 1.0) / 0.2))
111
+        sat_score = 1.0 if ok_sat else 0.3
112
+        score = 0.4 * lin_score + 0.3 * sat_score + 0.3 * over_score
113
+
114
+        return ProbeResult(
115
+            name=spec.name,
116
+            kind=spec.kind,
117
+            verdict=verdict,
118
+            score=score,
119
+            raw=linearity,
120
+            evidence={
121
+                "lambdas": spec.lambdas,
122
+                "mean_divergence_per_lambda": per_lambda,
123
+                "linearity": linearity,
124
+                "saturation_lambda": saturation_lambda,
125
+                "overshoot": overshoot,
126
+                "passed_linearity": ok_lin,
127
+                "passed_saturation": ok_sat,
128
+                "passed_overshoot": ok_over,
129
+                "weight": spec.weight,
130
+            },
131
+            message=(
132
+                f"R²={linearity:.2f}, sat_λ={saturation_lambda:.2f} "
133
+                f"({'in' if ok_sat else 'out of'} band), overshoot={overshoot:.2f}"
134
+                if saturation_lambda is not None
135
+                else f"R²={linearity:.2f}, saturation undetected, overshoot={overshoot:.2f}"
136
+            ),
137
+        )
138
+
139
+
140
+def _r_squared(x: np.ndarray, y: np.ndarray) -> float:
141
+    """Coefficient of determination for a linear fit of ``y`` on ``x``."""
142
+    if x.size < 2:
143
+        return 0.0
144
+    xm = float(x.mean())
145
+    ym = float(y.mean())
146
+    denom = float(((x - xm) ** 2).sum())
147
+    if denom == 0.0:
148
+        return 0.0
149
+    slope = float(((x - xm) * (y - ym)).sum()) / denom
150
+    intercept = ym - slope * xm
151
+    y_pred = slope * x + intercept
152
+    ss_res = float(((y - y_pred) ** 2).sum())
153
+    ss_tot = float(((y - ym) ** 2).sum())
154
+    if ss_tot == 0.0:
155
+        return 1.0
156
+    return max(0.0, 1.0 - ss_res / ss_tot)
157
+
158
+
159
+def _saturation_lambda(lambdas: np.ndarray, divs: np.ndarray) -> float | None:
160
+    """Smallest λ ≤ 1.0 at which divergence reaches 90% of div(λ=1)."""
161
+    # Locate the index of λ=1.0 (or the closest entry ≤ 1.0).
162
+    candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0]
163
+    if candidates.size == 0:
164
+        # Fall back to the largest λ ≤ 1.0.
165
+        mask = lambdas <= 1.0
166
+        if not mask.any():
167
+            return None
168
+        idx1 = int(np.argmax(lambdas * mask))
169
+    else:
170
+        idx1 = int(candidates[0])
171
+    target = 0.9 * float(divs[idx1])
172
+    if target <= 0:
173
+        return None
174
+    for lam, d in zip(lambdas[: idx1 + 1], divs[: idx1 + 1], strict=False):
175
+        if d >= target:
176
+            return float(lam)
177
+    return None
178
+
179
+
180
+def _overshoot(lambdas: np.ndarray, divs: np.ndarray) -> float:
181
+    """``div(λ_max) / div(λ=1)``. Returns 1.0 if λ_max ≤ 1.0."""
182
+    idx_max = int(np.argmax(lambdas))
183
+    candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0]
184
+    if candidates.size == 0:
185
+        return 1.0
186
+    idx1 = int(candidates[0])
187
+    if idx_max == idx1:
188
+        return 1.0
189
+    d1 = float(divs[idx1])
190
+    dmax = float(divs[idx_max])
191
+    if d1 <= 0:
192
+        return 1.0
193
+    return dmax / d1
tests/unit/test_probe_adapter_ablation.pyadded
@@ -0,0 +1,135 @@
1
+"""Tests for :mod:`dlm_sway.probes.adapter_ablation`.
2
+
3
+Uses the dummy backend's lam-interpolation implementation to exercise
4
+the full probe path without loading a real model.
5
+"""
6
+
7
+from __future__ import annotations
8
+
9
+import numpy as np
10
+
11
+from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
12
+from dlm_sway.core.result import Verdict
13
+from dlm_sway.core.scoring import ScalableDifferentialBackend, TokenDist
14
+from dlm_sway.probes.adapter_ablation import (
15
+    _overshoot,
16
+    _r_squared,
17
+    _saturation_lambda,
18
+)
19
+from dlm_sway.probes.base import RunContext, build_probe
20
+
21
+
22
+class TestShapeMetrics:
23
+    def test_r_squared_perfect_linear(self) -> None:
24
+        x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
25
+        y = 2 * x + 0.1
26
+        assert _r_squared(x, y) > 0.99
27
+
28
+    def test_r_squared_zero_slope_defined(self) -> None:
29
+        x = np.asarray([0.0, 0.5, 1.0], dtype=np.float64)
30
+        y = np.zeros_like(x)
31
+        # Flat y → ss_tot = 0 → defined as 1.0 (perfect fit).
32
+        assert _r_squared(x, y) == 1.0
33
+
34
+    def test_saturation_lambda_expected(self) -> None:
35
+        lambdas = np.asarray([0.0, 0.25, 0.5, 0.75, 1.0], dtype=np.float64)
36
+        divs = np.asarray([0.0, 0.5, 0.8, 0.95, 1.0], dtype=np.float64)
37
+        sat = _saturation_lambda(lambdas, divs)
38
+        assert sat == 0.75  # 0.95 / 1.0 = 0.95 ≥ 0.9
39
+
40
+    def test_overshoot_recovered(self) -> None:
41
+        lambdas = np.asarray([0.0, 0.5, 1.0, 1.25], dtype=np.float64)
42
+        divs = np.asarray([0.0, 0.5, 1.0, 1.15], dtype=np.float64)
43
+        assert _overshoot(lambdas, divs) == 1.15
44
+
45
+
46
+def _diverging_backend() -> DummyDifferentialBackend:
47
+    """Backend where base ≠ ft at a few prompts; distributions interpolate
48
+    smoothly under lam-blending in DummyDifferentialBackend.as_scaled_adapter."""
49
+    base = DummyResponses(
50
+        token_dists={
51
+            "q1": TokenDist(
52
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
53
+                logprobs=np.log(np.array([0.9, 0.05, 0.05], dtype=np.float32)),
54
+                vocab_size=100,
55
+            ),
56
+            "q2": TokenDist(
57
+                token_ids=np.array([5, 6], dtype=np.int64),
58
+                logprobs=np.log(np.array([0.8, 0.2], dtype=np.float32)),
59
+                vocab_size=100,
60
+            ),
61
+        }
62
+    )
63
+    ft = DummyResponses(
64
+        token_dists={
65
+            "q1": TokenDist(
66
+                token_ids=np.array([1, 2, 3], dtype=np.int64),
67
+                logprobs=np.log(np.array([0.2, 0.4, 0.4], dtype=np.float32)),
68
+                vocab_size=100,
69
+            ),
70
+            "q2": TokenDist(
71
+                token_ids=np.array([5, 6], dtype=np.int64),
72
+                logprobs=np.log(np.array([0.3, 0.7], dtype=np.float32)),
73
+                vocab_size=100,
74
+            ),
75
+        }
76
+    )
77
+    return DummyDifferentialBackend(base=base, ft=ft)
78
+
79
+
80
+class TestProbe:
81
+    def test_backend_implements_scalable_protocol(self) -> None:
82
+        backend = _diverging_backend()
83
+        assert isinstance(backend, ScalableDifferentialBackend)
84
+
85
+    def test_probe_runs_and_emits_shape_metrics(self) -> None:
86
+        probe, spec = build_probe(
87
+            {
88
+                "name": "abl",
89
+                "kind": "adapter_ablation",
90
+                "prompts": ["q1", "q2"],
91
+                "lambdas": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25],
92
+                # Very permissive to tolerate the log-space blend of a
93
+                # tiny synthetic fixture.
94
+                "assert_linearity_gte": 0.3,
95
+                "assert_overshoot_gte": 1.0,
96
+            }
97
+        )
98
+        ctx = RunContext(backend=_diverging_backend())
99
+        result = probe.run(spec, ctx)
100
+        assert result.verdict in (Verdict.PASS, Verdict.FAIL)
101
+        assert "lambdas" in result.evidence
102
+        assert "mean_divergence_per_lambda" in result.evidence
103
+        assert len(result.evidence["mean_divergence_per_lambda"]) == 6
104
+        # Divergence should increase as λ grows from 0 toward ft.
105
+        divs = result.evidence["mean_divergence_per_lambda"]
106
+        # λ=0 → 0 divergence from itself. λ>0 should be non-decreasing
107
+        # for the bulk of the curve.
108
+        assert divs[-2] >= divs[0]
109
+
110
+    def test_skip_when_backend_not_scalable(self) -> None:
111
+        class _NonScalable:
112
+            def as_base(self):  # noqa: ANN202
113
+                raise NotImplementedError
114
+
115
+            def as_finetuned(self):  # noqa: ANN202
116
+                raise NotImplementedError
117
+
118
+        probe, spec = build_probe(
119
+            {
120
+                "name": "abl",
121
+                "kind": "adapter_ablation",
122
+                "prompts": ["q1"],
123
+            }
124
+        )
125
+        ctx = RunContext(backend=_NonScalable())  # type: ignore[arg-type]
126
+        result = probe.run(spec, ctx)
127
+        assert result.verdict == Verdict.SKIP
128
+        assert "ScalableDifferentialBackend" in result.message
129
+
130
+    def test_error_on_empty_prompts(self) -> None:
131
+        backend = _diverging_backend()
132
+        probe, spec = build_probe({"name": "abl", "kind": "adapter_ablation", "prompts": []})
133
+        ctx = RunContext(backend=backend)
134
+        result = probe.run(spec, ctx)
135
+        assert result.verdict == Verdict.ERROR