| 1 | """Deterministic-execution helper. |
| 2 | |
| 3 | Mirrors ``dlm.train.determinism.seed_everything`` so running the same |
| 4 | suite twice on the same host produces the same :class:`ProbeResult` |
| 5 | payloads. The dlm project treats determinism as a contract; sway takes |
| 6 | the same posture for scoring operations. |
| 7 | |
| 8 | Generation is allowed to use non-deterministic attention kernels when |
| 9 | ``temperature > 0``, because a deterministic sampled generation is a |
| 10 | contradiction. Scoring (logprobs, rolling logprobs, next-token dists) |
| 11 | always runs under :func:`torch.use_deterministic_algorithms(True)`. |
| 12 | """ |
| 13 | |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import os |
| 17 | import random |
| 18 | from dataclasses import dataclass |
| 19 | from typing import Literal |
| 20 | |
| 21 | DeterminismClass = Literal["strict", "best_effort", "loose"] |
| 22 | |
| 23 | |
| 24 | @dataclass(frozen=True, slots=True) |
| 25 | class DeterminismSummary: |
| 26 | """What seeding actually accomplished, for logging in the report.""" |
| 27 | |
| 28 | class_: DeterminismClass |
| 29 | seed: int |
| 30 | notes: tuple[str, ...] = () |
| 31 | |
| 32 | |
| 33 | def seed_everything(seed: int, *, strict: bool = True) -> DeterminismSummary: |
| 34 | """Seed every RNG sway's probes touch and flip backend flags. |
| 35 | |
| 36 | Idempotent — safe to call repeatedly with the same seed. |
| 37 | |
| 38 | Parameters |
| 39 | ---------- |
| 40 | seed: |
| 41 | The seed. Callers typically use the value from ``sway.yaml``'s |
| 42 | ``defaults.seed`` (default 0). |
| 43 | strict: |
| 44 | If ``True`` (the default), request deterministic CUDA algorithms |
| 45 | and set ``CUBLAS_WORKSPACE_CONFIG``. Scoring probes need this; |
| 46 | generation-only runs can set it ``False``. |
| 47 | |
| 48 | Returns |
| 49 | ------- |
| 50 | :class:`DeterminismSummary` with a classification: |
| 51 | |
| 52 | - ``"strict"`` — deterministic algorithms active, no warnings. |
| 53 | - ``"best_effort"`` — platform doesn't support full determinism |
| 54 | (MPS, some CPU kernels). |
| 55 | - ``"loose"`` — seeded but deterministic algorithms refused. |
| 56 | """ |
| 57 | |
| 58 | notes: list[str] = [] |
| 59 | clazz: DeterminismClass = "best_effort" |
| 60 | |
| 61 | # Env vars must come first — torch reads them at cuBLAS init. |
| 62 | if strict: |
| 63 | os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") |
| 64 | |
| 65 | random.seed(seed) |
| 66 | |
| 67 | # numpy is a hard dep; safe to seed unconditionally. |
| 68 | import numpy as np |
| 69 | |
| 70 | np.random.seed(seed) |
| 71 | |
| 72 | try: |
| 73 | import torch # noqa: PLC0415 — lazy: torch is an optional extra. |
| 74 | except ModuleNotFoundError: |
| 75 | notes.append("torch not installed; seeded python + numpy only") |
| 76 | return DeterminismSummary(class_="best_effort", seed=seed, notes=tuple(notes)) |
| 77 | |
| 78 | torch.manual_seed(seed) |
| 79 | if torch.cuda.is_available(): |
| 80 | torch.cuda.manual_seed_all(seed) |
| 81 | clazz = "strict" |
| 82 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| 83 | clazz = "best_effort" |
| 84 | notes.append("MPS: bit-identical across runs is best-effort") |
| 85 | else: |
| 86 | clazz = "best_effort" |
| 87 | notes.append("CPU-only backend: strict determinism depends on BLAS impl") |
| 88 | |
| 89 | if strict: |
| 90 | try: |
| 91 | torch.use_deterministic_algorithms(True, warn_only=True) |
| 92 | torch.backends.cudnn.benchmark = False |
| 93 | except Exception as exc: # noqa: BLE001 — torch raises a naked Exception |
| 94 | clazz = "loose" |
| 95 | notes.append(f"deterministic algorithms refused: {exc}") |
| 96 | |
| 97 | return DeterminismSummary(class_=clazz, seed=seed, notes=tuple(notes)) |