Python · 4402 bytes Raw Blame History
1 """Tests for :mod:`dlm_sway.core.determinism`."""
2
3 from __future__ import annotations
4
5 import os
6 import random
7
8 import numpy as np
9
10 from dlm_sway.core.determinism import DeterminismSummary, seed_everything
11
12
13 class TestSeedEverything:
14 def test_returns_summary(self) -> None:
15 summary = seed_everything(0)
16 assert isinstance(summary, DeterminismSummary)
17 assert summary.seed == 0
18 assert summary.class_ in {"strict", "best_effort", "loose"}
19
20 def test_idempotent_for_stdlib_random(self) -> None:
21 seed_everything(42)
22 a = [random.random() for _ in range(5)]
23 seed_everything(42)
24 b = [random.random() for _ in range(5)]
25 assert a == b
26
27 def test_idempotent_for_numpy(self) -> None:
28 seed_everything(17)
29 a = np.random.rand(5)
30 seed_everything(17)
31 b = np.random.rand(5)
32 np.testing.assert_array_equal(a, b)
33
34 def test_cublas_workspace_set_under_strict(self) -> None:
35 os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
36 seed_everything(0, strict=True)
37 assert os.environ.get("CUBLAS_WORKSPACE_CONFIG") == ":4096:8"
38
39 def test_non_strict_does_not_set_cublas(self) -> None:
40 os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
41 seed_everything(0, strict=False)
42 # Non-strict mode must not leak the env var in either direction;
43 # the host environment's prior value wins.
44 assert (
45 "CUBLAS_WORKSPACE_CONFIG" not in os.environ
46 or os.environ["CUBLAS_WORKSPACE_CONFIG"] != ":4096:8"
47 )
48
49
50 class TestRunnerCallsSeedEverything:
51 """The runner must seed every RNG before any probe runs (P09)."""
52
53 def test_runner_populates_determinism_field(self) -> None:
54 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
55 from dlm_sway.suite.runner import run as run_suite
56 from dlm_sway.suite.spec import SwaySpec
57
58 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
59 spec = SwaySpec.model_validate(
60 {
61 "version": 1,
62 "models": {
63 "base": {"base": "b"},
64 "ft": {"base": "b", "adapter": "/tmp/a"},
65 },
66 "defaults": {"seed": 7},
67 "suite": [],
68 }
69 )
70 result = run_suite(spec, backend)
71 assert result.determinism is not None
72 assert result.determinism.seed == 7
73 assert result.determinism.class_ in {"strict", "best_effort", "loose"}
74
75 def test_runner_seeds_before_first_probe(self, monkeypatch) -> None:
76 """Reorder check: seed_everything must fire *before* the probe loop."""
77 from dlm_sway.backends.dummy import DummyDifferentialBackend, DummyResponses
78 from dlm_sway.suite import runner as runner_mod
79 from dlm_sway.suite.spec import SwaySpec
80
81 events: list[str] = []
82
83 original_seed = runner_mod.seed_everything
84
85 def recording_seed(seed: int, *, strict: bool = True):
86 events.append(f"seed={seed}")
87 return original_seed(seed, strict=strict)
88
89 monkeypatch.setattr(runner_mod, "seed_everything", recording_seed)
90
91 # Use the dummy preflight as a probe stand-in: it runs *after*
92 # seeding in the runner, so its event lands after the seed event.
93 from dlm_sway.backends import dummy as dummy_mod
94
95 original_preflight = dummy_mod.DummyDifferentialBackend.preflight_finite_check
96
97 def recording_preflight(self):
98 events.append("preflight")
99 return original_preflight(self)
100
101 monkeypatch.setattr(
102 dummy_mod.DummyDifferentialBackend,
103 "preflight_finite_check",
104 recording_preflight,
105 )
106
107 backend = DummyDifferentialBackend(base=DummyResponses(), ft=DummyResponses())
108 spec = SwaySpec.model_validate(
109 {
110 "version": 1,
111 "models": {
112 "base": {"base": "b"},
113 "ft": {"base": "b", "adapter": "/tmp/a"},
114 },
115 "defaults": {"seed": 11},
116 "suite": [],
117 }
118 )
119 runner_mod.run(spec, backend)
120 assert events.index("seed=11") < events.index("preflight"), (
121 f"seed must fire before preflight; got {events}"
122 )