Python · 3325 bytes Raw Blame History
1 """Deterministic-execution helper.
2
3 Mirrors ``dlm.train.determinism.seed_everything`` so running the same
4 suite twice on the same host produces the same :class:`ProbeResult`
5 payloads. The dlm project treats determinism as a contract; sway takes
6 the same posture for scoring operations.
7
8 Generation is allowed to use non-deterministic attention kernels when
9 ``temperature > 0``, because a deterministic sampled generation is a
10 contradiction. Scoring (logprobs, rolling logprobs, next-token dists)
11 always runs under :func:`torch.use_deterministic_algorithms(True)`.
12 """
13
14 from __future__ import annotations
15
16 import os
17 import random
18 from dataclasses import dataclass
19 from typing import Literal
20
21 DeterminismClass = Literal["strict", "best_effort", "loose"]
22
23
24 @dataclass(frozen=True, slots=True)
25 class DeterminismSummary:
26 """What seeding actually accomplished, for logging in the report."""
27
28 class_: DeterminismClass
29 seed: int
30 notes: tuple[str, ...] = ()
31
32
33 def seed_everything(seed: int, *, strict: bool = True) -> DeterminismSummary:
34 """Seed every RNG sway's probes touch and flip backend flags.
35
36 Idempotent — safe to call repeatedly with the same seed.
37
38 Parameters
39 ----------
40 seed:
41 The seed. Callers typically use the value from ``sway.yaml``'s
42 ``defaults.seed`` (default 0).
43 strict:
44 If ``True`` (the default), request deterministic CUDA algorithms
45 and set ``CUBLAS_WORKSPACE_CONFIG``. Scoring probes need this;
46 generation-only runs can set it ``False``.
47
48 Returns
49 -------
50 :class:`DeterminismSummary` with a classification:
51
52 - ``"strict"`` — deterministic algorithms active, no warnings.
53 - ``"best_effort"`` — platform doesn't support full determinism
54 (MPS, some CPU kernels).
55 - ``"loose"`` — seeded but deterministic algorithms refused.
56 """
57
58 notes: list[str] = []
59 clazz: DeterminismClass = "best_effort"
60
61 # Env vars must come first — torch reads them at cuBLAS init.
62 if strict:
63 os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
64
65 random.seed(seed)
66
67 # numpy is a hard dep; safe to seed unconditionally.
68 import numpy as np
69
70 np.random.seed(seed)
71
72 try:
73 import torch # noqa: PLC0415 — lazy: torch is an optional extra.
74 except ModuleNotFoundError:
75 notes.append("torch not installed; seeded python + numpy only")
76 return DeterminismSummary(class_="best_effort", seed=seed, notes=tuple(notes))
77
78 torch.manual_seed(seed)
79 if torch.cuda.is_available():
80 torch.cuda.manual_seed_all(seed)
81 clazz = "strict"
82 elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
83 clazz = "best_effort"
84 notes.append("MPS: bit-identical across runs is best-effort")
85 else:
86 clazz = "best_effort"
87 notes.append("CPU-only backend: strict determinism depends on BLAS impl")
88
89 if strict:
90 try:
91 torch.use_deterministic_algorithms(True, warn_only=True)
92 torch.backends.cudnn.benchmark = False
93 except Exception as exc: # noqa: BLE001 — torch raises a naked Exception
94 clazz = "loose"
95 notes.append(f"deterministic algorithms refused: {exc}")
96
97 return DeterminismSummary(class_=clazz, seed=seed, notes=tuple(notes))