@@ -0,0 +1,193 @@ |
| 1 | +"""N2 AdapterAblation — the sway signature primitive. |
| 2 | + |
| 3 | +Scales the LoRA additive term by λ ∈ {0, 0.25, 0.5, 0.75, 1.0, 1.25} |
| 4 | +and measures the mean divergence from the base distribution at each |
| 5 | +step. Fits a monotonic response curve; reports three shape metrics: |
| 6 | + |
| 7 | +- **linearity**: R² of a linear fit on ``(λ, mean_div)``. High means |
| 8 | + the adapter's effect scales predictably; low means it's "all or |
| 9 | + nothing" (degenerate). |
| 10 | +- **saturation_lambda**: the smallest λ at which divergence reaches |
| 11 | + 90% of the λ=1 value. Too low (<0.3) means the adapter fires at |
| 12 | + partial strength — fragile. Too high (>1.0) means the adapter is |
| 13 | + under-trained. |
| 14 | +- **overshoot**: divergence at λ=1.25 divided by λ=1.0. >1.05 is the |
| 15 | + healthy "pushing past 1 still moves the model" signal. An overshoot |
| 16 | + below 1.0 suggests collapse. |
| 17 | + |
| 18 | +This is the single novel primitive that no generic eval harness |
| 19 | +provides — sway's position next to the adapter math makes it possible. |
| 20 | + |
| 21 | +Requires the backend to implement |
| 22 | +:class:`~dlm_sway.core.scoring.ScalableDifferentialBackend`. Probes |
| 23 | +SKIP gracefully on backends that don't. |
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import annotations |
| 27 | + |
| 28 | +from typing import Literal |
| 29 | + |
| 30 | +import numpy as np |
| 31 | +from pydantic import Field |
| 32 | + |
| 33 | +from dlm_sway.core.result import ProbeResult, Verdict |
| 34 | +from dlm_sway.core.scoring import ScalableDifferentialBackend |
| 35 | +from dlm_sway.probes._divergence import Divergence, divergence |
| 36 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 37 | + |
| 38 | + |
| 39 | +class AdapterAblationSpec(ProbeSpec): |
| 40 | + kind: Literal["adapter_ablation"] = "adapter_ablation" |
| 41 | + prompts: list[str] = Field(default_factory=list) |
| 42 | + lambdas: list[float] = Field( |
| 43 | + default_factory=lambda: [0.0, 0.25, 0.5, 0.75, 1.0, 1.25], |
| 44 | + min_length=3, |
| 45 | + ) |
| 46 | + divergence: Divergence = "js" |
| 47 | + top_k: int | None = None |
| 48 | + assert_linearity_gte: float = 0.85 |
| 49 | + assert_saturation_between: tuple[float, float] = (0.3, 1.05) |
| 50 | + assert_overshoot_gte: float = 1.02 |
| 51 | + |
| 52 | + |
| 53 | +class AdapterAblationProbe(Probe): |
| 54 | + kind = "adapter_ablation" |
| 55 | + spec_cls = AdapterAblationSpec |
| 56 | + category = "ablation" |
| 57 | + |
| 58 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 59 | + assert isinstance(spec, AdapterAblationSpec) |
| 60 | + if not spec.prompts: |
| 61 | + return ProbeResult( |
| 62 | + name=spec.name, |
| 63 | + kind=spec.kind, |
| 64 | + verdict=Verdict.ERROR, |
| 65 | + score=None, |
| 66 | + message="no prompts provided", |
| 67 | + ) |
| 68 | + if not isinstance(ctx.backend, ScalableDifferentialBackend): |
| 69 | + return ProbeResult( |
| 70 | + name=spec.name, |
| 71 | + kind=spec.kind, |
| 72 | + verdict=Verdict.SKIP, |
| 73 | + score=None, |
| 74 | + message=( |
| 75 | + "backend does not implement ScalableDifferentialBackend — " |
| 76 | + "adapter ablation requires LoRA-scale access" |
| 77 | + ), |
| 78 | + ) |
| 79 | + |
| 80 | + top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| 81 | + |
| 82 | + # Reference distribution at λ=0 (adapter scaled to zero → base). |
| 83 | + lam_zero = min(spec.lambdas) |
| 84 | + per_lambda: list[float] = [] |
| 85 | + for lam in spec.lambdas: |
| 86 | + divs_for_lam: list[float] = [] |
| 87 | + for prompt in spec.prompts: |
| 88 | + with ctx.backend.as_scaled_adapter(lam_zero) as ref: |
| 89 | + ref_dist = ref.next_token_dist(prompt, top_k=top_k) |
| 90 | + with ctx.backend.as_scaled_adapter(lam) as scaled: |
| 91 | + scaled_dist = scaled.next_token_dist(prompt, top_k=top_k) |
| 92 | + divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence)) |
| 93 | + per_lambda.append(float(np.mean(divs_for_lam))) |
| 94 | + |
| 95 | + lambdas_arr = np.asarray(spec.lambdas, dtype=np.float64) |
| 96 | + divs_arr = np.asarray(per_lambda, dtype=np.float64) |
| 97 | + |
| 98 | + linearity = _r_squared(lambdas_arr, divs_arr) |
| 99 | + saturation_lambda = _saturation_lambda(lambdas_arr, divs_arr) |
| 100 | + overshoot = _overshoot(lambdas_arr, divs_arr) |
| 101 | + |
| 102 | + # Pass when all three shape metrics land in their healthy bands. |
| 103 | + sat_lo, sat_hi = spec.assert_saturation_between |
| 104 | + ok_lin = linearity >= spec.assert_linearity_gte |
| 105 | + ok_sat = saturation_lambda is not None and sat_lo <= saturation_lambda <= sat_hi |
| 106 | + ok_over = overshoot >= spec.assert_overshoot_gte |
| 107 | + verdict = Verdict.PASS if (ok_lin and ok_sat and ok_over) else Verdict.FAIL |
| 108 | + |
| 109 | + lin_score = max(0.0, min(1.0, linearity / max(spec.assert_linearity_gte, 1e-6))) |
| 110 | + over_score = max(0.0, min(1.0, (overshoot - 1.0) / 0.2)) |
| 111 | + sat_score = 1.0 if ok_sat else 0.3 |
| 112 | + score = 0.4 * lin_score + 0.3 * sat_score + 0.3 * over_score |
| 113 | + |
| 114 | + return ProbeResult( |
| 115 | + name=spec.name, |
| 116 | + kind=spec.kind, |
| 117 | + verdict=verdict, |
| 118 | + score=score, |
| 119 | + raw=linearity, |
| 120 | + evidence={ |
| 121 | + "lambdas": spec.lambdas, |
| 122 | + "mean_divergence_per_lambda": per_lambda, |
| 123 | + "linearity": linearity, |
| 124 | + "saturation_lambda": saturation_lambda, |
| 125 | + "overshoot": overshoot, |
| 126 | + "passed_linearity": ok_lin, |
| 127 | + "passed_saturation": ok_sat, |
| 128 | + "passed_overshoot": ok_over, |
| 129 | + "weight": spec.weight, |
| 130 | + }, |
| 131 | + message=( |
| 132 | + f"R²={linearity:.2f}, sat_λ={saturation_lambda:.2f} " |
| 133 | + f"({'in' if ok_sat else 'out of'} band), overshoot={overshoot:.2f}" |
| 134 | + if saturation_lambda is not None |
| 135 | + else f"R²={linearity:.2f}, saturation undetected, overshoot={overshoot:.2f}" |
| 136 | + ), |
| 137 | + ) |
| 138 | + |
| 139 | + |
| 140 | +def _r_squared(x: np.ndarray, y: np.ndarray) -> float: |
| 141 | + """Coefficient of determination for a linear fit of ``y`` on ``x``.""" |
| 142 | + if x.size < 2: |
| 143 | + return 0.0 |
| 144 | + xm = float(x.mean()) |
| 145 | + ym = float(y.mean()) |
| 146 | + denom = float(((x - xm) ** 2).sum()) |
| 147 | + if denom == 0.0: |
| 148 | + return 0.0 |
| 149 | + slope = float(((x - xm) * (y - ym)).sum()) / denom |
| 150 | + intercept = ym - slope * xm |
| 151 | + y_pred = slope * x + intercept |
| 152 | + ss_res = float(((y - y_pred) ** 2).sum()) |
| 153 | + ss_tot = float(((y - ym) ** 2).sum()) |
| 154 | + if ss_tot == 0.0: |
| 155 | + return 1.0 |
| 156 | + return max(0.0, 1.0 - ss_res / ss_tot) |
| 157 | + |
| 158 | + |
| 159 | +def _saturation_lambda(lambdas: np.ndarray, divs: np.ndarray) -> float | None: |
| 160 | + """Smallest λ ≤ 1.0 at which divergence reaches 90% of div(λ=1).""" |
| 161 | + # Locate the index of λ=1.0 (or the closest entry ≤ 1.0). |
| 162 | + candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0] |
| 163 | + if candidates.size == 0: |
| 164 | + # Fall back to the largest λ ≤ 1.0. |
| 165 | + mask = lambdas <= 1.0 |
| 166 | + if not mask.any(): |
| 167 | + return None |
| 168 | + idx1 = int(np.argmax(lambdas * mask)) |
| 169 | + else: |
| 170 | + idx1 = int(candidates[0]) |
| 171 | + target = 0.9 * float(divs[idx1]) |
| 172 | + if target <= 0: |
| 173 | + return None |
| 174 | + for lam, d in zip(lambdas[: idx1 + 1], divs[: idx1 + 1], strict=False): |
| 175 | + if d >= target: |
| 176 | + return float(lam) |
| 177 | + return None |
| 178 | + |
| 179 | + |
| 180 | +def _overshoot(lambdas: np.ndarray, divs: np.ndarray) -> float: |
| 181 | + """``div(λ_max) / div(λ=1)``. Returns 1.0 if λ_max ≤ 1.0.""" |
| 182 | + idx_max = int(np.argmax(lambdas)) |
| 183 | + candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0] |
| 184 | + if candidates.size == 0: |
| 185 | + return 1.0 |
| 186 | + idx1 = int(candidates[0]) |
| 187 | + if idx_max == idx1: |
| 188 | + return 1.0 |
| 189 | + d1 = float(divs[idx1]) |
| 190 | + dmax = float(divs[idx_max]) |
| 191 | + if d1 <= 0: |
| 192 | + return 1.0 |
| 193 | + return dmax / d1 |