@@ -0,0 +1,193 @@ |
| | 1 | +"""N2 AdapterAblation — the sway signature primitive. |
| | 2 | + |
| | 3 | +Scales the LoRA additive term by λ ∈ {0, 0.25, 0.5, 0.75, 1.0, 1.25} |
| | 4 | +and measures the mean divergence from the base distribution at each |
| | 5 | +step. Fits a monotonic response curve; reports three shape metrics: |
| | 6 | + |
| | 7 | +- **linearity**: R² of a linear fit on ``(λ, mean_div)``. High means |
| | 8 | + the adapter's effect scales predictably; low means it's "all or |
| | 9 | + nothing" (degenerate). |
| | 10 | +- **saturation_lambda**: the smallest λ at which divergence reaches |
| | 11 | + 90% of the λ=1 value. Too low (<0.3) means the adapter fires at |
| | 12 | + partial strength — fragile. Too high (>1.0) means the adapter is |
| | 13 | + under-trained. |
| | 14 | +- **overshoot**: divergence at λ=1.25 divided by λ=1.0. >1.05 is the |
| | 15 | + healthy "pushing past 1 still moves the model" signal. An overshoot |
| | 16 | + below 1.0 suggests collapse. |
| | 17 | + |
| | 18 | +This is the single novel primitive that no generic eval harness |
| | 19 | +provides — sway's position next to the adapter math makes it possible. |
| | 20 | + |
| | 21 | +Requires the backend to implement |
| | 22 | +:class:`~dlm_sway.core.scoring.ScalableDifferentialBackend`. Probes |
| | 23 | +SKIP gracefully on backends that don't. |
| | 24 | +""" |
| | 25 | + |
| | 26 | +from __future__ import annotations |
| | 27 | + |
| | 28 | +from typing import Literal |
| | 29 | + |
| | 30 | +import numpy as np |
| | 31 | +from pydantic import Field |
| | 32 | + |
| | 33 | +from dlm_sway.core.result import ProbeResult, Verdict |
| | 34 | +from dlm_sway.core.scoring import ScalableDifferentialBackend |
| | 35 | +from dlm_sway.probes._divergence import Divergence, divergence |
| | 36 | +from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| | 37 | + |
| | 38 | + |
| | 39 | +class AdapterAblationSpec(ProbeSpec): |
| | 40 | + kind: Literal["adapter_ablation"] = "adapter_ablation" |
| | 41 | + prompts: list[str] = Field(default_factory=list) |
| | 42 | + lambdas: list[float] = Field( |
| | 43 | + default_factory=lambda: [0.0, 0.25, 0.5, 0.75, 1.0, 1.25], |
| | 44 | + min_length=3, |
| | 45 | + ) |
| | 46 | + divergence: Divergence = "js" |
| | 47 | + top_k: int | None = None |
| | 48 | + assert_linearity_gte: float = 0.85 |
| | 49 | + assert_saturation_between: tuple[float, float] = (0.3, 1.05) |
| | 50 | + assert_overshoot_gte: float = 1.02 |
| | 51 | + |
| | 52 | + |
| | 53 | +class AdapterAblationProbe(Probe): |
| | 54 | + kind = "adapter_ablation" |
| | 55 | + spec_cls = AdapterAblationSpec |
| | 56 | + category = "ablation" |
| | 57 | + |
| | 58 | + def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| | 59 | + assert isinstance(spec, AdapterAblationSpec) |
| | 60 | + if not spec.prompts: |
| | 61 | + return ProbeResult( |
| | 62 | + name=spec.name, |
| | 63 | + kind=spec.kind, |
| | 64 | + verdict=Verdict.ERROR, |
| | 65 | + score=None, |
| | 66 | + message="no prompts provided", |
| | 67 | + ) |
| | 68 | + if not isinstance(ctx.backend, ScalableDifferentialBackend): |
| | 69 | + return ProbeResult( |
| | 70 | + name=spec.name, |
| | 71 | + kind=spec.kind, |
| | 72 | + verdict=Verdict.SKIP, |
| | 73 | + score=None, |
| | 74 | + message=( |
| | 75 | + "backend does not implement ScalableDifferentialBackend — " |
| | 76 | + "adapter ablation requires LoRA-scale access" |
| | 77 | + ), |
| | 78 | + ) |
| | 79 | + |
| | 80 | + top_k = spec.top_k if spec.top_k is not None else ctx.top_k |
| | 81 | + |
| | 82 | + # Reference distribution at λ=0 (adapter scaled to zero → base). |
| | 83 | + lam_zero = min(spec.lambdas) |
| | 84 | + per_lambda: list[float] = [] |
| | 85 | + for lam in spec.lambdas: |
| | 86 | + divs_for_lam: list[float] = [] |
| | 87 | + for prompt in spec.prompts: |
| | 88 | + with ctx.backend.as_scaled_adapter(lam_zero) as ref: |
| | 89 | + ref_dist = ref.next_token_dist(prompt, top_k=top_k) |
| | 90 | + with ctx.backend.as_scaled_adapter(lam) as scaled: |
| | 91 | + scaled_dist = scaled.next_token_dist(prompt, top_k=top_k) |
| | 92 | + divs_for_lam.append(divergence(ref_dist, scaled_dist, kind=spec.divergence)) |
| | 93 | + per_lambda.append(float(np.mean(divs_for_lam))) |
| | 94 | + |
| | 95 | + lambdas_arr = np.asarray(spec.lambdas, dtype=np.float64) |
| | 96 | + divs_arr = np.asarray(per_lambda, dtype=np.float64) |
| | 97 | + |
| | 98 | + linearity = _r_squared(lambdas_arr, divs_arr) |
| | 99 | + saturation_lambda = _saturation_lambda(lambdas_arr, divs_arr) |
| | 100 | + overshoot = _overshoot(lambdas_arr, divs_arr) |
| | 101 | + |
| | 102 | + # Pass when all three shape metrics land in their healthy bands. |
| | 103 | + sat_lo, sat_hi = spec.assert_saturation_between |
| | 104 | + ok_lin = linearity >= spec.assert_linearity_gte |
| | 105 | + ok_sat = saturation_lambda is not None and sat_lo <= saturation_lambda <= sat_hi |
| | 106 | + ok_over = overshoot >= spec.assert_overshoot_gte |
| | 107 | + verdict = Verdict.PASS if (ok_lin and ok_sat and ok_over) else Verdict.FAIL |
| | 108 | + |
| | 109 | + lin_score = max(0.0, min(1.0, linearity / max(spec.assert_linearity_gte, 1e-6))) |
| | 110 | + over_score = max(0.0, min(1.0, (overshoot - 1.0) / 0.2)) |
| | 111 | + sat_score = 1.0 if ok_sat else 0.3 |
| | 112 | + score = 0.4 * lin_score + 0.3 * sat_score + 0.3 * over_score |
| | 113 | + |
| | 114 | + return ProbeResult( |
| | 115 | + name=spec.name, |
| | 116 | + kind=spec.kind, |
| | 117 | + verdict=verdict, |
| | 118 | + score=score, |
| | 119 | + raw=linearity, |
| | 120 | + evidence={ |
| | 121 | + "lambdas": spec.lambdas, |
| | 122 | + "mean_divergence_per_lambda": per_lambda, |
| | 123 | + "linearity": linearity, |
| | 124 | + "saturation_lambda": saturation_lambda, |
| | 125 | + "overshoot": overshoot, |
| | 126 | + "passed_linearity": ok_lin, |
| | 127 | + "passed_saturation": ok_sat, |
| | 128 | + "passed_overshoot": ok_over, |
| | 129 | + "weight": spec.weight, |
| | 130 | + }, |
| | 131 | + message=( |
| | 132 | + f"R²={linearity:.2f}, sat_λ={saturation_lambda:.2f} " |
| | 133 | + f"({'in' if ok_sat else 'out of'} band), overshoot={overshoot:.2f}" |
| | 134 | + if saturation_lambda is not None |
| | 135 | + else f"R²={linearity:.2f}, saturation undetected, overshoot={overshoot:.2f}" |
| | 136 | + ), |
| | 137 | + ) |
| | 138 | + |
| | 139 | + |
| | 140 | +def _r_squared(x: np.ndarray, y: np.ndarray) -> float: |
| | 141 | + """Coefficient of determination for a linear fit of ``y`` on ``x``.""" |
| | 142 | + if x.size < 2: |
| | 143 | + return 0.0 |
| | 144 | + xm = float(x.mean()) |
| | 145 | + ym = float(y.mean()) |
| | 146 | + denom = float(((x - xm) ** 2).sum()) |
| | 147 | + if denom == 0.0: |
| | 148 | + return 0.0 |
| | 149 | + slope = float(((x - xm) * (y - ym)).sum()) / denom |
| | 150 | + intercept = ym - slope * xm |
| | 151 | + y_pred = slope * x + intercept |
| | 152 | + ss_res = float(((y - y_pred) ** 2).sum()) |
| | 153 | + ss_tot = float(((y - ym) ** 2).sum()) |
| | 154 | + if ss_tot == 0.0: |
| | 155 | + return 1.0 |
| | 156 | + return max(0.0, 1.0 - ss_res / ss_tot) |
| | 157 | + |
| | 158 | + |
| | 159 | +def _saturation_lambda(lambdas: np.ndarray, divs: np.ndarray) -> float | None: |
| | 160 | + """Smallest λ ≤ 1.0 at which divergence reaches 90% of div(λ=1).""" |
| | 161 | + # Locate the index of λ=1.0 (or the closest entry ≤ 1.0). |
| | 162 | + candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0] |
| | 163 | + if candidates.size == 0: |
| | 164 | + # Fall back to the largest λ ≤ 1.0. |
| | 165 | + mask = lambdas <= 1.0 |
| | 166 | + if not mask.any(): |
| | 167 | + return None |
| | 168 | + idx1 = int(np.argmax(lambdas * mask)) |
| | 169 | + else: |
| | 170 | + idx1 = int(candidates[0]) |
| | 171 | + target = 0.9 * float(divs[idx1]) |
| | 172 | + if target <= 0: |
| | 173 | + return None |
| | 174 | + for lam, d in zip(lambdas[: idx1 + 1], divs[: idx1 + 1], strict=False): |
| | 175 | + if d >= target: |
| | 176 | + return float(lam) |
| | 177 | + return None |
| | 178 | + |
| | 179 | + |
| | 180 | +def _overshoot(lambdas: np.ndarray, divs: np.ndarray) -> float: |
| | 181 | + """``div(λ_max) / div(λ=1)``. Returns 1.0 if λ_max ≤ 1.0.""" |
| | 182 | + idx_max = int(np.argmax(lambdas)) |
| | 183 | + candidates = np.where(np.isclose(lambdas, 1.0, atol=1e-6))[0] |
| | 184 | + if candidates.size == 0: |
| | 185 | + return 1.0 |
| | 186 | + idx1 = int(candidates[0]) |
| | 187 | + if idx_max == idx1: |
| | 188 | + return 1.0 |
| | 189 | + d1 = float(divs[idx1]) |
| | 190 | + dmax = float(divs[idx_max]) |
| | 191 | + if d1 <= 0: |
| | 192 | + return 1.0 |
| | 193 | + return dmax / d1 |