Python · 7850 bytes Raw Blame History
1 """Two-phase checkpoint commit — allocation, flip, crash safety."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6
7 import pytest
8
9 import dlm.train.checkpoint_commit as checkpoint_commit
10 from dlm.store.paths import for_dlm
11 from dlm.train.checkpoint_commit import (
12 _uniquify_rejected,
13 allocate_next_version,
14 commit_version,
15 fsync_dir,
16 list_pending_versions,
17 )
18 from dlm.train.integrity import NaNWeightsError
19
20
21 def _store(home: Path):
22 store = for_dlm("01TEST", home=home)
23 store.ensure_layout()
24 return store
25
26
27 class TestAllocation:
28 def test_first_version_is_v0001(self, tmp_path: Path) -> None:
29 store = _store(tmp_path)
30 v1 = allocate_next_version(store)
31 assert v1.name == "v0001"
32 assert v1.exists()
33
34 def test_allocates_sequentially(self, tmp_path: Path) -> None:
35 store = _store(tmp_path)
36 v1 = allocate_next_version(store)
37 v2 = allocate_next_version(store)
38 v3 = allocate_next_version(store)
39 assert v2.name == "v0002"
40 assert v3.name == "v0003"
41 # Non-reuse: allocating over an existing dir is rejected.
42 assert v1.name != v2.name
43
44 def test_ignores_non_version_dirs(self, tmp_path: Path) -> None:
45 """Stray dirs (e.g., tmp/ or backup/) don't advance the counter."""
46 store = _store(tmp_path)
47 (store.adapter_versions / "scratch").mkdir()
48 (store.adapter_versions / "v-not-a-number").mkdir()
49 v1 = allocate_next_version(store)
50 assert v1.name == "v0001"
51
52 def test_ignores_non_dir_entries(self, tmp_path: Path) -> None:
53 store = _store(tmp_path)
54 (store.adapter_versions / "v0009").write_text("not a directory")
55 v1 = allocate_next_version(store)
56 assert v1.name == "v0001"
57
58
59 class TestCommitVersion:
60 def test_happy_path_flips_current(self, tmp_path: Path) -> None:
61 store = _store(tmp_path)
62
63 def writer(p: Path) -> None:
64 (p / "adapter_config.json").write_text("{}")
65
66 committed = commit_version(store, writer)
67 assert committed.name == "v0001"
68 assert store.resolve_current_adapter() == committed
69
70 def test_writer_exception_preserves_current(self, tmp_path: Path) -> None:
71 """Simulated crash-during-write: old current stays authoritative."""
72 store = _store(tmp_path)
73
74 # First, do a successful commit so `current` has something to preserve.
75 v1 = commit_version(store, lambda p: (p / "ok").write_text("ok"))
76 assert store.resolve_current_adapter() == v1
77
78 # Now a writer that crashes.
79 def bad_writer(p: Path) -> None:
80 (p / "partial").write_text("half")
81 raise RuntimeError("SIGKILL-like")
82
83 with pytest.raises(RuntimeError, match="SIGKILL-like"):
84 commit_version(store, bad_writer)
85
86 # Current pointer is unchanged.
87 assert store.resolve_current_adapter() == v1
88
89 def test_pending_dir_remains_after_crash(self, tmp_path: Path) -> None:
90 """Left on disk for post-mortem; next allocate skips over it."""
91 store = _store(tmp_path)
92
93 def bad_writer(p: Path) -> None:
94 (p / "x").write_text("x")
95 raise RuntimeError("crash")
96
97 with pytest.raises(RuntimeError):
98 commit_version(store, bad_writer)
99
100 # v0001 should exist (even though it's not current).
101 v1_path = store.adapter_version(1)
102 assert v1_path.exists()
103
104 # Next allocate goes to v0002.
105 v2 = allocate_next_version(store)
106 assert v2.name == "v0002"
107
108 def test_nonfinite_writer_uniquify_failure_leaves_pending(
109 self, tmp_path: Path, monkeypatch
110 ) -> None:
111 store = _store(tmp_path)
112
113 def bad_writer(p: Path) -> None:
114 (p / "weights.safetensors").write_text("bad")
115 raise NaNWeightsError(["adapter.lora_A"])
116
117 def boom(_: Path) -> Path:
118 raise RuntimeError("no rejected slot")
119
120 monkeypatch.setattr(checkpoint_commit, "_uniquify_rejected", boom)
121
122 with pytest.raises(RuntimeError, match="no rejected slot"):
123 commit_version(store, bad_writer)
124
125 assert store.adapter_version(1).exists()
126 assert store.resolve_current_adapter() is None
127
128 def test_nonfinite_writer_rename_failure_still_reraises(
129 self, tmp_path: Path, monkeypatch
130 ) -> None:
131 store = _store(tmp_path)
132
133 def bad_writer(p: Path) -> None:
134 (p / "weights.safetensors").write_text("bad")
135 raise NaNWeightsError(["adapter.lora_B"])
136
137 def bad_rename(self: Path, target: Path) -> Path:
138 raise OSError("rename blocked")
139
140 monkeypatch.setattr(Path, "rename", bad_rename)
141
142 with pytest.raises(NaNWeightsError, match="NaN/inf"):
143 commit_version(store, bad_writer)
144
145 assert store.adapter_version(1).exists()
146 assert store.resolve_current_adapter() is None
147
148 def test_nonfinite_writer_renames_to_rejected_path(self, tmp_path: Path) -> None:
149 store = _store(tmp_path)
150
151 def bad_writer(p: Path) -> None:
152 (p / "weights.safetensors").write_text("bad")
153 raise NaNWeightsError(["adapter.lora_B"])
154
155 with pytest.raises(NaNWeightsError, match="NaN/inf"):
156 commit_version(store, bad_writer)
157
158 assert not store.adapter_version(1).exists()
159 assert (store.adapter_versions / "v0001-rejected").exists()
160 assert store.resolve_current_adapter() is None
161
162
163 class TestListPending:
164 def test_no_pending_when_all_committed(self, tmp_path: Path) -> None:
165 store = _store(tmp_path)
166 v1 = commit_version(store, lambda p: (p / "a").write_text("a"))
167 pending = list_pending_versions(store)
168 assert pending == []
169 assert store.resolve_current_adapter() == v1
170
171 def test_reports_pre_current_versions(self, tmp_path: Path) -> None:
172 store = _store(tmp_path)
173 # Orphan v0001
174 v1 = allocate_next_version(store)
175 # Committed v0002
176 commit_version(store, lambda p: (p / "a").write_text("a"))
177 pending = list_pending_versions(store)
178 assert [p.name for p in pending] == [v1.name]
179
180 def test_named_adapter_pending_versions_report_orphans(self, tmp_path: Path) -> None:
181 store = _store(tmp_path)
182 orphan = allocate_next_version(store, adapter_name="writer")
183 commit_version(store, lambda p: (p / "a").write_text("a"), adapter_name="writer")
184 pending = list_pending_versions(store, adapter_name="writer")
185 assert [p.name for p in pending] == [orphan.name]
186
187 def test_named_adapter_pending_versions_without_current(self, tmp_path: Path) -> None:
188 store = _store(tmp_path)
189 orphan = allocate_next_version(store, adapter_name="writer")
190 pending = list_pending_versions(store, adapter_name="writer")
191 assert pending == [orphan]
192
193
194 class TestFsyncDir:
195 def test_fsync_no_error_on_real_dir(self, tmp_path: Path) -> None:
196 """fsync_dir is side-effectful; just assert it doesn't raise."""
197 fsync_dir(tmp_path)
198
199
200 class TestRejectedPathAllocation:
201 def test_returns_first_available_suffix(self, tmp_path: Path) -> None:
202 pending = tmp_path / "v0001"
203 pending.mkdir()
204 (tmp_path / "v0001-rejected").mkdir()
205 (tmp_path / "v0001-rejected-1").mkdir()
206 assert _uniquify_rejected(pending) == tmp_path / "v0001-rejected-2"
207
208 def test_raises_after_1000_collisions(self, tmp_path: Path) -> None:
209 pending = tmp_path / "v0001"
210 pending.mkdir()
211 (tmp_path / "v0001-rejected").mkdir()
212 for i in range(1, 1000):
213 (tmp_path / f"v0001-rejected-{i}").mkdir()
214
215 with pytest.raises(RuntimeError, match="after 1000 attempts"):
216 _uniquify_rejected(pending)