Python · 10713 bytes Raw Blame History
1 """N2 AdapterAblation — the sway signature primitive.
2
3 Scales the LoRA additive term by λ ∈ {0, 0.25, 0.5, 0.75, 1.0, 1.25}
4 and measures the mean divergence from the base distribution at each
5 step. Fits a monotonic response curve; reports three shape metrics:
6
7 - **linearity**: R² of a linear fit on ``(λ, mean_div)``. High means
8 the adapter's effect scales predictably; low means it's "all or
9 nothing" (degenerate).
10 - **saturation_lambda**: the smallest λ at which divergence reaches
11 90% of the λ=1 value. Too low (<0.3) means the adapter fires at
12 partial strength — fragile. Too high (>1.0) means the adapter is
13 under-trained.
14 - **overshoot**: divergence at λ=1.25 divided by λ=1.0. >1.05 is the
15 healthy "pushing past 1 still moves the model" signal. An overshoot
16 below 1.0 suggests collapse.
17
18 This is the single novel primitive that no generic eval harness
19 provides — sway's position next to the adapter math makes it possible.
20
21 Requires the backend to implement
22 :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend`. Probes
23 SKIP gracefully on backends that don't.
24
25 **On the missing ``ci_95`` column.** S14's bootstrap CI column is
26 populated for every *aggregating* probe — ones whose ``raw`` is a
27 sample mean (or similar) over N observations that admit resampling.
28 ``adapter_ablation`` is a curve-fit: the raw metric is an R² on the
29 ``(λ, divergence)`` sweep, not a per-prompt aggregate. Resampling
30 "residuals" would surface confidence on the *fit* rather than on the
31 underlying observations, which confuses the signal the probe
32 reports. The column renders as ``—`` by design; see F14 in the
33 Audit 02 closure for the rationale.
34 """
35
36 from __future__ import annotations
37
38 import math
39 from typing import Literal
40
41 import numpy as np
42 from pydantic import Field
43
44 from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize
45 from dlm_sway.core.scoring import ScalableDifferentialBackend
46 from dlm_sway.probes._divergence import Divergence, divergence
47 from dlm_sway.probes._zscore import (
48 no_calibration_note,
49 score_from_z,
50 verdict_from_z,
51 z_score,
52 z_scores_by_rank,
53 )
54 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
55 from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank
56
57
58 class AdapterAblationSpec(ProbeSpec):
59 kind: Literal["adapter_ablation"] = "adapter_ablation"
60 prompts: list[str] = Field(default_factory=list)
61 lambdas: list[float] = Field(
62 default_factory=lambda: [0.0, 0.25, 0.5, 0.75, 1.0, 1.25],
63 min_length=3,
64 )
65 divergence: Divergence = "js"
66 top_k: int | None = None
67 assert_linearity_gte: float = 0.85
68 assert_saturation_between: tuple[float, float] = (0.3, 1.05)
69 assert_overshoot_gte: float = 1.02
70 assert_z_gte: float = 3.0
71 """Z-score pass criterion against the null-adapter baseline, when it
72 exists. Note: this probe usually opts out of calibration (the null
73 proxy doesn't expose ``as_scaled_adapter``); the z-score path is
74 retained only for shape consistency with the rest of the suite."""
75
76
77 class AdapterAblationProbe(Probe):
78 kind = "adapter_ablation"
79 spec_cls = AdapterAblationSpec
80 category = "ablation"
81
82 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
83 assert isinstance(spec, AdapterAblationSpec)
84 if not spec.prompts:
85 return ProbeResult(
86 name=spec.name,
87 kind=spec.kind,
88 verdict=Verdict.ERROR,
89 score=None,
90 message="no prompts provided",
91 )
92 # Local binding so mypy keeps the ScalableDifferentialBackend
93 # narrowing across the loop below (require_backend's return
94 # type is the base DifferentialBackend; we narrow once here).
95 scalable = ctx.backend
96 if not isinstance(scalable, ScalableDifferentialBackend):
97 return ProbeResult(
98 name=spec.name,
99 kind=spec.kind,
100 verdict=Verdict.SKIP,
101 score=None,
102 message=(
103 "backend does not implement ScalableDifferentialBackend — "
104 "adapter ablation requires LoRA-scale access"
105 ),
106 )
107
108 top_k = spec.top_k if spec.top_k is not None else ctx.top_k
109
110 # Reference distribution at λ=0 (adapter scaled to zero → base).
111 lam_zero = min(spec.lambdas)
112 per_lambda: list[float] = []
113 for lam in spec.lambdas:
114 divs_for_lam: list[float] = []
115 for prompt in spec.prompts:
116 with scalable.as_scaled_adapter(lam_zero) as ref:
117 ref_dist = ref.next_token_dist(prompt, top_k=top_k)
118 with scalable.as_scaled_adapter(lam) as scaled:
119 scaled_dist = scaled.next_token_dist(prompt, top_k=top_k)
120 divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence))
121 per_lambda.append(float(np.mean(divs_for_lam)))
122
123 lambdas_arr = np.asarray(spec.lambdas, dtype=np.float64)
124 divs_arr = np.asarray(per_lambda, dtype=np.float64)
125
126 linearity = _r_squared(lambdas_arr, divs_arr)
127 saturation_lambda, sat_reason = _saturation_lambda(lambdas_arr, divs_arr)
128 overshoot = _overshoot(lambdas_arr, divs_arr)
129
130 # Pass when all three shape metrics land in their healthy bands.
131 sat_lo, sat_hi = spec.assert_saturation_between
132 ok_lin = linearity >= spec.assert_linearity_gte
133 ok_sat = (
134 saturation_lambda is not None
135 and sat_lo <= saturation_lambda <= sat_hi
136 and sat_reason in ("found", "non_monotonic")
137 )
138 ok_over = overshoot >= spec.assert_overshoot_gte
139
140 stats = get_null_stats(ctx, spec.kind)
141 z = z_score(linearity, stats)
142 z_by_rank = z_scores_by_rank(linearity, get_null_stats_by_rank(ctx, spec.kind), sign=+1)
143 verdict_z = verdict_from_z(z, spec.assert_z_gte)
144 if verdict_z is not None:
145 verdict = verdict_z
146 score_val = score_from_z(z)
147 score = score_val if score_val is not None else 0.0
148 else:
149 verdict = Verdict.PASS if (ok_lin and ok_sat and ok_over) else Verdict.FAIL
150 lin_score = max(0.0, min(1.0, linearity / max(spec.assert_linearity_gte, 1e-6)))
151 over_score = max(0.0, min(1.0, (overshoot - 1.0) / 0.2))
152 sat_score = 1.0 if ok_sat else 0.3
153 score = 0.4 * lin_score + 0.3 * sat_score + 0.3 * over_score
154
155 sat_msg = (
156 f"sat_λ={saturation_lambda:.2f} ({'in' if ok_sat else 'out of'} band)"
157 if saturation_lambda is not None
158 else f"saturation undetected ({sat_reason})"
159 )
160 base_msg = f"R²={linearity:.2f}, {sat_msg}, overshoot={overshoot:.2f}"
161 if z is not None:
162 message = f"{base_msg}, z={z:+.2f}σ vs null"
163 else:
164 message = f"{base_msg} {no_calibration_note(spec.kind)}"
165
166 return safe_finalize(
167 name=spec.name,
168 kind=spec.kind,
169 verdict=verdict,
170 score=score,
171 raw=linearity,
172 z_score=z,
173 evidence={
174 "lambdas": spec.lambdas,
175 "mean_divergence_per_lambda": per_lambda,
176 "linearity": linearity,
177 "saturation_lambda": saturation_lambda,
178 "saturation_reason": sat_reason,
179 "overshoot": overshoot,
180 "passed_linearity": ok_lin,
181 "passed_saturation": ok_sat,
182 "passed_overshoot": ok_over,
183 "weight": spec.weight,
184 "z_by_rank": z_by_rank,
185 },
186 message=message,
187 )
188
189
190 def _r_squared(x: np.ndarray, y: np.ndarray) -> float:
191 """Coefficient of determination for a linear fit of ``y`` on ``x``."""
192 if x.size < 2:
193 return 0.0
194 xm = float(x.mean())
195 ym = float(y.mean())
196 denom = float(((x - xm) ** 2).sum())
197 if denom == 0.0:
198 return 0.0
199 slope = float(((x - xm) * (y - ym)).sum()) / denom
200 intercept = ym - slope * xm
201 y_pred = slope * x + intercept
202 ss_res = float(((y - y_pred) ** 2).sum())
203 ss_tot = float(((y - ym) ** 2).sum())
204 if ss_tot == 0.0:
205 return 1.0
206 return max(0.0, 1.0 - ss_res / ss_tot)
207
208
209 SaturationReason = Literal["found", "flat_curve", "non_monotonic", "below_floor"]
210
211
212 def _saturation_lambda(
213 lambdas: np.ndarray, divs: np.ndarray
214 ) -> tuple[float | None, SaturationReason]:
215 """Smallest λ at which divergence reaches 90% of ``max(divs)``.
216
217 Returns ``(value, reason)``:
218
219 - ``("found", λ)`` — saturation reached at the returned λ on a
220 monotonically-non-decreasing curve up to that point.
221 - ``("non_monotonic", λ)`` — saturation point identified but the
222 curve dipped or zigzagged on the way; probe should emit a WARN.
223 - ``("flat_curve", None)`` — every divergence value ≤ 0; adapter
224 produced no measurable signal (often: NaN / zero adapter).
225 - ``("below_floor", None)`` — defensive; shouldn't trigger with the
226 max-based target but kept for future-proofing.
227
228 The B3 fix searches the **full** λ range (not just λ ≤ 1.0) and
229 uses ``max(divs)`` as the reference, so an overshoot at λ=1.25
230 that dips at λ=1.0 still produces a meaningful saturation read.
231 """
232 if lambdas.size == 0 or divs.size == 0:
233 return None, "flat_curve"
234
235 max_div = float(divs.max())
236 if not math.isfinite(max_div) or max_div <= 0.0:
237 return None, "flat_curve"
238
239 target = 0.9 * max_div
240
241 # Search the full curve, not just ≤ 1.0.
242 saturating_idx = np.where(divs >= target)[0]
243 if saturating_idx.size == 0:
244 return None, "below_floor"
245
246 smallest_idx = int(saturating_idx.min())
247 sat_lambda = float(lambdas[smallest_idx])
248
249 # Monotonicity advisory — divs should be non-decreasing up through
250 # the saturation point. A dip is acceptable but signals shape noise.
251 monotonic = bool(np.all(np.diff(divs[: smallest_idx + 1]) >= -1e-9))
252 if not monotonic:
253 return sat_lambda, "non_monotonic"
254
255 return sat_lambda, "found"
256
257
258 def _overshoot(lambdas: np.ndarray, divs: np.ndarray) -> float:
259 """``div(λ_max) / div(λ=1)``. Returns 1.0 if λ_max ≤ 1.0."""
260 idx_max = int(np.argmax(lambdas))
261 candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0]
262 if candidates.size == 0:
263 return 1.0
264 idx1 = int(candidates[0])
265 if idx_max == idx1:
266 return 1.0
267 d1 = float(divs[idx1])
268 dmax = float(divs[idx_max])
269 if d1 <= 0:
270 return 1.0
271 return dmax / d1