| 1 | """Top-level ``sway.yaml`` spec models. |
| 2 | |
| 3 | Per-probe specs live next to their implementations in |
| 4 | :mod:`dlm_sway.probes`. This module owns the *outer* envelope — |
| 5 | ``version``, ``models``, ``defaults``, ``suite`` — plus the runtime |
| 6 | bind between raw probe dicts and registered probe classes. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | from typing import Annotated, Any |
| 12 | |
| 13 | from pydantic import BaseModel, ConfigDict, Field, field_validator |
| 14 | |
| 15 | from dlm_sway.core.model import ModelSpec |
| 16 | from dlm_sway.core.result import DEFAULT_COMPONENT_WEIGHTS |
| 17 | |
| 18 | SUPPORTED_VERSION = 1 |
| 19 | |
| 20 | |
| 21 | class SuiteModels(BaseModel): |
| 22 | """Named model handles the suite references — ``base`` + ``ft``.""" |
| 23 | |
| 24 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 25 | |
| 26 | base: ModelSpec |
| 27 | ft: ModelSpec |
| 28 | |
| 29 | |
| 30 | class SuiteDefaults(BaseModel): |
| 31 | """Shared defaults for the whole suite. Probes may override per-entry.""" |
| 32 | |
| 33 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 34 | |
| 35 | seed: int = 0 |
| 36 | top_k: int = 256 |
| 37 | differential: bool = True |
| 38 | """If ``False``, the runner loads base + ft as two separate models |
| 39 | instead of toggling on one. More memory-heavy; only useful when a |
| 40 | backend can't do in-place toggling.""" |
| 41 | coverage_threshold: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6 |
| 42 | """Minimum composite score for ``sway gate`` to pass.""" |
| 43 | concurrent_probes: Annotated[int, Field(ge=1, le=32)] = 1 |
| 44 | """Maximum number of independent probes to dispatch concurrently. |
| 45 | |
| 46 | **Default is 1 (sequential).** Values > 1 are respected only when |
| 47 | the backend declares |
| 48 | :attr:`~dlm_sway.core.scoring.DifferentialBackend.safe_for_concurrent_views` |
| 49 | as ``True`` — which no shipped backend does in v0.1 (see |
| 50 | ``.docs/design/backend-concurrency.md`` / B19). The flag exists so |
| 51 | custom backends that *are* already concurrency-safe (e.g. a |
| 52 | stateless hosted-API backend) can opt in without waiting for the |
| 53 | HF backend fix; shipped backends treat it as a no-op.""" |
| 54 | score_weights: dict[str, float] | None = None |
| 55 | """Per-category weight overrides for the composite score. ``None`` |
| 56 | uses :data:`dlm_sway.core.result.DEFAULT_COMPONENT_WEIGHTS`. Keys |
| 57 | must be a subset of the known categories (``adherence``, |
| 58 | ``attribution``, ``calibration``, ``ablation``, ``baseline``); |
| 59 | unknown keys are rejected. Missing keys inherit the default weight |
| 60 | so a user who only wants to re-weight one category doesn't have to |
| 61 | respecify all of them. All values must be non-negative and at |
| 62 | least one must be positive.""" |
| 63 | |
| 64 | @field_validator("score_weights") |
| 65 | @classmethod |
| 66 | def _validate_weights(cls, v: dict[str, float] | None) -> dict[str, float] | None: |
| 67 | if v is None: |
| 68 | return v |
| 69 | known = set(DEFAULT_COMPONENT_WEIGHTS) |
| 70 | unknown = sorted(set(v) - known) |
| 71 | if unknown: |
| 72 | raise ValueError( |
| 73 | f"score_weights contains unknown category keys: {unknown}. " |
| 74 | f"Known categories: {sorted(known)}" |
| 75 | ) |
| 76 | if any(w < 0.0 for w in v.values()): |
| 77 | raise ValueError("score_weights values must be non-negative") |
| 78 | # Merge with defaults so partial overrides are ergonomic. |
| 79 | merged = dict(DEFAULT_COMPONENT_WEIGHTS) |
| 80 | merged.update(v) |
| 81 | if sum(merged.values()) <= 0.0: |
| 82 | raise ValueError("score_weights must have at least one positive weight") |
| 83 | return merged |
| 84 | |
| 85 | |
| 86 | class SwaySpec(BaseModel): |
| 87 | """Root of ``sway.yaml``.""" |
| 88 | |
| 89 | model_config = ConfigDict(extra="forbid", frozen=True) |
| 90 | |
| 91 | version: int = 1 |
| 92 | models: SuiteModels |
| 93 | defaults: SuiteDefaults = SuiteDefaults() |
| 94 | suite: list[dict[str, Any]] = Field(default_factory=list) |
| 95 | """Raw probe entries. Validated one-at-a-time by the probe registry |
| 96 | via :func:`dlm_sway.probes.base.build_probe` so that the set of |
| 97 | allowed probe kinds is an open registry rather than a closed |
| 98 | discriminated union.""" |
| 99 | dlm_source: str | None = None |
| 100 | """Optional path to a ``.dlm`` file. When present, the runner asks |
| 101 | :mod:`dlm_sway.integrations.dlm.resolver` for typed sections and |
| 102 | hands them to probes via :attr:`RunContext.sections`. Auto-populated |
| 103 | by ``sway autogen``.""" |
| 104 | |
| 105 | def check_version(self) -> None: |
| 106 | """Raise ``ValueError`` if the spec version is unsupported. |
| 107 | |
| 108 | Called explicitly by the loader after validation so the error |
| 109 | surfaces with a loader-source tag rather than a pydantic stack. |
| 110 | """ |
| 111 | if self.version != SUPPORTED_VERSION: |
| 112 | raise ValueError( |
| 113 | f"unsupported sway spec version: {self.version} (this build supports {SUPPORTED_VERSION})" |
| 114 | ) |