| 1 | """N2 AdapterAblation — the sway signature primitive. |
| 2 | |
| 3 | Scales the LoRA additive term by λ ∈ {0, 0.25, 0.5, 0.75, 1.0, 1.25} |
| 4 | and measures the mean divergence from the base distribution at each |
| 5 | step. Fits a monotonic response curve; reports three shape metrics: |
| 6 | |
| 7 | - **linearity**: R² of a linear fit on ``(λ, mean_div)``. High means |
| 8 | the adapter's effect scales predictably; low means it's "all or |
| 9 | nothing" (degenerate). |
| 10 | - **saturation_lambda**: the smallest λ at which divergence reaches |
| 11 | 90% of the λ=1 value. Too low (<0.3) means the adapter fires at |
| 12 | partial strength — fragile. Too high (>1.0) means the adapter is |
| 13 | under-trained. |
| 14 | - **overshoot**: divergence at λ=1.25 divided by λ=1.0. >1.05 is the |
| 15 | healthy "pushing past 1 still moves the model" signal. An overshoot |
| 16 | below 1.0 suggests collapse. |
| 17 | |
| 18 | This is the single novel primitive that no generic eval harness |
| 19 | provides — sway's position next to the adapter math makes it possible. |
| 20 | |
| 21 | Requires the backend to implement |
| 22 | :class:`~dlm_sway.core.scoring.ScalableDifferentialBackend`. Probes |
| 23 | SKIP gracefully on backends that don't. |
| 24 | |
| 25 | **On the missing ``ci_95`` column.** S14's bootstrap CI column is |
| 26 | populated for every *aggregating* probe — ones whose ``raw`` is a |
| 27 | sample mean (or similar) over N observations that admit resampling. |
| 28 | ``adapter_ablation`` is a curve-fit: the raw metric is an R² on the |
| 29 | ``(λ, divergence)`` sweep, not a per-prompt aggregate. Resampling |
| 30 | "residuals" would surface confidence on the *fit* rather than on the |
| 31 | underlying observations, which confuses the signal the probe |
| 32 | reports. The column renders as ``—`` by design; see F14 in the |
| 33 | Audit 02 closure for the rationale. |
| 34 | """ |
| 35 | |
| 36 | from __future__ import annotations |
| 37 | |
| 38 | import math |
| 39 | from typing import Literal |
| 40 | |
| 41 | import numpy as np |
| 42 | from pydantic import Field |
| 43 | |
| 44 | from dlm_sway.core.result import ProbeResult, Verdict, safe_finalize |
| 45 | from dlm_sway.core.scoring import ScalableDifferentialBackend |
| 46 | from dlm_sway.probes._divergence import Divergence, divergence |
| 47 | from dlm_sway.probes._zscore import ( |
| 48 | no_calibration_note, |
| 49 | score_from_z, |
| 50 | verdict_from_z, |
| 51 | z_score, |
| 52 | z_scores_by_rank, |
| 53 | ) |
| 54 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 55 | from dlm_sway.probes.null_adapter import get_null_stats, get_null_stats_by_rank |
| 56 | |
| 57 | |
| 58 | class AdapterAblationSpec(ProbeSpec): |
| 59 | kind: Literal["adapter_ablation"] = "adapter_ablation" |
| 60 | prompts: list[str] = Field(default_factory=list) |
| 61 | lambdas: list[float] = Field( |
| 62 | default_factory=lambda: [0.0, 0.25, 0.5, 0.75, 1.0, 1.25], |
| 63 | min_length=3, |
| 64 | ) |
| 65 | divergence: Divergence = "js" |
| 66 | top_k: int | None = None |
| 67 | assert_linearity_gte: float = 0.85 |
| 68 | assert_saturation_between: tuple[float, float] = (0.3, 1.05) |
| 69 | assert_overshoot_gte: float = 1.02 |
| 70 | assert_z_gte: float = 3.0 |
| 71 | """Z-score pass criterion against the null-adapter baseline, when it |
| 72 | exists. Note: this probe usually opts out of calibration (the null |
| 73 | proxy doesn't expose ``as_scaled_adapter``); the z-score path is |
| 74 | retained only for shape consistency with the rest of the suite.""" |
| 75 | |
| 76 | |
| 77 | class AdapterAblationProbe(Probe): |
| 78 | kind = "adapter_ablation" |
| 79 | spec_cls = AdapterAblationSpec |
| 80 | category = "ablation" |
| 81 | |
| 82 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 83 | assert isinstance(spec, AdapterAblationSpec) |
| 84 | if not spec.prompts: |
| 85 | return ProbeResult( |
| 86 | name=spec.name, |
| 87 | kind=spec.kind, |
| 88 | verdict=Verdict.ERROR, |
| 89 | score=None, |
| 90 | message="no prompts provided", |
| 91 | ) |
| 92 | # Local binding so mypy keeps the ScalableDifferentialBackend |
| 93 | # narrowing across the loop below (require_backend's return |
| 94 | # type is the base DifferentialBackend; we narrow once here). |
| 95 | scalable = ctx.backend |
| 96 | if not isinstance(scalable, ScalableDifferentialBackend): |
| 97 | return ProbeResult( |
| 98 | name=spec.name, |
| 99 | kind=spec.kind, |
| 100 | verdict=Verdict.SKIP, |
| 101 | score=None, |
| 102 | message=( |
| 103 | "backend does not implement ScalableDifferentialBackend — " |
| 104 | "adapter ablation requires LoRA-scale access" |
| 105 | ), |
| 106 | ) |
| 107 | |
| 108 | top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| 109 | |
| 110 | # Reference distribution at λ=0 (adapter scaled to zero → base). |
| 111 | lam_zero = min(spec.lambdas) |
| 112 | per_lambda: list[float] = [] |
| 113 | for lam in spec.lambdas: |
| 114 | divs_for_lam: list[float] = [] |
| 115 | for prompt in spec.prompts: |
| 116 | with scalable.as_scaled_adapter(lam_zero) as ref: |
| 117 | ref_dist = ref.next_token_dist(prompt, top_k=top_k) |
| 118 | with scalable.as_scaled_adapter(lam) as scaled: |
| 119 | scaled_dist = scaled.next_token_dist(prompt, top_k=top_k) |
| 120 | divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence)) |
| 121 | per_lambda.append(float(np.mean(divs_for_lam))) |
| 122 | |
| 123 | lambdas_arr = np.asarray(spec.lambdas, dtype=np.float64) |
| 124 | divs_arr = np.asarray(per_lambda, dtype=np.float64) |
| 125 | |
| 126 | linearity = _r_squared(lambdas_arr, divs_arr) |
| 127 | saturation_lambda, sat_reason = _saturation_lambda(lambdas_arr, divs_arr) |
| 128 | overshoot = _overshoot(lambdas_arr, divs_arr) |
| 129 | |
| 130 | # Pass when all three shape metrics land in their healthy bands. |
| 131 | sat_lo, sat_hi = spec.assert_saturation_between |
| 132 | ok_lin = linearity >= spec.assert_linearity_gte |
| 133 | ok_sat = ( |
| 134 | saturation_lambda is not None |
| 135 | and sat_lo <= saturation_lambda <= sat_hi |
| 136 | and sat_reason in ("found", "non_monotonic") |
| 137 | ) |
| 138 | ok_over = overshoot >= spec.assert_overshoot_gte |
| 139 | |
| 140 | stats = get_null_stats(ctx, spec.kind) |
| 141 | z = z_score(linearity, stats) |
| 142 | z_by_rank = z_scores_by_rank(linearity, get_null_stats_by_rank(ctx, spec.kind), sign=+1) |
| 143 | verdict_z = verdict_from_z(z, spec.assert_z_gte) |
| 144 | if verdict_z is not None: |
| 145 | verdict = verdict_z |
| 146 | score_val = score_from_z(z) |
| 147 | score = score_val if score_val is not None else 0.0 |
| 148 | else: |
| 149 | verdict = Verdict.PASS if (ok_lin and ok_sat and ok_over) else Verdict.FAIL |
| 150 | lin_score = max(0.0, min(1.0, linearity / max(spec.assert_linearity_gte, 1e-6))) |
| 151 | over_score = max(0.0, min(1.0, (overshoot - 1.0) / 0.2)) |
| 152 | sat_score = 1.0 if ok_sat else 0.3 |
| 153 | score = 0.4 * lin_score + 0.3 * sat_score + 0.3 * over_score |
| 154 | |
| 155 | sat_msg = ( |
| 156 | f"sat_λ={saturation_lambda:.2f} ({'in' if ok_sat else 'out of'} band)" |
| 157 | if saturation_lambda is not None |
| 158 | else f"saturation undetected ({sat_reason})" |
| 159 | ) |
| 160 | base_msg = f"R²={linearity:.2f}, {sat_msg}, overshoot={overshoot:.2f}" |
| 161 | if z is not None: |
| 162 | message = f"{base_msg}, z={z:+.2f}σ vs null" |
| 163 | else: |
| 164 | message = f"{base_msg} {no_calibration_note(spec.kind)}" |
| 165 | |
| 166 | return safe_finalize( |
| 167 | name=spec.name, |
| 168 | kind=spec.kind, |
| 169 | verdict=verdict, |
| 170 | score=score, |
| 171 | raw=linearity, |
| 172 | z_score=z, |
| 173 | evidence={ |
| 174 | "lambdas": spec.lambdas, |
| 175 | "mean_divergence_per_lambda": per_lambda, |
| 176 | "linearity": linearity, |
| 177 | "saturation_lambda": saturation_lambda, |
| 178 | "saturation_reason": sat_reason, |
| 179 | "overshoot": overshoot, |
| 180 | "passed_linearity": ok_lin, |
| 181 | "passed_saturation": ok_sat, |
| 182 | "passed_overshoot": ok_over, |
| 183 | "weight": spec.weight, |
| 184 | "z_by_rank": z_by_rank, |
| 185 | }, |
| 186 | message=message, |
| 187 | ) |
| 188 | |
| 189 | |
| 190 | def _r_squared(x: np.ndarray, y: np.ndarray) -> float: |
| 191 | """Coefficient of determination for a linear fit of ``y`` on ``x``.""" |
| 192 | if x.size < 2: |
| 193 | return 0.0 |
| 194 | xm = float(x.mean()) |
| 195 | ym = float(y.mean()) |
| 196 | denom = float(((x - xm) ** 2).sum()) |
| 197 | if denom == 0.0: |
| 198 | return 0.0 |
| 199 | slope = float(((x - xm) * (y - ym)).sum()) / denom |
| 200 | intercept = ym - slope * xm |
| 201 | y_pred = slope * x + intercept |
| 202 | ss_res = float(((y - y_pred) ** 2).sum()) |
| 203 | ss_tot = float(((y - ym) ** 2).sum()) |
| 204 | if ss_tot == 0.0: |
| 205 | return 1.0 |
| 206 | return max(0.0, 1.0 - ss_res / ss_tot) |
| 207 | |
| 208 | |
| 209 | SaturationReason = Literal["found", "flat_curve", "non_monotonic", "below_floor"] |
| 210 | |
| 211 | |
| 212 | def _saturation_lambda( |
| 213 | lambdas: np.ndarray, divs: np.ndarray |
| 214 | ) -> tuple[float | None, SaturationReason]: |
| 215 | """Smallest λ at which divergence reaches 90% of ``max(divs)``. |
| 216 | |
| 217 | Returns ``(value, reason)``: |
| 218 | |
| 219 | - ``("found", λ)`` — saturation reached at the returned λ on a |
| 220 | monotonically-non-decreasing curve up to that point. |
| 221 | - ``("non_monotonic", λ)`` — saturation point identified but the |
| 222 | curve dipped or zigzagged on the way; probe should emit a WARN. |
| 223 | - ``("flat_curve", None)`` — every divergence value ≤ 0; adapter |
| 224 | produced no measurable signal (often: NaN / zero adapter). |
| 225 | - ``("below_floor", None)`` — defensive; shouldn't trigger with the |
| 226 | max-based target but kept for future-proofing. |
| 227 | |
| 228 | The B3 fix searches the **full** λ range (not just λ ≤ 1.0) and |
| 229 | uses ``max(divs)`` as the reference, so an overshoot at λ=1.25 |
| 230 | that dips at λ=1.0 still produces a meaningful saturation read. |
| 231 | """ |
| 232 | if lambdas.size == 0 or divs.size == 0: |
| 233 | return None, "flat_curve" |
| 234 | |
| 235 | max_div = float(divs.max()) |
| 236 | if not math.isfinite(max_div) or max_div <= 0.0: |
| 237 | return None, "flat_curve" |
| 238 | |
| 239 | target = 0.9 * max_div |
| 240 | |
| 241 | # Search the full curve, not just ≤ 1.0. |
| 242 | saturating_idx = np.where(divs >= target)[0] |
| 243 | if saturating_idx.size == 0: |
| 244 | return None, "below_floor" |
| 245 | |
| 246 | smallest_idx = int(saturating_idx.min()) |
| 247 | sat_lambda = float(lambdas[smallest_idx]) |
| 248 | |
| 249 | # Monotonicity advisory — divs should be non-decreasing up through |
| 250 | # the saturation point. A dip is acceptable but signals shape noise. |
| 251 | monotonic = bool(np.all(np.diff(divs[: smallest_idx + 1]) >= -1e-9)) |
| 252 | if not monotonic: |
| 253 | return sat_lambda, "non_monotonic" |
| 254 | |
| 255 | return sat_lambda, "found" |
| 256 | |
| 257 | |
| 258 | def _overshoot(lambdas: np.ndarray, divs: np.ndarray) -> float: |
| 259 | """``div(λ_max) / div(λ=1)``. Returns 1.0 if λ_max ≤ 1.0.""" |
| 260 | idx_max = int(np.argmax(lambdas)) |
| 261 | candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0] |
| 262 | if candidates.size == 0: |
| 263 | return 1.0 |
| 264 | idx1 = int(candidates[0]) |
| 265 | if idx_max == idx1: |
| 266 | return 1.0 |
| 267 | d1 = float(divs[idx1]) |
| 268 | dmax = float(divs[idx_max]) |
| 269 | if d1 <= 0: |
| 270 | return 1.0 |
| 271 | return dmax / d1 |