Python · 17261 bytes Raw Blame History
1 """training_state.pt save/load/integrity/version-drift."""
2
3 from __future__ import annotations
4
5 import builtins
6 import hashlib
7 import io
8 import json
9 import logging
10 import random
11 import warnings
12 from pathlib import Path
13 from typing import Any
14
15 import numpy as np
16 import pytest
17 import torch
18
19 from dlm.train.errors import ResumeIntegrityError, VersionDriftWarning
20 from dlm.train.state_sidecar import (
21 RNG_SIDECAR_FILENAME,
22 STATE_FILENAME,
23 STATE_SHA_FILENAME,
24 TRAINING_RUN_FILENAME,
25 VERSIONS_FILENAME,
26 TrainingState,
27 _decode_python_random_state,
28 _encode_python_random_state,
29 capture_runtime_versions,
30 load_state,
31 save_state,
32 )
33
34
35 def _mock_state(*, use_qlora: bool = False) -> TrainingState:
36 return TrainingState(
37 optimizer_state_dict={"lr": 1e-4},
38 scheduler_state_dict={"step": 5},
39 scaler_state_dict=None,
40 torch_rng_state=torch.get_rng_state(),
41 cuda_rng_state=None,
42 numpy_rng_state=None,
43 python_random_state=random.getstate(),
44 global_step=10,
45 epoch=0.5,
46 best_val_loss=0.9,
47 dlm_manifest_hash=None,
48 base_model_revision="a" * 40,
49 pinned_versions={"torch": torch.__version__},
50 use_qlora=use_qlora,
51 )
52
53
54 class TestRoundTrip:
55 def test_save_writes_three_files(self, tmp_path: Path) -> None:
56 save_state(tmp_path, _mock_state())
57 assert (tmp_path / STATE_FILENAME).exists()
58 assert (tmp_path / STATE_SHA_FILENAME).exists()
59 assert (tmp_path / VERSIONS_FILENAME).exists()
60
61 def test_load_returns_state(self, tmp_path: Path) -> None:
62 save_state(tmp_path, _mock_state())
63 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
64 assert loaded["global_step"] == 10
65 assert loaded["base_model_revision"] == "a" * 40
66
67 def test_pinned_versions_json_sidecar_readable(self, tmp_path: Path) -> None:
68 """`pinned_versions.json` is JSON for human grep-ability."""
69 import json
70
71 save_state(tmp_path, _mock_state())
72 content = json.loads((tmp_path / VERSIONS_FILENAME).read_text())
73 assert "torch" in content
74
75
76 class TestIntegrity:
77 def test_missing_state_file_raises(self, tmp_path: Path) -> None:
78 with pytest.raises(ResumeIntegrityError, match="missing training state"):
79 load_state(tmp_path, runtime_versions={})
80
81 def test_missing_sha_file_raises(self, tmp_path: Path) -> None:
82 save_state(tmp_path, _mock_state())
83 (tmp_path / STATE_SHA_FILENAME).unlink()
84 with pytest.raises(ResumeIntegrityError, match="sha256 sidecar"):
85 load_state(tmp_path, runtime_versions={})
86
87 def test_corrupted_state_raises(self, tmp_path: Path) -> None:
88 save_state(tmp_path, _mock_state())
89 (tmp_path / STATE_FILENAME).write_bytes(b"tampered-bytes")
90 with pytest.raises(ResumeIntegrityError, match="sha256 mismatch"):
91 load_state(tmp_path, runtime_versions={})
92
93 def test_corrupted_sha_raises(self, tmp_path: Path) -> None:
94 save_state(tmp_path, _mock_state())
95 (tmp_path / STATE_SHA_FILENAME).write_text("0" * 64 + "\n")
96 with pytest.raises(ResumeIntegrityError, match="sha256 mismatch"):
97 load_state(tmp_path, runtime_versions={})
98
99
100 class TestVersionDrift:
101 def test_matching_versions_no_warning(self, tmp_path: Path) -> None:
102 save_state(tmp_path, _mock_state())
103 with warnings.catch_warnings():
104 warnings.simplefilter("error") # any warning fails
105 load_state(tmp_path, runtime_versions={"torch": torch.__version__})
106
107 def test_differing_version_emits_warning(self, tmp_path: Path) -> None:
108 save_state(tmp_path, _mock_state())
109 with pytest.warns(VersionDriftWarning, match="torch:"):
110 load_state(tmp_path, runtime_versions={"torch": "99.99.99"})
111
112 def test_gaining_a_package_is_not_drift(self, tmp_path: Path) -> None:
113 """Saved had no `trl` pinned; current runtime knows it → no drift.
114
115 Gaining capability isn't drift — there was no prior state to
116 diverge from. Only losing a pinned package is (see M6 test).
117 """
118 save_state(tmp_path, _mock_state())
119 with warnings.catch_warnings():
120 warnings.simplefilter("error")
121 load_state(
122 tmp_path,
123 runtime_versions={"torch": torch.__version__, "trl": "1.2.0"},
124 )
125
126 def test_losing_pinned_package_is_drift(self, tmp_path: Path) -> None:
127 """Audit-04 M6: saved had `bitsandbytes="0.43.1"`, runtime has None.
128
129 This matters for the QLoRA-on-CUDA → resumed-on-Apple-Silicon
130 case; under the old logic it was silently skipped.
131 """
132 # Build a mock state whose pinned_versions declares bitsandbytes.
133 state = _mock_state()
134 state["pinned_versions"] = {
135 "torch": torch.__version__,
136 "bitsandbytes": "0.43.1",
137 }
138 save_state(tmp_path, state)
139
140 with pytest.warns(VersionDriftWarning, match="bitsandbytes.*0\\.43\\.1.*missing"):
141 load_state(
142 tmp_path,
143 runtime_versions={"torch": torch.__version__, "bitsandbytes": None},
144 )
145
146 def test_losing_pinned_package_missing_key_is_drift(self, tmp_path: Path) -> None:
147 """Same as above but the runtime dict omits the key entirely."""
148 state = _mock_state()
149 state["pinned_versions"] = {"torch": torch.__version__, "bitsandbytes": "0.43.1"}
150 save_state(tmp_path, state)
151
152 with pytest.warns(VersionDriftWarning, match="bitsandbytes.*missing"):
153 load_state(tmp_path, runtime_versions={"torch": torch.__version__})
154
155
156 class TestTrainingRunSidecar:
157 """Audit-05 M1: explicit use_qlora flag persisted alongside the adapter."""
158
159 def test_training_run_json_written_on_save(self, tmp_path: Path) -> None:
160 import json
161
162 save_state(tmp_path, _mock_state(use_qlora=True))
163 training_run = tmp_path / TRAINING_RUN_FILENAME
164 assert training_run.exists()
165 data = json.loads(training_run.read_text())
166 assert data["use_qlora"] is True
167
168 def test_training_run_defaults_false_when_lora(self, tmp_path: Path) -> None:
169 import json
170
171 save_state(tmp_path, _mock_state(use_qlora=False))
172 data = json.loads((tmp_path / TRAINING_RUN_FILENAME).read_text())
173 assert data["use_qlora"] is False
174
175
176 class TestRngSidecar:
177 """Audit-11 B7: numpy + python RNG round-trip via JSON sidecar.
178
179 The v1 layout stored these inside the torch payload and required
180 `weights_only=False` on load — an RCE vector. v2 moves them to a
181 JSON sidecar so the torch payload loads safely under
182 `weights_only=True`.
183 """
184
185 def test_numpy_rng_round_trip(self, tmp_path: Path) -> None:
186 """`numpy.random.get_state()` must round-trip exactly — any drift
187 means loss curves diverge on resume even with matched torch RNG."""
188 rng = np.random.RandomState(seed=12345)
189 rng.random_sample(100) # advance the state past seed init
190 original_state: tuple[Any, ...] = rng.get_state(legacy=True) # type: ignore[assignment]
191
192 state = _mock_state()
193 state["numpy_rng_state"] = original_state
194 save_state(tmp_path, state)
195
196 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
197 restored: tuple[Any, ...] = loaded["numpy_rng_state"]
198
199 assert restored[0] == original_state[0]
200 np.testing.assert_array_equal(restored[1], original_state[1])
201 assert restored[2] == original_state[2]
202 assert restored[3] == original_state[3]
203 assert restored[4] == original_state[4]
204
205 def test_numpy_rng_round_trip_draws_match(self, tmp_path: Path) -> None:
206 """Behavioral check: after `set_state(restored)`, the next draws
207 match what the original generator would have produced."""
208 rng = np.random.RandomState(seed=99)
209 rng.random_sample(50)
210 original_state = rng.get_state()
211 expected_draws = rng.random_sample(20)
212
213 state = _mock_state()
214 state["numpy_rng_state"] = original_state
215 save_state(tmp_path, state)
216
217 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
218 resumed = np.random.RandomState()
219 resumed.set_state(loaded["numpy_rng_state"])
220 np.testing.assert_array_equal(resumed.random_sample(20), expected_draws)
221
222 def test_python_random_round_trip(self, tmp_path: Path) -> None:
223 """`random.getstate()` must round-trip so replay sampling matches."""
224 rng = random.Random(7)
225 for _ in range(50):
226 rng.random()
227 original_state = rng.getstate()
228 expected = [rng.random() for _ in range(10)]
229
230 state = _mock_state()
231 state["python_random_state"] = original_state
232 save_state(tmp_path, state)
233
234 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
235 resumed = random.Random()
236 resumed.setstate(loaded["python_random_state"])
237 got = [resumed.random() for _ in range(10)]
238 assert got == expected
239
240 def test_rng_sidecar_is_valid_json(self, tmp_path: Path) -> None:
241 """The sidecar is parseable JSON with the expected top-level keys."""
242 state = _mock_state()
243 state["numpy_rng_state"] = np.random.RandomState(seed=1).get_state()
244 save_state(tmp_path, state)
245
246 sidecar = json.loads((tmp_path / RNG_SIDECAR_FILENAME).read_text())
247 assert sidecar["_rng_sidecar_version"] == 2
248 assert sidecar["numpy_rng_state"] is not None
249 assert "state_hex" in sidecar["numpy_rng_state"]
250 assert sidecar["python_random_state"] is not None
251
252 def test_missing_rng_sidecar_on_v2_payload_raises(self, tmp_path: Path) -> None:
253 """Deleting `training_state.rng.json` after a v2 save must be
254 refused — silently substituting None breaks determinism."""
255 save_state(tmp_path, _mock_state())
256 (tmp_path / RNG_SIDECAR_FILENAME).unlink()
257
258 with pytest.raises(ResumeIntegrityError, match="requires training_state.rng.json"):
259 load_state(tmp_path, runtime_versions={"torch": torch.__version__})
260
261 def test_malformed_rng_sidecar_raises(self, tmp_path: Path) -> None:
262 save_state(tmp_path, _mock_state())
263 (tmp_path / RNG_SIDECAR_FILENAME).write_text("{not valid json")
264
265 with pytest.raises(ResumeIntegrityError, match="cannot read RNG sidecar"):
266 load_state(tmp_path, runtime_versions={"torch": torch.__version__})
267
268 def test_torch_payload_loads_under_weights_only_true(self, tmp_path: Path) -> None:
269 """Direct verification the payload never needs `weights_only=False`.
270
271 The point of the v2 refactor: tampered pickled bytes cannot
272 execute arbitrary code on resume because the loader itself
273 refuses to deserialize non-allowlisted types.
274 """
275 state = _mock_state()
276 state["numpy_rng_state"] = np.random.RandomState(seed=1).get_state()
277 save_state(tmp_path, state)
278
279 blob = (tmp_path / STATE_FILENAME).read_bytes()
280 payload = torch.load(io.BytesIO(blob), weights_only=True)
281 assert payload["_state_sidecar_version"] == 2
282 # numpy ndarrays shouldn't appear in the torch payload — they
283 # live in the JSON sidecar.
284 assert "numpy_rng_state" not in payload
285
286 def test_python_random_none_helpers_round_trip(self) -> None:
287 assert _encode_python_random_state(None) is None
288 assert _decode_python_random_state(None) is None
289
290
291 class TestLegacyV1Compat:
292 """Audit-11 B7: one-release back-compat for pre-B7 sidecars.
293
294 Prior releases torch.save'd the full state dict (including numpy
295 ndarrays) under `weights_only=False`. v2's reader retries with the
296 legacy loader + logs a migration warning so existing checkpoints
297 keep resuming through the transition.
298 """
299
300 def _write_v1_sidecar(self, directory: Path) -> None:
301 """Emit a v1-shape blob (no version marker, numpy array inline)."""
302 payload = {
303 "optimizer_state_dict": {"lr": 1e-4},
304 "scheduler_state_dict": {"step": 5},
305 "scaler_state_dict": None,
306 "torch_rng_state": torch.get_rng_state(),
307 "cuda_rng_state": None,
308 "numpy_rng_state": np.random.RandomState(seed=1).get_state(),
309 "python_random_state": random.getstate(),
310 "global_step": 10,
311 "epoch": 0.5,
312 "best_val_loss": 0.9,
313 "dlm_manifest_hash": None,
314 "base_model_revision": "a" * 40,
315 "pinned_versions": {"torch": torch.__version__},
316 "use_qlora": False,
317 }
318 buf = io.BytesIO()
319 torch.save(payload, buf)
320 blob = buf.getvalue()
321 (directory / STATE_FILENAME).write_bytes(blob)
322 (directory / STATE_SHA_FILENAME).write_text(hashlib.sha256(blob).hexdigest() + "\n")
323
324 def test_v1_payload_loads_with_migration_warning(
325 self, tmp_path: Path, caplog: pytest.LogCaptureFixture
326 ) -> None:
327 self._write_v1_sidecar(tmp_path)
328
329 with caplog.at_level(logging.WARNING, logger="dlm.train.state_sidecar"):
330 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
331
332 assert loaded["global_step"] == 10
333 assert loaded["numpy_rng_state"] is not None
334 assert any("legacy v1 format" in rec.message for rec in caplog.records)
335
336 def test_v1_payload_does_not_require_rng_sidecar(self, tmp_path: Path) -> None:
337 """The legacy path carries RNG inline, so the JSON sidecar
338 must not be required for v1 blobs."""
339 self._write_v1_sidecar(tmp_path)
340 # No RNG_SIDECAR_FILENAME written — this must still load.
341 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
342 assert loaded["global_step"] == 10
343
344 def test_double_failed_torch_load_raises_integrity_error(
345 self, tmp_path: Path, monkeypatch
346 ) -> None:
347 save_state(tmp_path, _mock_state())
348
349 calls = {"count": 0}
350 real_load = torch.load
351
352 def fake_load(*args: Any, **kwargs: Any) -> Any:
353 calls["count"] += 1
354 if calls["count"] == 1:
355 raise RuntimeError("weights-only failed")
356 raise RuntimeError("legacy failed")
357
358 monkeypatch.setattr(torch, "load", fake_load)
359 with pytest.raises(ResumeIntegrityError, match="legacy load also failed"):
360 load_state(tmp_path, runtime_versions={"torch": torch.__version__})
361 monkeypatch.setattr(torch, "load", real_load)
362
363 def test_missing_sidecar_version_defaults_rng_to_none(
364 self, tmp_path: Path, monkeypatch
365 ) -> None:
366 save_state(tmp_path, _mock_state())
367 real_load = torch.load
368
369 def fake_load(*args: Any, **kwargs: Any) -> dict[str, Any]:
370 return {
371 "optimizer_state_dict": {"lr": 1e-4},
372 "scheduler_state_dict": {"step": 5},
373 "scaler_state_dict": None,
374 "torch_rng_state": torch.get_rng_state(),
375 "cuda_rng_state": None,
376 "global_step": 10,
377 "epoch": 0.5,
378 "best_val_loss": 0.9,
379 "dlm_manifest_hash": None,
380 "base_model_revision": "a" * 40,
381 "pinned_versions": {"torch": torch.__version__},
382 "use_qlora": False,
383 }
384
385 monkeypatch.setattr(torch, "load", fake_load)
386 loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__})
387 monkeypatch.setattr(torch, "load", real_load)
388
389 assert loaded["numpy_rng_state"] is None
390 assert loaded["python_random_state"] is None
391
392
393 class TestCaptureRuntimeVersions:
394 def test_torch_key_populated(self) -> None:
395 versions = capture_runtime_versions()
396 assert "torch" in versions
397 assert isinstance(versions["torch"], str)
398
399 def test_bitsandbytes_key_present_even_if_none(self) -> None:
400 """Explicit `None` so the key survives a JSON round-trip + drift check."""
401 versions = capture_runtime_versions()
402 # Value may be None on Apple Silicon (bnb not installed) — but key exists.
403 assert "bitsandbytes" in versions
404
405 def test_sway_key_present_even_if_none(self) -> None:
406 """Same shape as bitsandbytes: key always present, `None` when sway is
407 not installed in this venv. Records which probe harness produced the
408 reports that drove the run."""
409 versions = capture_runtime_versions()
410 assert "sway" in versions
411
412 def test_missing_import_returns_none(self, monkeypatch) -> None:
413 real_import = builtins.__import__
414
415 def fake_import(name: str, *args: Any, **kwargs: Any) -> Any:
416 if name == "bitsandbytes":
417 raise ImportError("forced missing package")
418 return real_import(name, *args, **kwargs)
419
420 monkeypatch.setattr(builtins, "__import__", fake_import)
421 versions = capture_runtime_versions()
422 assert versions["bitsandbytes"] is None