tenseleyflow/documentlanguagemodel / 3220813

Browse files

feat(train): integrity gates — NaNWeightsError + NaNEvalError

Authored by espadonne
SHA
32208138141f22a0091cc1fc3979b61222de27eb
Parents
4d0e290
Tree
955550b

2 changed files

StatusFile+-
M src/dlm/train/__init__.py 12 0
A src/dlm/train/integrity.py 142 0
src/dlm/train/__init__.pymodified
@@ -21,6 +21,13 @@ from dlm.train.errors import (
2121
     TrainingError,
2222
     VersionDriftWarning,
2323
 )
24
+from dlm.train.integrity import (
25
+    NaNEvalError,
26
+    NaNWeightsError,
27
+    assert_eval_finite,
28
+    assert_finite_adapter,
29
+    audit_trainable_finite,
30
+)
2431
 from dlm.train.logger import Banner, StepLogger, log_path_for
2532
 from dlm.train.oom_guard import format_oom_message, recommend_grad_accum
2633
 from dlm.train.state_sidecar import (
@@ -40,6 +47,8 @@ __all__ = [
4047
     "Banner",
4148
     "DeterminismSummary",
4249
     "DiskSpaceError",
50
+    "NaNEvalError",
51
+    "NaNWeightsError",
4352
     "OOMError",
4453
     "PinnedVersions",
4554
     "ResumeIntegrityError",
@@ -53,6 +62,9 @@ __all__ = [
5362
     "VERSIONS_FILENAME",
5463
     "VersionDriftWarning",
5564
     "allocate_next_version",
65
+    "assert_eval_finite",
66
+    "assert_finite_adapter",
67
+    "audit_trainable_finite",
5668
     "capture_runtime_versions",
5769
     "commit_version",
5870
     "estimate_checkpoint_bytes",
src/dlm/train/integrity.pyadded
@@ -0,0 +1,142 @@
1
+"""Post-training weight-integrity gate.
2
+
3
+Invariant
4
+---------
5
+
6
+Training must never persist an adapter whose weights contain NaN or
7
+infinite values. Downstream consumers (inference, export, evaluation)
8
+would then silently produce NaN logits — and the `dlm train` exit code
9
+would be 0, giving the user no signal that the run failed.
10
+
11
+This module walks the trainable parameters of a wrapped PEFT model and
12
+asserts every element is finite. The check runs inside the
13
+`_write_checkpoint` writer, between `save_model` and the caller's
14
+atomic pointer flip. If the check fails, the exception propagates up
15
+through `commit_version` — which by contract leaves the pending
16
+version directory on disk and does **not** flip `current.txt`. The
17
+trainer then renames the pending dir to a `-rejected` suffix so the
18
+next `allocate_next_version` skips it (the suffix makes the directory
19
+name unparseable as `vNNNN`) and the user can inspect the bad weights
20
+for postmortem.
21
+
22
+Only trainable parameters are checked. Frozen base-model weights are
23
+both huge and (by construction) unchanged by training, so validating
24
+them is wasted I/O without corresponding signal.
25
+"""
26
+
27
+from __future__ import annotations
28
+
29
+from dataclasses import dataclass
30
+from typing import TYPE_CHECKING, Any
31
+
32
+from dlm.train.errors import TrainingError
33
+
34
+if TYPE_CHECKING:
35
+    pass
36
+
37
+
38
+class NaNEvalError(TrainingError):
39
+    """The trainer's final eval_loss was NaN or infinite — refusing to commit.
40
+
41
+    Fires in the trainer orchestrator after `summarize_eval_state` and
42
+    before the two-phase commit. Redundant with `NaNWeightsError`: when
43
+    eval diverges, the weights are usually already corrupt, so both
44
+    gates catch the same failure from different angles. The redundancy
45
+    is intentional — catching it twice is cheap, catching it zero times
46
+    (the original bug) silently poisons downstream consumers.
47
+    """
48
+
49
+    def __init__(self, value: float) -> None:
50
+        super().__init__(
51
+            f"final eval_loss is non-finite ({value!r}) — refusing to commit "
52
+            "adapter. Lower LR, add warmup, or check dataset for NaN-producing inputs."
53
+        )
54
+        self.value = value
55
+
56
+
57
+def assert_eval_finite(log_history: list[Any]) -> None:
58
+    """Raise `NaNEvalError` if the last `eval_loss` entry is non-finite.
59
+
60
+    No-op when `log_history` is empty or contains no eval entries —
61
+    eval is optional, and runs without a `eval_dataset` legitimately
62
+    have no eval_loss to check.
63
+    """
64
+    import math
65
+
66
+    for entry in reversed(log_history):
67
+        if not isinstance(entry, dict):
68
+            continue
69
+        value = entry.get("eval_loss")
70
+        if isinstance(value, (int, float)):
71
+            if not math.isfinite(float(value)):
72
+                raise NaNEvalError(float(value))
73
+            return  # last eval was finite — good
74
+    # No eval entries — the caller's contract is "check iff eval ran".
75
+
76
+
77
+class NaNWeightsError(TrainingError):
78
+    """A trained adapter contains non-finite weights — refusing to persist.
79
+
80
+    `offending` is the list of parameter names that failed the check.
81
+    Truncated to the first 20 names for readability; the full list is
82
+    available via the `full_offending` attribute.
83
+    """
84
+
85
+    def __init__(self, offending: list[str]) -> None:
86
+        display = offending[:20]
87
+        suffix = f" (and {len(offending) - 20} more)" if len(offending) > 20 else ""
88
+        super().__init__(
89
+            f"adapter weights contain NaN/inf in {len(offending)} tensor(s) — "
90
+            f"refusing to persist. Offenders: {', '.join(display)}{suffix}"
91
+        )
92
+        self.offending = display
93
+        self.full_offending = offending
94
+
95
+
96
+@dataclass(frozen=True)
97
+class FiniteCheckResult:
98
+    """Outcome of a finite-weight audit, used by unit tests + the gate."""
99
+
100
+    checked: int
101
+    offending: tuple[str, ...]
102
+
103
+    @property
104
+    def ok(self) -> bool:
105
+        return not self.offending
106
+
107
+
108
+def audit_trainable_finite(model: Any) -> FiniteCheckResult:
109
+    """Walk `model.named_parameters()` and flag non-finite trainable tensors.
110
+
111
+    Only `requires_grad=True` parameters are inspected — these are the
112
+    LoRA `lora_A` / `lora_B` / `modules_to_save` tensors that training
113
+    actually updates. Frozen base weights are skipped.
114
+
115
+    Returns a `FiniteCheckResult` even when everything is finite, so
116
+    callers can log the count of checked tensors.
117
+    """
118
+    import torch
119
+
120
+    offending: list[str] = []
121
+    checked = 0
122
+    for name, param in model.named_parameters():
123
+        if not getattr(param, "requires_grad", False):
124
+            continue
125
+        checked += 1
126
+        if not torch.isfinite(param).all():
127
+            offending.append(name)
128
+    return FiniteCheckResult(checked=checked, offending=tuple(offending))
129
+
130
+
131
+def assert_finite_adapter(model: Any) -> None:
132
+    """Raise `NaNWeightsError` if any trainable parameter is non-finite.
133
+
134
+    Called from `_write_checkpoint` right after `sft.save_model()`. If
135
+    the weights are bad, the version directory on disk still contains
136
+    the saved (bad) adapter_model.safetensors — the trainer relies on
137
+    `commit_version` leaving the pending dir alone when the writer
138
+    raises, and then renames it to `{pending}-rejected` for postmortem.
139
+    """
140
+    result = audit_trainable_finite(model)
141
+    if not result.ok:
142
+        raise NaNWeightsError(list(result.offending))