| 1 | """Probe abstract base + per-kind registry. |
| 2 | |
| 3 | The registry is the extension point. Adding a new probe means: |
| 4 | |
| 5 | 1. Subclass :class:`ProbeSpec` with a unique ``kind`` field (Literal). |
| 6 | 2. Subclass :class:`Probe` setting ``kind`` and ``spec_cls``. |
| 7 | 3. Importing the probe module at least once (its subclass hook registers |
| 8 | itself). |
| 9 | |
| 10 | The runner uses :func:`build_probe` to map each raw spec dict to a |
| 11 | ``(Probe, ProbeSpec)`` pair. Validation errors are turned into |
| 12 | :class:`~dlm_sway.core.errors.SpecValidationError` with the probe name |
| 13 | as the source so error messages localize to the offending entry. |
| 14 | """ |
| 15 | |
| 16 | from __future__ import annotations |
| 17 | |
| 18 | from abc import ABC, abstractmethod |
| 19 | from collections.abc import Mapping |
| 20 | from dataclasses import dataclass, field |
| 21 | from typing import Any, ClassVar |
| 22 | |
| 23 | from pydantic import BaseModel, ConfigDict, ValidationError |
| 24 | |
| 25 | from dlm_sway.core.errors import SpecValidationError |
| 26 | from dlm_sway.core.result import ProbeResult |
| 27 | from dlm_sway.core.scoring import DifferentialBackend |
| 28 | from dlm_sway.core.sections import Section |
| 29 | |
| 30 | |
| 31 | class ProbeSpec(BaseModel): |
| 32 | """Common fields for every probe's spec entry in ``sway.yaml``.""" |
| 33 | |
| 34 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 35 | |
| 36 | name: str |
| 37 | """Unique within a suite; surfaces in the report.""" |
| 38 | kind: str |
| 39 | """Discriminator — must match a registered :class:`Probe` subclass.""" |
| 40 | enabled: bool = True |
| 41 | """If ``False`` the runner records a :class:`~dlm_sway.core.result.Verdict.SKIP`.""" |
| 42 | weight: float = 1.0 |
| 43 | """Weight inside the probe's component (adherence / attribution / …).""" |
| 44 | |
| 45 | |
| 46 | @dataclass(frozen=True, slots=True) |
| 47 | class RunContext: |
| 48 | """What a probe can read beyond its own spec. |
| 49 | |
| 50 | Probes should receive exactly what they need and nothing more; fat |
| 51 | contexts encourage coupling between unrelated probes. |
| 52 | |
| 53 | Attributes |
| 54 | ---------- |
| 55 | backend: |
| 56 | The differential backend holding base + fine-tuned views. |
| 57 | seed: |
| 58 | Seed for deterministic probe RNGs (paraphrase sampling, etc). |
| 59 | top_k: |
| 60 | Default truncation for next-token distributions. |
| 61 | sections: |
| 62 | Optional list of typed sections (populated by the .dlm bridge; |
| 63 | ``None`` when sway is invoked against bare HF+PEFT). |
| 64 | doc_text: |
| 65 | Raw document text, if available. |
| 66 | null_stats: |
| 67 | Null-adapter baseline stats for z-score calibration, keyed by |
| 68 | probe *kind*. Populated by the runner after it's executed the |
| 69 | ``null_adapter`` probe (if configured). Typed as a read-only |
| 70 | :class:`~collections.abc.Mapping` and constructed from a |
| 71 | ``MappingProxyType`` so probes can't accidentally mutate the |
| 72 | stats other probes will consume. |
| 73 | |
| 74 | When ``null_adapter`` was run with multiple ``rank_multipliers``, |
| 75 | this field carries the 1.0x group (or the first multiplier when |
| 76 | 1.0 isn't present) for back-compat with probes that consume a |
| 77 | single calibration level. For the full rank profile, consult |
| 78 | :attr:`null_stats_by_rank`. |
| 79 | null_stats_by_rank: |
| 80 | Null-adapter stats across every rank multiplier the |
| 81 | ``null_adapter`` probe calibrated. Outer key is the canonical |
| 82 | rank label (``"rank_1.00"``, ``"rank_0.50"``, …); inner key is |
| 83 | the probe kind. Empty when a single-rank calibration ran — |
| 84 | probes that render a rank profile should check length > 1 |
| 85 | before surfacing it. |
| 86 | downstream_kinds: |
| 87 | Tuple of probe kinds that appear *after* the current probe in |
| 88 | the suite. Populated by the runner before each probe runs; |
| 89 | ``NullAdapterProbe`` consults it to decide which probe kinds |
| 90 | to calibrate per-kind null stats for. |
| 91 | """ |
| 92 | |
| 93 | backend: DifferentialBackend | None = None |
| 94 | """The model-scoring backend. Required for every probe with |
| 95 | ``needs_backend=True`` (the default). Pre-run probes |
| 96 | (``needs_backend=False``, e.g. S25 ``gradient_ghost``) tolerate |
| 97 | ``None`` here so the runner can skip backend construction |
| 98 | entirely when only pre-flight probes are scheduled. |
| 99 | |
| 100 | Existing probes access ``self.require_backend`` instead of |
| 101 | ``backend`` directly — the property narrows the type for mypy |
| 102 | and gives a clear runtime error if the runner ever passes |
| 103 | ``None`` to a probe that needs the backend. |
| 104 | """ |
| 105 | seed: int = 0 |
| 106 | top_k: int = 256 |
| 107 | sections: tuple[Section, ...] | None = None |
| 108 | doc_text: str | None = None |
| 109 | null_stats: Mapping[str, Mapping[str, float]] = field(default_factory=dict) |
| 110 | null_stats_by_rank: Mapping[str, Mapping[str, Mapping[str, float]]] = field( |
| 111 | default_factory=dict |
| 112 | ) |
| 113 | downstream_kinds: tuple[str, ...] = field(default_factory=tuple) |
| 114 | |
| 115 | @property |
| 116 | def require_backend(self) -> DifferentialBackend: |
| 117 | """Return :attr:`backend`, asserting non-None. |
| 118 | |
| 119 | Probes with ``needs_backend=True`` (default) call this to |
| 120 | narrow the type from ``DifferentialBackend | None`` to |
| 121 | ``DifferentialBackend``. The runner contract guarantees |
| 122 | non-None when scheduling backend-dependent probes; this |
| 123 | accessor turns a runner bug into a clear error rather than |
| 124 | a confusing AttributeError on ``None.as_base()``. |
| 125 | """ |
| 126 | if self.backend is None: |
| 127 | raise RuntimeError( |
| 128 | "RunContext.backend is None — probe requires a backend " |
| 129 | "(needs_backend=True) but the runner did not provide one. " |
| 130 | "If this is a pre-run probe, set needs_backend=False on " |
| 131 | "the Probe subclass." |
| 132 | ) |
| 133 | return self.backend |
| 134 | |
| 135 | |
| 136 | _REGISTRY: dict[str, type[Probe]] = {} |
| 137 | |
| 138 | |
| 139 | #: Generic LM-agnostic prompts used as sentinel inputs by per-probe |
| 140 | #: ``calibrate_spec()`` overrides. Each is short, content-neutral, and |
| 141 | #: should produce a finite logprob from any sane base model. |
| 142 | SENTINEL_PROMPTS: tuple[str, ...] = ( |
| 143 | "The capital of", |
| 144 | "Once upon a time", |
| 145 | "An interesting fact is", |
| 146 | "The next step is to", |
| 147 | ) |
| 148 | |
| 149 | |
| 150 | #: A fixed reference paragraph for stylistic-fingerprint calibration. |
| 151 | #: Mid-length, mid-complexity prose so the fingerprint vector exercises |
| 152 | #: every dimension non-degenerately. |
| 153 | SENTINEL_DOC: str = ( |
| 154 | "This is a brief reference paragraph. It contains several short " |
| 155 | "sentences. The vocabulary is plain and the punctuation density " |
| 156 | "is moderate. A second paragraph follows the first.\n\n" |
| 157 | "Each clause is concise. Together they sample the fingerprint " |
| 158 | "dimensions a probe needs to compute a meaningful style shift." |
| 159 | ) |
| 160 | |
| 161 | |
| 162 | class Probe(ABC): |
| 163 | """Concrete probe. One instance per probe spec in the suite.""" |
| 164 | |
| 165 | kind: ClassVar[str] |
| 166 | """The string used in ``sway.yaml``'s ``kind`` field.""" |
| 167 | spec_cls: ClassVar[type[ProbeSpec]] |
| 168 | """The pydantic model class that validates this probe's spec.""" |
| 169 | category: ClassVar[str] = "adherence" |
| 170 | """One of: ``adherence``, ``attribution``, ``calibration``, |
| 171 | ``ablation``, ``baseline``. Drives composite scoring.""" |
| 172 | needs_backend: ClassVar[bool] = True |
| 173 | """Does this probe ever call ``ctx.backend``? |
| 174 | |
| 175 | Default ``True`` — every probe shipped through Sprint 24 reads |
| 176 | next-token distributions or logprobs from a backend view. The |
| 177 | runner therefore builds the backend before any probe runs. |
| 178 | |
| 179 | Set ``False`` for **pre-run diagnostic probes** that consume |
| 180 | artifacts on disk (e.g. S25's ``gradient_ghost`` which loads |
| 181 | ``training_state.pt``). When *every* scheduled probe has |
| 182 | ``needs_backend=False``, the runner skips backend construction |
| 183 | entirely — no model load, no GPU memory, no cold-start |
| 184 | latency. Pre-flight verdicts can short-circuit a slow suite |
| 185 | when the adapter is obviously broken. |
| 186 | """ |
| 187 | batch_score: ClassVar[bool] = False |
| 188 | """Does this probe score its prompts via the backend's batched |
| 189 | :meth:`~dlm_sway.core.scoring.ScoringBackend.next_token_dist_batch` |
| 190 | path (S23)? |
| 191 | |
| 192 | When ``True`` the probe's ``run()`` uses a single batched call per |
| 193 | view (typically inside a ``with backend.as_base()`` block) instead |
| 194 | of looping one prompt at a time. Backends with real batching (HF) |
| 195 | amortize kernel-launch overhead; backends without (dummy, MLX for |
| 196 | now) see identical behavior via the Protocol's default loop. |
| 197 | |
| 198 | Opt-in. Set to ``True`` only when every prompt flows through a |
| 199 | uniform ``next_token_dist`` call (delta_kl, cluster_kl). Probes |
| 200 | that parameterize per-prompt context (prompt_collapse) or |
| 201 | interleave backend toggles per prompt (adapter_ablation) keep |
| 202 | ``False`` — those need bespoke batching logic deferred to a |
| 203 | follow-up sprint. |
| 204 | """ |
| 205 | |
| 206 | def __init_subclass__(cls, **kwargs: Any) -> None: |
| 207 | super().__init_subclass__(**kwargs) |
| 208 | # The abstract class itself has no `kind`; skip registration. |
| 209 | if "kind" not in cls.__dict__: |
| 210 | return |
| 211 | kind = cls.kind |
| 212 | if kind in _REGISTRY: |
| 213 | raise ValueError(f"duplicate probe kind {kind!r}: {_REGISTRY[kind]!r} vs {cls!r}") |
| 214 | _REGISTRY[kind] = cls |
| 215 | |
| 216 | @abstractmethod |
| 217 | def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult: ... |
| 218 | |
| 219 | @classmethod |
| 220 | def calibrate_spec(cls, ctx: RunContext) -> ProbeSpec | None: |
| 221 | """Return a small spec for null-adapter calibration of this kind. |
| 222 | |
| 223 | ``NullAdapterProbe`` calls this once per kind in |
| 224 | ``ctx.downstream_kinds`` to harvest the per-kind null |
| 225 | distribution. Returning ``None`` opts the kind out of |
| 226 | calibration — downstream probes of that kind fall back to |
| 227 | their fixed-threshold paths and surface ``(no calibration)`` |
| 228 | in the report. |
| 229 | |
| 230 | Default returns ``None``. Probes that *can* be calibrated |
| 231 | override to return a small spec (typically 4 sentinel prompts, |
| 232 | minimal compute) suitable for running against the |
| 233 | :class:`~dlm_sway.probes._null_proxy.NullCalibrationBackendProxy`. |
| 234 | """ |
| 235 | del ctx |
| 236 | return None |
| 237 | |
| 238 | |
| 239 | def registry() -> dict[str, type[Probe]]: |
| 240 | """Read-only view of registered probes.""" |
| 241 | return dict(_REGISTRY) |
| 242 | |
| 243 | |
| 244 | def build_probe(raw: dict[str, Any]) -> tuple[Probe, ProbeSpec]: |
| 245 | """Validate a raw YAML probe entry and return (Probe instance, spec).""" |
| 246 | kind = raw.get("kind") |
| 247 | if not isinstance(kind, str): |
| 248 | raise SpecValidationError( |
| 249 | "probe entry missing string 'kind' field", |
| 250 | source=str(raw.get("name", "<unknown>")), |
| 251 | ) |
| 252 | if kind not in _REGISTRY: |
| 253 | known = ", ".join(sorted(_REGISTRY)) |
| 254 | raise SpecValidationError( |
| 255 | f"unknown probe kind {kind!r} (registered: {known})", |
| 256 | source=str(raw.get("name", "<unknown>")), |
| 257 | ) |
| 258 | probe_cls = _REGISTRY[kind] |
| 259 | try: |
| 260 | spec = probe_cls.spec_cls.model_validate(raw) |
| 261 | except ValidationError as exc: |
| 262 | raise SpecValidationError(str(exc), source=str(raw.get("name", "<unknown>"))) from exc |
| 263 | return probe_cls(), spec |
| 264 | |
| 265 | |
| 266 | def validate_all_probes(suite: list[dict[str, Any]]) -> None: |
| 267 | """Run :func:`build_probe` against every entry; collect all errors. |
| 268 | |
| 269 | The runner / CLI calls this before constructing the backend (which |
| 270 | is the slow, network-touching step). A typo in any ``kind:`` field |
| 271 | therefore surfaces before any model is loaded — and *every* typo |
| 272 | in the spec surfaces in a single error message instead of |
| 273 | forcing the user to fix one, re-run, fix the next, re-run (B7). |
| 274 | """ |
| 275 | errors: list[str] = [] |
| 276 | for idx, raw in enumerate(suite): |
| 277 | try: |
| 278 | build_probe(raw) |
| 279 | except SpecValidationError as exc: |
| 280 | label = raw.get("name") or f"entry #{idx}" |
| 281 | errors.append(f" - {label}: {exc}") |
| 282 | if errors: |
| 283 | raise SpecValidationError("spec contains invalid probe entries:\n" + "\n".join(errors)) |