Python · 6520 bytes Raw Blame History
1 """Post-training finite-weights + finite-eval gates.
2
3 Both gates were added after a MPS + tiny-data + no-warmup `dlm train`
4 persisted a fully NaN adapter and silently promoted it to `current.txt`.
5 Downstream consumers then produced NaN logits; the user got exit 0 and
6 no signal. These tests guard against that regression.
7 """
8
9 from __future__ import annotations
10
11 import math
12 from pathlib import Path
13
14 import pytest
15 import torch
16 import torch.nn as nn
17
18 from dlm.train.checkpoint_commit import commit_version
19 from dlm.train.integrity import (
20 NaNEvalError,
21 NaNWeightsError,
22 assert_eval_finite,
23 assert_finite_adapter,
24 audit_trainable_finite,
25 )
26
27
28 class _TinyLoRAModel(nn.Module):
29 """Minimal module with trainable lora_A/lora_B tensors."""
30
31 def __init__(self, *, nan: bool = False, inf: bool = False) -> None:
32 super().__init__()
33 self.lora_A = nn.Parameter(torch.zeros(4, 8))
34 self.lora_B = nn.Parameter(torch.zeros(8, 4))
35 # A frozen "base" param — must be ignored by the audit even if non-finite.
36 self.base = nn.Parameter(torch.full((4, 4), float("nan")), requires_grad=False)
37 if nan:
38 with torch.no_grad():
39 self.lora_A[0, 0] = float("nan")
40 if inf:
41 with torch.no_grad():
42 self.lora_B[1, 1] = float("inf")
43
44
45 class TestAuditTrainableFinite:
46 def test_clean_model_reports_ok(self) -> None:
47 result = audit_trainable_finite(_TinyLoRAModel())
48 assert result.ok
49 assert result.checked == 2 # lora_A + lora_B
50 assert result.offending == ()
51
52 def test_nan_trainable_param_is_flagged(self) -> None:
53 result = audit_trainable_finite(_TinyLoRAModel(nan=True))
54 assert not result.ok
55 assert "lora_A" in result.offending[0]
56
57 def test_inf_trainable_param_is_flagged(self) -> None:
58 result = audit_trainable_finite(_TinyLoRAModel(inf=True))
59 assert not result.ok
60 assert "lora_B" in result.offending[0]
61
62 def test_frozen_nan_base_is_ignored(self) -> None:
63 # The frozen `base` tensor is NaN-filled but requires_grad=False;
64 # audit must skip it (we only check what training updates).
65 result = audit_trainable_finite(_TinyLoRAModel())
66 assert result.ok
67 assert all("base" not in name for name in result.offending)
68
69
70 class TestAssertFiniteAdapter:
71 def test_clean_model_does_not_raise(self) -> None:
72 assert_finite_adapter(_TinyLoRAModel()) # no raise
73
74 def test_nan_model_raises(self) -> None:
75 with pytest.raises(NaNWeightsError) as exc:
76 assert_finite_adapter(_TinyLoRAModel(nan=True))
77 assert "NaN/inf" in str(exc.value)
78 assert len(exc.value.full_offending) == 1
79
80
81 class TestAssertEvalFinite:
82 def test_empty_log_history_no_raise(self) -> None:
83 assert_eval_finite([])
84
85 def test_no_eval_entries_no_raise(self) -> None:
86 # log_history may contain train-only entries — the contract is
87 # "check iff eval ran", so no eval entries means nothing to check.
88 assert_eval_finite([{"loss": 2.0, "step": 1}, {"loss": 1.5, "step": 2}])
89
90 def test_non_dict_entries_ignored(self) -> None:
91 assert_eval_finite([{"loss": 2.0, "step": 1}, "not-a-dict"])
92
93 def test_finite_eval_does_not_raise(self) -> None:
94 assert_eval_finite([{"eval_loss": 1.8, "step": 10}])
95
96 def test_nan_eval_raises(self) -> None:
97 with pytest.raises(NaNEvalError) as exc:
98 assert_eval_finite([{"eval_loss": float("nan"), "step": 10}])
99 assert math.isnan(exc.value.value)
100
101 def test_inf_eval_raises(self) -> None:
102 with pytest.raises(NaNEvalError):
103 assert_eval_finite([{"eval_loss": float("inf"), "step": 10}])
104
105 def test_last_eval_is_authoritative(self) -> None:
106 # Walks from the tail — the last eval entry (NaN here) wins even
107 # if an earlier one was finite.
108 with pytest.raises(NaNEvalError):
109 assert_eval_finite(
110 [
111 {"eval_loss": 1.5, "step": 10},
112 {"eval_loss": float("nan"), "step": 20},
113 ]
114 )
115
116 def test_only_walks_until_first_eval_entry(self) -> None:
117 # If the final eval is finite, the gate is satisfied — earlier
118 # NaN eval entries don't matter (they were historical).
119 assert_eval_finite(
120 [
121 {"eval_loss": float("nan"), "step": 10},
122 {"eval_loss": 1.3, "step": 20},
123 ]
124 )
125
126
127 class _FakeStore:
128 """Minimal StorePath stand-in for commit_version's integration test."""
129
130 def __init__(self, root: Path) -> None:
131 self.adapter_versions = root / "versions"
132 self.adapter_versions.mkdir(parents=True)
133 self._current: Path | None = None
134
135 def adapter_version(self, n: int) -> Path:
136 return self.adapter_versions / f"v{n:04d}"
137
138 def set_current_adapter(self, path: Path) -> None:
139 self._current = path
140
141
142 class TestCommitVersionRenamesOnNaN:
143 def test_rejected_dir_created_and_current_not_flipped(self, tmp_path: Path) -> None:
144 store = _FakeStore(tmp_path)
145
146 def writer(pending: Path) -> None:
147 # Simulate: save weights (touch a file), then gate fails.
148 (pending / "adapter_model.safetensors").write_bytes(b"bogus")
149 raise NaNWeightsError(["base_model.lora_A.weight"])
150
151 with pytest.raises(NaNWeightsError):
152 commit_version(store, writer) # type: ignore[arg-type]
153
154 # The rejected dir exists with the saved (bad) weights preserved.
155 rejected = tmp_path / "versions" / "v0001-rejected"
156 assert rejected.exists()
157 assert (rejected / "adapter_model.safetensors").exists()
158 # The plain v0001 dir no longer exists (was renamed).
159 assert not (tmp_path / "versions" / "v0001").exists()
160 # current.txt was never flipped.
161 assert store._current is None
162
163 def test_next_version_skips_rejected_name(self, tmp_path: Path) -> None:
164 # After a rejected commit, the next allocate_next_version should
165 # still pick v0001 (the `-rejected` suffix makes the old one
166 # unparseable as a version number).
167 from dlm.train.checkpoint_commit import allocate_next_version
168
169 store = _FakeStore(tmp_path)
170 (store.adapter_versions / "v0001-rejected").mkdir()
171 next_dir = allocate_next_version(store) # type: ignore[arg-type]
172 assert next_dir.name == "v0001"