Python · 4635 bytes Raw Blame History
1 """Top-level ``sway.yaml`` spec models.
2
3 Per-probe specs live next to their implementations in
4 :mod:`dlm_sway.probes`. This module owns the *outer* envelope —
5 ``version``, ``models``, ``defaults``, ``suite`` — plus the runtime
6 bind between raw probe dicts and registered probe classes.
7 """
8
9 from __future__ import annotations
10
11 from typing import Annotated, Any
12
13 from pydantic import BaseModel, ConfigDict, Field, field_validator
14
15 from dlm_sway.core.model import ModelSpec
16 from dlm_sway.core.result import DEFAULT_COMPONENT_WEIGHTS
17
18 SUPPORTED_VERSION = 1
19
20
21 class SuiteModels(BaseModel):
22 """Named model handles the suite references — ``base`` + ``ft``."""
23
24 model_config = ConfigDict(extra="forbid", frozen=True)
25
26 base: ModelSpec
27 ft: ModelSpec
28
29
30 class SuiteDefaults(BaseModel):
31 """Shared defaults for the whole suite. Probes may override per-entry."""
32
33 model_config = ConfigDict(extra="forbid", frozen=True)
34
35 seed: int = 0
36 top_k: int = 256
37 differential: bool = True
38 """If ``False``, the runner loads base + ft as two separate models
39 instead of toggling on one. More memory-heavy; only useful when a
40 backend can't do in-place toggling."""
41 coverage_threshold: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6
42 """Minimum composite score for ``sway gate`` to pass."""
43 concurrent_probes: Annotated[int, Field(ge=1, le=32)] = 1
44 """Maximum number of independent probes to dispatch concurrently.
45
46 **Default is 1 (sequential).** Values > 1 are respected only when
47 the backend declares
48 :attr:`~dlm_sway.core.scoring.DifferentialBackend.safe_for_concurrent_views`
49 as ``True`` — which no shipped backend does in v0.1 (see
50 ``.docs/design/backend-concurrency.md`` / B19). The flag exists so
51 custom backends that *are* already concurrency-safe (e.g. a
52 stateless hosted-API backend) can opt in without waiting for the
53 HF backend fix; shipped backends treat it as a no-op."""
54 score_weights: dict[str, float] | None = None
55 """Per-category weight overrides for the composite score. ``None``
56 uses :data:`dlm_sway.core.result.DEFAULT_COMPONENT_WEIGHTS`. Keys
57 must be a subset of the known categories (``adherence``,
58 ``attribution``, ``calibration``, ``ablation``, ``baseline``);
59 unknown keys are rejected. Missing keys inherit the default weight
60 so a user who only wants to re-weight one category doesn't have to
61 respecify all of them. All values must be non-negative and at
62 least one must be positive."""
63
64 @field_validator("score_weights")
65 @classmethod
66 def _validate_weights(cls, v: dict[str, float] | None) -> dict[str, float] | None:
67 if v is None:
68 return v
69 known = set(DEFAULT_COMPONENT_WEIGHTS)
70 unknown = sorted(set(v) - known)
71 if unknown:
72 raise ValueError(
73 f"score_weights contains unknown category keys: {unknown}. "
74 f"Known categories: {sorted(known)}"
75 )
76 if any(w < 0.0 for w in v.values()):
77 raise ValueError("score_weights values must be non-negative")
78 # Merge with defaults so partial overrides are ergonomic.
79 merged = dict(DEFAULT_COMPONENT_WEIGHTS)
80 merged.update(v)
81 if sum(merged.values()) <= 0.0:
82 raise ValueError("score_weights must have at least one positive weight")
83 return merged
84
85
86 class SwaySpec(BaseModel):
87 """Root of ``sway.yaml``."""
88
89 model_config = ConfigDict(extra="forbid", frozen=True)
90
91 version: int = 1
92 models: SuiteModels
93 defaults: SuiteDefaults = SuiteDefaults()
94 suite: list[dict[str, Any]] = Field(default_factory=list)
95 """Raw probe entries. Validated one-at-a-time by the probe registry
96 via :func:`dlm_sway.probes.base.build_probe` so that the set of
97 allowed probe kinds is an open registry rather than a closed
98 discriminated union."""
99 dlm_source: str | None = None
100 """Optional path to a ``.dlm`` file. When present, the runner asks
101 :mod:`dlm_sway.integrations.dlm.resolver` for typed sections and
102 hands them to probes via :attr:`RunContext.sections`. Auto-populated
103 by ``sway autogen``."""
104
105 def check_version(self) -> None:
106 """Raise ``ValueError`` if the spec version is unsupported.
107
108 Called explicitly by the loader after validation so the error
109 surfaces with a loader-source tag rather than a pydantic stack.
110 """
111 if self.version != SUPPORTED_VERSION:
112 raise ValueError(
113 f"unsupported sway spec version: {self.version} (this build supports {SUPPORTED_VERSION})"
114 )