@@ -2,56 +2,73 @@ |
| 2 | | 2 | |
| 3 | Every numeric primitive reports its raw metric *and* a z-score against a | 3 | Every numeric primitive reports its raw metric *and* a z-score against a |
| 4 | null-adapter distribution. This probe is the runtime engine that | 4 | null-adapter distribution. This probe is the runtime engine that |
| 5 | -establishes that distribution — running each configured primitive | 5 | +establishes that distribution — it builds random-init "null" adapters |
| 6 | -against a series of random-init-style "null" adapters (structurally | 6 | +(structurally identical to the real adapter but with weights drawn from |
| 7 | -identical to the real adapter but with weights indistinguishable from | 7 | +a Gaussian) and measures how much signal they produce. |
| 8 | -noise) and caching the resulting ``(mean, std, n)`` per primitive kind. | 8 | + |
| 9 | - | 9 | +The resulting ``(mean, std, n)`` per kind is attached to this probe's |
| 10 | -The heavy lifting — materializing random-init LoRAs on the loaded model | 10 | +``evidence["null_stats"]``. The runner picks it up and threads it into |
| 11 | -and running probes with them — lives in the HF backend (later | 11 | +:attr:`RunContext.null_stats`, where every downstream probe can read it |
| 12 | -milestone). For now this module ships the spec + the lookup API that | 12 | +and turn a raw metric into a z-score. |
| 13 | -probes will use to z-score their results once stats are populated. | 13 | + |
| | 14 | +Backends that don't implement :class:`~dlm_sway.core.scoring.NullCalibratedBackend` |
| | 15 | +cause this probe to :attr:`Verdict.SKIP` — downstream probes fall back |
| | 16 | +to their fixed thresholds in that case. |
| 14 | """ | 17 | """ |
| 15 | | 18 | |
| 16 | from __future__ import annotations | 19 | from __future__ import annotations |
| 17 | | 20 | |
| | 21 | +import statistics |
| 18 | from typing import Literal | 22 | from typing import Literal |
| 19 | | 23 | |
| 20 | from pydantic import Field | 24 | from pydantic import Field |
| 21 | | 25 | |
| 22 | from dlm_sway.core.result import ProbeResult, Verdict | 26 | from dlm_sway.core.result import ProbeResult, Verdict |
| | 27 | +from dlm_sway.core.scoring import NullCalibratedBackend |
| | 28 | +from dlm_sway.probes._divergence import divergence |
| 23 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext | 29 | from dlm_sway.probes.base import Probe, ProbeSpec, RunContext |
| 24 | | 30 | |
| 25 | | 31 | |
| 26 | class NullAdapterSpec(ProbeSpec): | 32 | class NullAdapterSpec(ProbeSpec): |
| 27 | """Spec for ``kind: null_adapter``. | 33 | """Spec for ``kind: null_adapter``. |
| 28 | | 34 | |
| 29 | - This is a meta-probe: it doesn't test the adapter, it calibrates | 35 | + Authors place this probe **first** in the suite so its output |
| 30 | - *other* probes. Place it first in the suite so its output is in | 36 | + populates :attr:`RunContext.null_stats` before subsequent probes |
| 31 | - :attr:`~dlm_sway.probes.base.RunContext.null_stats` when later | 37 | + consult it. |
| 32 | - probes run. | | |
| 33 | """ | 38 | """ |
| 34 | | 39 | |
| 35 | kind: Literal["null_adapter"] = "null_adapter" | 40 | kind: Literal["null_adapter"] = "null_adapter" |
| 36 | runs: int = Field(default=3, ge=1, le=10) | 41 | runs: int = Field(default=3, ge=1, le=10) |
| 37 | """Number of independent null adapters to evaluate. Three is the | 42 | """Number of independent null adapters to evaluate. Three is the |
| 38 | - smallest that gives a usable std estimate; more is better but quickly | 43 | + smallest that yields a usable std; more is better but quickly |
| 39 | dominates suite runtime.""" | 44 | dominates suite runtime.""" |
| 40 | - rank: int | None = None | 45 | + prompts: list[str] = Field(default_factory=list) |
| 41 | - """LoRA rank for the null adapter. ``None`` → match the real adapter.""" | 46 | + """Prompt set for null calibration. Keep small — calibration runs |
| 42 | - alpha: int | None = None | 47 | + ``runs × len(prompts)`` forward passes. 4–8 prompts is typical. |
| 43 | - """LoRA alpha. ``None`` → match the real adapter.""" | 48 | + If empty, a minimal built-in prompt set is used so the probe |
| | 49 | + always produces stats.""" |
| 44 | init_scale: float = 0.02 | 50 | init_scale: float = 0.02 |
| 45 | - """Standard deviation of the zero-mean Gaussian used to init | 51 | + """Stddev of the zero-mean Gaussian used to fill lora_A/lora_B.""" |
| 46 | - lora_A/lora_B. Matches typical post-init scale.""" | 52 | + seed_base: int = 1000 |
| | 53 | + """First seed; successive runs use ``seed_base + run_idx``.""" |
| | 54 | + |
| | 55 | + |
| | 56 | +_DEFAULT_PROMPTS: tuple[str, ...] = ( |
| | 57 | + "The quick brown fox", |
| | 58 | + "Once upon a time", |
| | 59 | + "In this document we explain", |
| | 60 | + "The key takeaway is", |
| | 61 | + "An important point to remember", |
| | 62 | +) |
| 47 | | 63 | |
| 48 | | 64 | |
| 49 | class NullAdapterProbe(Probe): | 65 | class NullAdapterProbe(Probe): |
| 50 | - """Populate ``ctx.null_stats``; report a :attr:`Verdict.SKIP` verdict itself. | 66 | + """Populate ``ctx.null_stats``; report a :attr:`Verdict.PASS` verdict itself. |
| 51 | | 67 | |
| 52 | - The probe never fails on its own terms — its *job* is calibration, | 68 | + The probe never fails on its own terms — its *job* is calibration. |
| 53 | - not judgment. Downstream probes consult | 69 | + Downstream probes pick up :attr:`RunContext.null_stats` keyed by |
| 54 | - :meth:`get_null_stats` to turn their raw metric into a z-score. | 70 | + probe kind (``delta_kl``, ``adapter_ablation`` …) and use the |
| | 71 | + populated mean/std to z-score their own raw metrics. |
| 55 | """ | 72 | """ |
| 56 | | 73 | |
| 57 | kind = "null_adapter" | 74 | kind = "null_adapter" |
@@ -59,22 +76,61 @@ class NullAdapterProbe(Probe): |
| 59 | category = "baseline" | 76 | category = "baseline" |
| 60 | | 77 | |
| 61 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: | 78 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: |
| 62 | - # Concrete null-adapter materialization is backend-specific. For | | |
| 63 | - # the HF backend it will build random-init LoRAs with matched | | |
| 64 | - # rank/alpha. That path is wired in a later milestone; this probe | | |
| 65 | - # currently reports SKIP so suite composition stays stable. | | |
| 66 | - del ctx # unused until HF-level materialization lands | | |
| 67 | assert isinstance(spec, NullAdapterSpec) | 79 | assert isinstance(spec, NullAdapterSpec) |
| | 80 | + if not isinstance(ctx.backend, NullCalibratedBackend): |
| | 81 | + return ProbeResult( |
| | 82 | + name=spec.name, |
| | 83 | + kind=spec.kind, |
| | 84 | + verdict=Verdict.SKIP, |
| | 85 | + score=None, |
| | 86 | + message=( |
| | 87 | + "backend does not implement NullCalibratedBackend — " |
| | 88 | + "numeric probes will fall back to fixed thresholds" |
| | 89 | + ), |
| | 90 | + ) |
| | 91 | + prompts = list(spec.prompts) or list(_DEFAULT_PROMPTS) |
| | 92 | + |
| | 93 | + per_seed_means: list[float] = [] |
| | 94 | + for run_idx in range(spec.runs): |
| | 95 | + seed = spec.seed_base + run_idx |
| | 96 | + per_prompt: list[float] = [] |
| | 97 | + for prompt in prompts: |
| | 98 | + with ctx.backend.as_base() as base_view: |
| | 99 | + base_dist = base_view.next_token_dist(prompt, top_k=ctx.top_k) |
| | 100 | + with ctx.backend.as_null_adapter(seed, init_scale=spec.init_scale) as null_view: |
| | 101 | + null_dist = null_view.next_token_dist(prompt, top_k=ctx.top_k) |
| | 102 | + per_prompt.append(divergence(base_dist, null_dist, kind="js")) |
| | 103 | + per_seed_means.append(statistics.fmean(per_prompt) if per_prompt else 0.0) |
| | 104 | + |
| | 105 | + mean = statistics.fmean(per_seed_means) |
| | 106 | + std = statistics.pstdev(per_seed_means) if len(per_seed_means) > 1 else 0.0 |
| | 107 | + |
| | 108 | + # Publish per-kind stats. delta_kl is the primary kind; other |
| | 109 | + # divergence-based probes (adapter_ablation) share this scale. |
| | 110 | + null_stats = { |
| | 111 | + "delta_kl": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)}, |
| | 112 | + "adapter_ablation": {"mean": mean, "std": max(std, 1e-6), "n": float(spec.runs)}, |
| | 113 | + } |
| | 114 | + |
| 68 | return ProbeResult( | 115 | return ProbeResult( |
| 69 | name=spec.name, | 116 | name=spec.name, |
| 70 | kind=spec.kind, | 117 | kind=spec.kind, |
| 71 | - verdict=Verdict.SKIP, | 118 | + verdict=Verdict.PASS, |
| 72 | - score=None, | 119 | + score=1.0, |
| | 120 | + raw=mean, |
| | 121 | + evidence={ |
| | 122 | + "null_stats": null_stats, |
| | 123 | + "per_seed_mean_js": per_seed_means, |
| | 124 | + "init_scale": spec.init_scale, |
| | 125 | + "runs": spec.runs, |
| | 126 | + "num_prompts": len(prompts), |
| | 127 | + "weight": spec.weight, |
| | 128 | + }, |
| 73 | message=( | 129 | message=( |
| 74 | - "null-adapter calibration pending — downstream probes will fall back to " | 130 | + f"null JS divergence μ={mean:.4f} ± {std:.4f} " |
| 75 | - "fixed thresholds until the backend-level materialization lands" | 131 | + f"(over {spec.runs} seeds × {len(prompts)} prompts) — " |
| | 132 | + f"downstream probes will z-score against this baseline" |
| 76 | ), | 133 | ), |
| 77 | - evidence={"runs": spec.runs, "rank": spec.rank, "alpha": spec.alpha}, | | |
| 78 | ) | 134 | ) |
| 79 | | 135 | |
| 80 | | 136 | |
@@ -82,7 +138,7 @@ def get_null_stats(ctx: RunContext, probe_kind: str) -> dict[str, float] | None: |
| 82 | """Look up null-adapter stats for ``probe_kind``. | 138 | """Look up null-adapter stats for ``probe_kind``. |
| 83 | | 139 | |
| 84 | Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for | 140 | Returns ``{"mean": …, "std": …, "n": …}`` when calibration ran for |
| 85 | - this kind, else ``None``. Probes should treat ``None`` as "fall back | 141 | + this kind, else ``None``. Probes treat ``None`` as "fall back to the |
| 86 | - to the fixed threshold from your spec." | 142 | + fixed threshold from your spec." |
| 87 | """ | 143 | """ |
| 88 | return ctx.null_stats.get(probe_kind) | 144 | return ctx.null_stats.get(probe_kind) |