| 1 |
"""training_state.pt save/load/integrity/version-drift.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import builtins |
| 6 |
import hashlib |
| 7 |
import io |
| 8 |
import json |
| 9 |
import logging |
| 10 |
import random |
| 11 |
import warnings |
| 12 |
from pathlib import Path |
| 13 |
from typing import Any |
| 14 |
|
| 15 |
import numpy as np |
| 16 |
import pytest |
| 17 |
import torch |
| 18 |
|
| 19 |
from dlm.train.errors import ResumeIntegrityError, VersionDriftWarning |
| 20 |
from dlm.train.state_sidecar import ( |
| 21 |
RNG_SIDECAR_FILENAME, |
| 22 |
STATE_FILENAME, |
| 23 |
STATE_SHA_FILENAME, |
| 24 |
TRAINING_RUN_FILENAME, |
| 25 |
VERSIONS_FILENAME, |
| 26 |
TrainingState, |
| 27 |
_decode_python_random_state, |
| 28 |
_encode_python_random_state, |
| 29 |
capture_runtime_versions, |
| 30 |
load_state, |
| 31 |
save_state, |
| 32 |
) |
| 33 |
|
| 34 |
|
| 35 |
def _mock_state(*, use_qlora: bool = False) -> TrainingState: |
| 36 |
return TrainingState( |
| 37 |
optimizer_state_dict={"lr": 1e-4}, |
| 38 |
scheduler_state_dict={"step": 5}, |
| 39 |
scaler_state_dict=None, |
| 40 |
torch_rng_state=torch.get_rng_state(), |
| 41 |
cuda_rng_state=None, |
| 42 |
numpy_rng_state=None, |
| 43 |
python_random_state=random.getstate(), |
| 44 |
global_step=10, |
| 45 |
epoch=0.5, |
| 46 |
best_val_loss=0.9, |
| 47 |
dlm_manifest_hash=None, |
| 48 |
base_model_revision="a" * 40, |
| 49 |
pinned_versions={"torch": torch.__version__}, |
| 50 |
use_qlora=use_qlora, |
| 51 |
) |
| 52 |
|
| 53 |
|
| 54 |
class TestRoundTrip: |
| 55 |
def test_save_writes_three_files(self, tmp_path: Path) -> None: |
| 56 |
save_state(tmp_path, _mock_state()) |
| 57 |
assert (tmp_path / STATE_FILENAME).exists() |
| 58 |
assert (tmp_path / STATE_SHA_FILENAME).exists() |
| 59 |
assert (tmp_path / VERSIONS_FILENAME).exists() |
| 60 |
|
| 61 |
def test_load_returns_state(self, tmp_path: Path) -> None: |
| 62 |
save_state(tmp_path, _mock_state()) |
| 63 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 64 |
assert loaded["global_step"] == 10 |
| 65 |
assert loaded["base_model_revision"] == "a" * 40 |
| 66 |
|
| 67 |
def test_pinned_versions_json_sidecar_readable(self, tmp_path: Path) -> None: |
| 68 |
"""`pinned_versions.json` is JSON for human grep-ability.""" |
| 69 |
import json |
| 70 |
|
| 71 |
save_state(tmp_path, _mock_state()) |
| 72 |
content = json.loads((tmp_path / VERSIONS_FILENAME).read_text()) |
| 73 |
assert "torch" in content |
| 74 |
|
| 75 |
|
| 76 |
class TestIntegrity: |
| 77 |
def test_missing_state_file_raises(self, tmp_path: Path) -> None: |
| 78 |
with pytest.raises(ResumeIntegrityError, match="missing training state"): |
| 79 |
load_state(tmp_path, runtime_versions={}) |
| 80 |
|
| 81 |
def test_missing_sha_file_raises(self, tmp_path: Path) -> None: |
| 82 |
save_state(tmp_path, _mock_state()) |
| 83 |
(tmp_path / STATE_SHA_FILENAME).unlink() |
| 84 |
with pytest.raises(ResumeIntegrityError, match="sha256 sidecar"): |
| 85 |
load_state(tmp_path, runtime_versions={}) |
| 86 |
|
| 87 |
def test_corrupted_state_raises(self, tmp_path: Path) -> None: |
| 88 |
save_state(tmp_path, _mock_state()) |
| 89 |
(tmp_path / STATE_FILENAME).write_bytes(b"tampered-bytes") |
| 90 |
with pytest.raises(ResumeIntegrityError, match="sha256 mismatch"): |
| 91 |
load_state(tmp_path, runtime_versions={}) |
| 92 |
|
| 93 |
def test_corrupted_sha_raises(self, tmp_path: Path) -> None: |
| 94 |
save_state(tmp_path, _mock_state()) |
| 95 |
(tmp_path / STATE_SHA_FILENAME).write_text("0" * 64 + "\n") |
| 96 |
with pytest.raises(ResumeIntegrityError, match="sha256 mismatch"): |
| 97 |
load_state(tmp_path, runtime_versions={}) |
| 98 |
|
| 99 |
|
| 100 |
class TestVersionDrift: |
| 101 |
def test_matching_versions_no_warning(self, tmp_path: Path) -> None: |
| 102 |
save_state(tmp_path, _mock_state()) |
| 103 |
with warnings.catch_warnings(): |
| 104 |
warnings.simplefilter("error") # any warning fails |
| 105 |
load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 106 |
|
| 107 |
def test_differing_version_emits_warning(self, tmp_path: Path) -> None: |
| 108 |
save_state(tmp_path, _mock_state()) |
| 109 |
with pytest.warns(VersionDriftWarning, match="torch:"): |
| 110 |
load_state(tmp_path, runtime_versions={"torch": "99.99.99"}) |
| 111 |
|
| 112 |
def test_gaining_a_package_is_not_drift(self, tmp_path: Path) -> None: |
| 113 |
"""Saved had no `trl` pinned; current runtime knows it → no drift. |
| 114 |
|
| 115 |
Gaining capability isn't drift — there was no prior state to |
| 116 |
diverge from. Only losing a pinned package is (see M6 test). |
| 117 |
""" |
| 118 |
save_state(tmp_path, _mock_state()) |
| 119 |
with warnings.catch_warnings(): |
| 120 |
warnings.simplefilter("error") |
| 121 |
load_state( |
| 122 |
tmp_path, |
| 123 |
runtime_versions={"torch": torch.__version__, "trl": "1.2.0"}, |
| 124 |
) |
| 125 |
|
| 126 |
def test_losing_pinned_package_is_drift(self, tmp_path: Path) -> None: |
| 127 |
"""Audit-04 M6: saved had `bitsandbytes="0.43.1"`, runtime has None. |
| 128 |
|
| 129 |
This matters for the QLoRA-on-CUDA → resumed-on-Apple-Silicon |
| 130 |
case; under the old logic it was silently skipped. |
| 131 |
""" |
| 132 |
# Build a mock state whose pinned_versions declares bitsandbytes. |
| 133 |
state = _mock_state() |
| 134 |
state["pinned_versions"] = { |
| 135 |
"torch": torch.__version__, |
| 136 |
"bitsandbytes": "0.43.1", |
| 137 |
} |
| 138 |
save_state(tmp_path, state) |
| 139 |
|
| 140 |
with pytest.warns(VersionDriftWarning, match="bitsandbytes.*0\\.43\\.1.*missing"): |
| 141 |
load_state( |
| 142 |
tmp_path, |
| 143 |
runtime_versions={"torch": torch.__version__, "bitsandbytes": None}, |
| 144 |
) |
| 145 |
|
| 146 |
def test_losing_pinned_package_missing_key_is_drift(self, tmp_path: Path) -> None: |
| 147 |
"""Same as above but the runtime dict omits the key entirely.""" |
| 148 |
state = _mock_state() |
| 149 |
state["pinned_versions"] = {"torch": torch.__version__, "bitsandbytes": "0.43.1"} |
| 150 |
save_state(tmp_path, state) |
| 151 |
|
| 152 |
with pytest.warns(VersionDriftWarning, match="bitsandbytes.*missing"): |
| 153 |
load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 154 |
|
| 155 |
|
| 156 |
class TestTrainingRunSidecar: |
| 157 |
"""Audit-05 M1: explicit use_qlora flag persisted alongside the adapter.""" |
| 158 |
|
| 159 |
def test_training_run_json_written_on_save(self, tmp_path: Path) -> None: |
| 160 |
import json |
| 161 |
|
| 162 |
save_state(tmp_path, _mock_state(use_qlora=True)) |
| 163 |
training_run = tmp_path / TRAINING_RUN_FILENAME |
| 164 |
assert training_run.exists() |
| 165 |
data = json.loads(training_run.read_text()) |
| 166 |
assert data["use_qlora"] is True |
| 167 |
|
| 168 |
def test_training_run_defaults_false_when_lora(self, tmp_path: Path) -> None: |
| 169 |
import json |
| 170 |
|
| 171 |
save_state(tmp_path, _mock_state(use_qlora=False)) |
| 172 |
data = json.loads((tmp_path / TRAINING_RUN_FILENAME).read_text()) |
| 173 |
assert data["use_qlora"] is False |
| 174 |
|
| 175 |
|
| 176 |
class TestRngSidecar: |
| 177 |
"""Audit-11 B7: numpy + python RNG round-trip via JSON sidecar. |
| 178 |
|
| 179 |
The v1 layout stored these inside the torch payload and required |
| 180 |
`weights_only=False` on load — an RCE vector. v2 moves them to a |
| 181 |
JSON sidecar so the torch payload loads safely under |
| 182 |
`weights_only=True`. |
| 183 |
""" |
| 184 |
|
| 185 |
def test_numpy_rng_round_trip(self, tmp_path: Path) -> None: |
| 186 |
"""`numpy.random.get_state()` must round-trip exactly — any drift |
| 187 |
means loss curves diverge on resume even with matched torch RNG.""" |
| 188 |
rng = np.random.RandomState(seed=12345) |
| 189 |
rng.random_sample(100) # advance the state past seed init |
| 190 |
original_state: tuple[Any, ...] = rng.get_state(legacy=True) # type: ignore[assignment] |
| 191 |
|
| 192 |
state = _mock_state() |
| 193 |
state["numpy_rng_state"] = original_state |
| 194 |
save_state(tmp_path, state) |
| 195 |
|
| 196 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 197 |
restored: tuple[Any, ...] = loaded["numpy_rng_state"] |
| 198 |
|
| 199 |
assert restored[0] == original_state[0] |
| 200 |
np.testing.assert_array_equal(restored[1], original_state[1]) |
| 201 |
assert restored[2] == original_state[2] |
| 202 |
assert restored[3] == original_state[3] |
| 203 |
assert restored[4] == original_state[4] |
| 204 |
|
| 205 |
def test_numpy_rng_round_trip_draws_match(self, tmp_path: Path) -> None: |
| 206 |
"""Behavioral check: after `set_state(restored)`, the next draws |
| 207 |
match what the original generator would have produced.""" |
| 208 |
rng = np.random.RandomState(seed=99) |
| 209 |
rng.random_sample(50) |
| 210 |
original_state = rng.get_state() |
| 211 |
expected_draws = rng.random_sample(20) |
| 212 |
|
| 213 |
state = _mock_state() |
| 214 |
state["numpy_rng_state"] = original_state |
| 215 |
save_state(tmp_path, state) |
| 216 |
|
| 217 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 218 |
resumed = np.random.RandomState() |
| 219 |
resumed.set_state(loaded["numpy_rng_state"]) |
| 220 |
np.testing.assert_array_equal(resumed.random_sample(20), expected_draws) |
| 221 |
|
| 222 |
def test_python_random_round_trip(self, tmp_path: Path) -> None: |
| 223 |
"""`random.getstate()` must round-trip so replay sampling matches.""" |
| 224 |
rng = random.Random(7) |
| 225 |
for _ in range(50): |
| 226 |
rng.random() |
| 227 |
original_state = rng.getstate() |
| 228 |
expected = [rng.random() for _ in range(10)] |
| 229 |
|
| 230 |
state = _mock_state() |
| 231 |
state["python_random_state"] = original_state |
| 232 |
save_state(tmp_path, state) |
| 233 |
|
| 234 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 235 |
resumed = random.Random() |
| 236 |
resumed.setstate(loaded["python_random_state"]) |
| 237 |
got = [resumed.random() for _ in range(10)] |
| 238 |
assert got == expected |
| 239 |
|
| 240 |
def test_rng_sidecar_is_valid_json(self, tmp_path: Path) -> None: |
| 241 |
"""The sidecar is parseable JSON with the expected top-level keys.""" |
| 242 |
state = _mock_state() |
| 243 |
state["numpy_rng_state"] = np.random.RandomState(seed=1).get_state() |
| 244 |
save_state(tmp_path, state) |
| 245 |
|
| 246 |
sidecar = json.loads((tmp_path / RNG_SIDECAR_FILENAME).read_text()) |
| 247 |
assert sidecar["_rng_sidecar_version"] == 2 |
| 248 |
assert sidecar["numpy_rng_state"] is not None |
| 249 |
assert "state_hex" in sidecar["numpy_rng_state"] |
| 250 |
assert sidecar["python_random_state"] is not None |
| 251 |
|
| 252 |
def test_missing_rng_sidecar_on_v2_payload_raises(self, tmp_path: Path) -> None: |
| 253 |
"""Deleting `training_state.rng.json` after a v2 save must be |
| 254 |
refused — silently substituting None breaks determinism.""" |
| 255 |
save_state(tmp_path, _mock_state()) |
| 256 |
(tmp_path / RNG_SIDECAR_FILENAME).unlink() |
| 257 |
|
| 258 |
with pytest.raises(ResumeIntegrityError, match="requires training_state.rng.json"): |
| 259 |
load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 260 |
|
| 261 |
def test_malformed_rng_sidecar_raises(self, tmp_path: Path) -> None: |
| 262 |
save_state(tmp_path, _mock_state()) |
| 263 |
(tmp_path / RNG_SIDECAR_FILENAME).write_text("{not valid json") |
| 264 |
|
| 265 |
with pytest.raises(ResumeIntegrityError, match="cannot read RNG sidecar"): |
| 266 |
load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 267 |
|
| 268 |
def test_torch_payload_loads_under_weights_only_true(self, tmp_path: Path) -> None: |
| 269 |
"""Direct verification the payload never needs `weights_only=False`. |
| 270 |
|
| 271 |
The point of the v2 refactor: tampered pickled bytes cannot |
| 272 |
execute arbitrary code on resume because the loader itself |
| 273 |
refuses to deserialize non-allowlisted types. |
| 274 |
""" |
| 275 |
state = _mock_state() |
| 276 |
state["numpy_rng_state"] = np.random.RandomState(seed=1).get_state() |
| 277 |
save_state(tmp_path, state) |
| 278 |
|
| 279 |
blob = (tmp_path / STATE_FILENAME).read_bytes() |
| 280 |
payload = torch.load(io.BytesIO(blob), weights_only=True) |
| 281 |
assert payload["_state_sidecar_version"] == 2 |
| 282 |
# numpy ndarrays shouldn't appear in the torch payload — they |
| 283 |
# live in the JSON sidecar. |
| 284 |
assert "numpy_rng_state" not in payload |
| 285 |
|
| 286 |
def test_python_random_none_helpers_round_trip(self) -> None: |
| 287 |
assert _encode_python_random_state(None) is None |
| 288 |
assert _decode_python_random_state(None) is None |
| 289 |
|
| 290 |
|
| 291 |
class TestLegacyV1Compat: |
| 292 |
"""Audit-11 B7: one-release back-compat for pre-B7 sidecars. |
| 293 |
|
| 294 |
Prior releases torch.save'd the full state dict (including numpy |
| 295 |
ndarrays) under `weights_only=False`. v2's reader retries with the |
| 296 |
legacy loader + logs a migration warning so existing checkpoints |
| 297 |
keep resuming through the transition. |
| 298 |
""" |
| 299 |
|
| 300 |
def _write_v1_sidecar(self, directory: Path) -> None: |
| 301 |
"""Emit a v1-shape blob (no version marker, numpy array inline).""" |
| 302 |
payload = { |
| 303 |
"optimizer_state_dict": {"lr": 1e-4}, |
| 304 |
"scheduler_state_dict": {"step": 5}, |
| 305 |
"scaler_state_dict": None, |
| 306 |
"torch_rng_state": torch.get_rng_state(), |
| 307 |
"cuda_rng_state": None, |
| 308 |
"numpy_rng_state": np.random.RandomState(seed=1).get_state(), |
| 309 |
"python_random_state": random.getstate(), |
| 310 |
"global_step": 10, |
| 311 |
"epoch": 0.5, |
| 312 |
"best_val_loss": 0.9, |
| 313 |
"dlm_manifest_hash": None, |
| 314 |
"base_model_revision": "a" * 40, |
| 315 |
"pinned_versions": {"torch": torch.__version__}, |
| 316 |
"use_qlora": False, |
| 317 |
} |
| 318 |
buf = io.BytesIO() |
| 319 |
torch.save(payload, buf) |
| 320 |
blob = buf.getvalue() |
| 321 |
(directory / STATE_FILENAME).write_bytes(blob) |
| 322 |
(directory / STATE_SHA_FILENAME).write_text(hashlib.sha256(blob).hexdigest() + "\n") |
| 323 |
|
| 324 |
def test_v1_payload_loads_with_migration_warning( |
| 325 |
self, tmp_path: Path, caplog: pytest.LogCaptureFixture |
| 326 |
) -> None: |
| 327 |
self._write_v1_sidecar(tmp_path) |
| 328 |
|
| 329 |
with caplog.at_level(logging.WARNING, logger="dlm.train.state_sidecar"): |
| 330 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 331 |
|
| 332 |
assert loaded["global_step"] == 10 |
| 333 |
assert loaded["numpy_rng_state"] is not None |
| 334 |
assert any("legacy v1 format" in rec.message for rec in caplog.records) |
| 335 |
|
| 336 |
def test_v1_payload_does_not_require_rng_sidecar(self, tmp_path: Path) -> None: |
| 337 |
"""The legacy path carries RNG inline, so the JSON sidecar |
| 338 |
must not be required for v1 blobs.""" |
| 339 |
self._write_v1_sidecar(tmp_path) |
| 340 |
# No RNG_SIDECAR_FILENAME written — this must still load. |
| 341 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 342 |
assert loaded["global_step"] == 10 |
| 343 |
|
| 344 |
def test_double_failed_torch_load_raises_integrity_error( |
| 345 |
self, tmp_path: Path, monkeypatch |
| 346 |
) -> None: |
| 347 |
save_state(tmp_path, _mock_state()) |
| 348 |
|
| 349 |
calls = {"count": 0} |
| 350 |
real_load = torch.load |
| 351 |
|
| 352 |
def fake_load(*args: Any, **kwargs: Any) -> Any: |
| 353 |
calls["count"] += 1 |
| 354 |
if calls["count"] == 1: |
| 355 |
raise RuntimeError("weights-only failed") |
| 356 |
raise RuntimeError("legacy failed") |
| 357 |
|
| 358 |
monkeypatch.setattr(torch, "load", fake_load) |
| 359 |
with pytest.raises(ResumeIntegrityError, match="legacy load also failed"): |
| 360 |
load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 361 |
monkeypatch.setattr(torch, "load", real_load) |
| 362 |
|
| 363 |
def test_missing_sidecar_version_defaults_rng_to_none( |
| 364 |
self, tmp_path: Path, monkeypatch |
| 365 |
) -> None: |
| 366 |
save_state(tmp_path, _mock_state()) |
| 367 |
real_load = torch.load |
| 368 |
|
| 369 |
def fake_load(*args: Any, **kwargs: Any) -> dict[str, Any]: |
| 370 |
return { |
| 371 |
"optimizer_state_dict": {"lr": 1e-4}, |
| 372 |
"scheduler_state_dict": {"step": 5}, |
| 373 |
"scaler_state_dict": None, |
| 374 |
"torch_rng_state": torch.get_rng_state(), |
| 375 |
"cuda_rng_state": None, |
| 376 |
"global_step": 10, |
| 377 |
"epoch": 0.5, |
| 378 |
"best_val_loss": 0.9, |
| 379 |
"dlm_manifest_hash": None, |
| 380 |
"base_model_revision": "a" * 40, |
| 381 |
"pinned_versions": {"torch": torch.__version__}, |
| 382 |
"use_qlora": False, |
| 383 |
} |
| 384 |
|
| 385 |
monkeypatch.setattr(torch, "load", fake_load) |
| 386 |
loaded = load_state(tmp_path, runtime_versions={"torch": torch.__version__}) |
| 387 |
monkeypatch.setattr(torch, "load", real_load) |
| 388 |
|
| 389 |
assert loaded["numpy_rng_state"] is None |
| 390 |
assert loaded["python_random_state"] is None |
| 391 |
|
| 392 |
|
| 393 |
class TestCaptureRuntimeVersions: |
| 394 |
def test_torch_key_populated(self) -> None: |
| 395 |
versions = capture_runtime_versions() |
| 396 |
assert "torch" in versions |
| 397 |
assert isinstance(versions["torch"], str) |
| 398 |
|
| 399 |
def test_bitsandbytes_key_present_even_if_none(self) -> None: |
| 400 |
"""Explicit `None` so the key survives a JSON round-trip + drift check.""" |
| 401 |
versions = capture_runtime_versions() |
| 402 |
# Value may be None on Apple Silicon (bnb not installed) — but key exists. |
| 403 |
assert "bitsandbytes" in versions |
| 404 |
|
| 405 |
def test_sway_key_present_even_if_none(self) -> None: |
| 406 |
"""Same shape as bitsandbytes: key always present, `None` when sway is |
| 407 |
not installed in this venv. Records which probe harness produced the |
| 408 |
reports that drove the run.""" |
| 409 |
versions = capture_runtime_versions() |
| 410 |
assert "sway" in versions |
| 411 |
|
| 412 |
def test_missing_import_returns_none(self, monkeypatch) -> None: |
| 413 |
real_import = builtins.__import__ |
| 414 |
|
| 415 |
def fake_import(name: str, *args: Any, **kwargs: Any) -> Any: |
| 416 |
if name == "bitsandbytes": |
| 417 |
raise ImportError("forced missing package") |
| 418 |
return real_import(name, *args, **kwargs) |
| 419 |
|
| 420 |
monkeypatch.setattr(builtins, "__import__", fake_import) |
| 421 |
versions = capture_runtime_versions() |
| 422 |
assert versions["bitsandbytes"] is None |