Python · 7183 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.suite.runner`.
2
3 Uses the dummy backend + ad-hoc probe classes so nothing real is loaded.
4 """
5
6 from __future__ import annotations
7
8 from typing import Literal
9
10 import numpy as np
11 import pytest
12
13 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
14 from dlm_sway.core.errors import ProbeError
15 from dlm_sway.core.result import ProbeResult, Verdict
16 from dlm_sway.probes.base import Probe, ProbeSpec, RunContext
17 from dlm_sway.suite.runner import run
18 from dlm_sway.suite.spec import SwaySpec
19
20
21 class _PassSpec(ProbeSpec):
22 kind: Literal["__runner_pass"] = "__runner_pass"
23
24
25 class _PassProbe(Probe):
26 kind = "__runner_pass"
27 spec_cls = _PassSpec
28 category = "adherence"
29
30 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
31 return ProbeResult(name=spec.name, kind=spec.kind, verdict=Verdict.PASS, score=0.9)
32
33
34 class _FailSpec(ProbeSpec):
35 kind: Literal["__runner_fail"] = "__runner_fail"
36
37
38 class _FailProbe(Probe):
39 kind = "__runner_fail"
40 spec_cls = _FailSpec
41 category = "attribution"
42
43 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
44 return ProbeResult(name=spec.name, kind=spec.kind, verdict=Verdict.FAIL, score=0.1)
45
46
47 class _RaiseSpec(ProbeSpec):
48 kind: Literal["__runner_raise"] = "__runner_raise"
49
50
51 class _RaiseProbe(Probe):
52 kind = "__runner_raise"
53 spec_cls = _RaiseSpec
54
55 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
56 raise ProbeError(spec.kind, "kaboom")
57
58
59 class _UnexpectedSpec(ProbeSpec):
60 kind: Literal["__runner_unexpected"] = "__runner_unexpected"
61
62
63 class _UnexpectedProbe(Probe):
64 kind = "__runner_unexpected"
65 spec_cls = _UnexpectedSpec
66
67 def run(self, spec: ProbeSpec, ctx: RunContext) -> ProbeResult:
68 raise ValueError("surprise")
69
70
71 @pytest.fixture
72 def backend() -> DummyDifferentialBackend:
73 return DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
74
75
76 def _spec(*entries: dict) -> SwaySpec:
77 return SwaySpec.model_validate(
78 {
79 "version": 1,
80 "models": {
81 "base": {"base": "b"},
82 "ft": {"base": "b", "adapter": "/tmp/a"},
83 },
84 "suite": list(entries),
85 }
86 )
87
88
89 class TestRunner:
90 def test_runs_each_probe_in_order(self, backend: DummyDifferentialBackend) -> None:
91 spec = _spec(
92 {"name": "p1", "kind": "__runner_pass"},
93 {"name": "p2", "kind": "__runner_fail"},
94 )
95 result = run(spec, backend)
96 assert [r.name for r in result.probes] == ["p1", "p2"]
97 assert result.probes[0].verdict == Verdict.PASS
98 assert result.probes[1].verdict == Verdict.FAIL
99
100 def test_disabled_probe_records_skip(self, backend: DummyDifferentialBackend) -> None:
101 spec = _spec({"name": "p1", "kind": "__runner_pass", "enabled": False})
102 result = run(spec, backend)
103 assert result.probes[0].verdict == Verdict.SKIP
104 assert "disabled" in result.probes[0].message
105
106 def test_probeerror_becomes_error_verdict(self, backend: DummyDifferentialBackend) -> None:
107 spec = _spec({"name": "oops", "kind": "__runner_raise"})
108 result = run(spec, backend)
109 assert result.probes[0].verdict == Verdict.ERROR
110 assert "kaboom" in result.probes[0].message
111
112 def test_unexpected_exception_becomes_error_verdict(
113 self, backend: DummyDifferentialBackend
114 ) -> None:
115 spec = _spec({"name": "oops", "kind": "__runner_unexpected"})
116 result = run(spec, backend)
117 assert result.probes[0].verdict == Verdict.ERROR
118 assert "ValueError" in result.probes[0].message
119
120 def test_wall_seconds_populated(self, backend: DummyDifferentialBackend) -> None:
121 spec = _spec({"name": "p1", "kind": "__runner_pass"})
122 result = run(spec, backend)
123 assert result.wall_seconds >= 0
124 assert result.probes[0].duration_s >= 0
125
126 def test_null_adapter_passes_on_null_calibrated_backend(
127 self, backend: DummyDifferentialBackend
128 ) -> None:
129 # Dummy backend implements NullCalibratedBackend, so calibration runs.
130 # Explicit calibrate_kinds so it runs even without downstream probes.
131 spec = _spec(
132 {
133 "name": "null",
134 "kind": "null_adapter",
135 "runs": 2,
136 "calibrate_kinds": ["delta_kl"],
137 }
138 )
139 result = run(spec, backend)
140 assert result.probes[0].kind == "null_adapter"
141 assert result.probes[0].verdict == Verdict.PASS
142 # And the suite's null_stats bubbles up onto the result.
143 assert "delta_kl" in result.null_stats
144
145
146 class TestPreflightGate:
147 """The S01 preflight gate: a NaN-producing backend aborts the suite.
148
149 No probe runs; the SuiteResult contains a single synthetic ERROR
150 probe explaining the abort.
151 """
152
153 def test_preflight_failure_aborts_suite(self) -> None:
154 import math
155
156 from dlm_sway.core.scoring import TokenDist
157
158 # Seed a NaN dist on the ft side under the preflight prompt.
159 nan_dist = TokenDist(
160 token_ids=np.array([1, 2], dtype=np.int64),
161 logprobs=np.array([math.nan, -0.5], dtype=np.float32),
162 vocab_size=100,
163 )
164 ft = DummyResponses(token_dists={"preflight": nan_dist})
165 bad_backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
166
167 spec = _spec(
168 {"name": "p1", "kind": "__runner_pass"},
169 {"name": "p2", "kind": "__runner_pass"},
170 )
171 result = run(spec, bad_backend)
172
173 # Exactly one synthetic ERROR probe; no configured probes ran.
174 assert len(result.probes) == 1
175 assert result.probes[0].kind == "preflight"
176 assert result.probes[0].verdict == Verdict.ERROR
177 assert "preflight failed" in result.probes[0].message
178 assert "ft view" in result.probes[0].message
179 # Configured probe names did not run.
180 assert "p1" not in {p.name for p in result.probes}
181
182 def test_skip_preflight_flag_runs_suite_anyway(self) -> None:
183 import math
184
185 from dlm_sway.core.scoring import TokenDist
186
187 nan_dist = TokenDist(
188 token_ids=np.array([1, 2], dtype=np.int64),
189 logprobs=np.array([math.nan, -0.5], dtype=np.float32),
190 vocab_size=100,
191 )
192 ft = DummyResponses(token_dists={"preflight": nan_dist})
193 bad_backend = DummyDifferentialBackend(base=DummyResponses(), ft=ft)
194
195 spec = _spec({"name": "p1", "kind": "__runner_pass"})
196 result = run(spec, bad_backend, skip_preflight=True)
197 # Probe ran (ignoring the unhealthy backend).
198 assert len(result.probes) == 1
199 assert result.probes[0].name == "p1"
200
201 def test_finite_backend_preflight_passes_through(
202 self, backend: DummyDifferentialBackend
203 ) -> None:
204 spec = _spec({"name": "p1", "kind": "__runner_pass"})
205 result = run(spec, backend)
206 # No synthetic preflight probe injected; configured probe ran.
207 assert len(result.probes) == 1
208 assert result.probes[0].name == "p1"