Python · 6824 bytes Raw Blame History
1 """Sprint 15 — `trainer.run()` lock validation + persistence wiring."""
2
3 from __future__ import annotations
4
5 import logging
6 from pathlib import Path
7 from types import SimpleNamespace
8 from typing import Any
9 from unittest.mock import MagicMock
10
11 import pytest
12
13 from dlm.base_models import BASE_MODELS
14 from dlm.doc.parser import ParsedDlm
15 from dlm.doc.schema import DlmFrontmatter, TrainingConfig
16 from dlm.doc.sections import Section, SectionType
17 from dlm.lock import DlmLock, LockValidationError, load_lock, write_lock
18 from dlm.lock.schema import CURRENT_LOCK_VERSION
19 from dlm.store.manifest import Manifest, save_manifest
20 from dlm.store.paths import for_dlm
21 from dlm.train.trainer import run
22
23 # Cheap parsed helper duplicated from test_trainer; kept local so a future
24 # refactor in one test file doesn't ripple into the other.
25
26
27 def _parsed(tmp_path: Path, dlm_id: str = "01TEST0" + "0" * 19) -> ParsedDlm:
28 doc = tmp_path / "doc.dlm"
29 doc.write_text("placeholder .dlm body\n", encoding="utf-8")
30 return ParsedDlm(
31 frontmatter=DlmFrontmatter(
32 dlm_id=dlm_id,
33 base_model="smollm2-135m",
34 training=TrainingConfig(seed=42),
35 ),
36 sections=(Section(type=SectionType.PROSE, content="x"),),
37 source_path=doc,
38 )
39
40
41 def _plan() -> SimpleNamespace:
42 return SimpleNamespace(
43 precision="bf16",
44 attn_implementation="sdpa",
45 use_qlora=False,
46 quant_compute_dtype=None,
47 micro_batch_size=1,
48 grad_accum=1,
49 effective_batch_size=1,
50 gradient_checkpointing=False,
51 est_peak_vram_gb=1.0,
52 est_step_seconds=0.1,
53 reason="test",
54 to_dict=lambda: {"precision": "bf16"},
55 )
56
57
58 def _mock_trainer_factory(**_: Any) -> MagicMock:
59 sft = MagicMock()
60 sft.state = SimpleNamespace(global_step=10, epoch=1.0, best_metric=0.5)
61 sft.optimizer = SimpleNamespace(state_dict=lambda: {})
62 sft.lr_scheduler = SimpleNamespace(state_dict=lambda: {})
63 sft.scaler = None
64 sft.control = SimpleNamespace(should_training_stop=False)
65 sft.train.return_value = SimpleNamespace(training_loss=0.5)
66
67 def _save_model(path: str) -> None:
68 p = Path(path)
69 p.mkdir(parents=True, exist_ok=True)
70 (p / "adapter_config.json").write_text("{}")
71 (p / "adapter_model.safetensors").write_bytes(b"\x00" * 32)
72
73 sft.save_model.side_effect = _save_model
74 return sft
75
76
77 def _bootstrap_store(tmp_path: Path, dlm_id: str = "01TEST0" + "0" * 19):
78 store = for_dlm(dlm_id, home=tmp_path)
79 store.ensure_layout()
80 save_manifest(store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m"))
81 return store
82
83
84 class TestFirstRunWritesLock:
85 def test_fresh_run_creates_dlm_lock_with_run_id_1(self, tmp_path: Path) -> None:
86 store = _bootstrap_store(tmp_path)
87 parsed = _parsed(tmp_path)
88 spec = BASE_MODELS["smollm2-135m"]
89
90 run(store, parsed, spec, _plan(), trainer_factory=_mock_trainer_factory)
91
92 loaded = load_lock(store.root)
93 assert loaded is not None
94 assert loaded.lock_version == CURRENT_LOCK_VERSION
95 assert loaded.dlm_id == parsed.frontmatter.dlm_id
96 assert loaded.seed == 42
97 assert loaded.last_run_id == 1
98
99
100 class TestIgnoreModeSkipsLock:
101 def test_ignore_mode_doesnt_write_lock(self, tmp_path: Path) -> None:
102 store = _bootstrap_store(tmp_path)
103 parsed = _parsed(tmp_path)
104 spec = BASE_MODELS["smollm2-135m"]
105
106 run(
107 store,
108 parsed,
109 spec,
110 _plan(),
111 trainer_factory=_mock_trainer_factory,
112 lock_mode="ignore",
113 )
114 assert load_lock(store.root) is None
115
116
117 class TestErrorSeverityAborts:
118 def test_base_revision_drift_raises(self, tmp_path: Path) -> None:
119 store = _bootstrap_store(tmp_path)
120 parsed = _parsed(tmp_path)
121 spec = BASE_MODELS["smollm2-135m"]
122
123 # Seed the store with a lock whose recorded base_model_revision
124 # differs from the real spec.revision — the validator must abort.
125 from datetime import UTC, datetime
126
127 forged = DlmLock(
128 lock_version=CURRENT_LOCK_VERSION,
129 created_at=datetime(2026, 4, 1, tzinfo=UTC),
130 dlm_id=parsed.frontmatter.dlm_id,
131 dlm_sha256="0" * 64,
132 base_model_revision="totally-different-revision",
133 hardware_tier="cpu",
134 seed=42,
135 determinism_class="best-effort",
136 last_run_id=1,
137 )
138 write_lock(store.root, forged)
139
140 with pytest.raises(LockValidationError, match="base_model_revision"):
141 run(
142 store,
143 parsed,
144 spec,
145 _plan(),
146 trainer_factory=_mock_trainer_factory,
147 )
148
149
150 class TestUpdateModeOverrides:
151 def test_update_mode_bypasses_validation_and_writes(self, tmp_path: Path) -> None:
152 store = _bootstrap_store(tmp_path)
153 parsed = _parsed(tmp_path)
154 spec = BASE_MODELS["smollm2-135m"]
155
156 from datetime import UTC, datetime
157
158 forged = DlmLock(
159 lock_version=CURRENT_LOCK_VERSION,
160 created_at=datetime(2026, 4, 1, tzinfo=UTC),
161 dlm_id=parsed.frontmatter.dlm_id,
162 dlm_sha256="0" * 64,
163 base_model_revision="totally-different-revision",
164 hardware_tier="cpu",
165 seed=42,
166 determinism_class="best-effort",
167 last_run_id=1,
168 )
169 write_lock(store.root, forged)
170
171 # --update-lock should NOT raise despite the base-revision drift.
172 run(
173 store,
174 parsed,
175 spec,
176 _plan(),
177 trainer_factory=_mock_trainer_factory,
178 lock_mode="update",
179 )
180
181 updated = load_lock(store.root)
182 assert updated is not None
183 assert updated.base_model_revision == spec.revision
184
185 def test_update_mode_warns_and_recovers_from_broken_lock(
186 self,
187 tmp_path: Path,
188 caplog: pytest.LogCaptureFixture,
189 ) -> None:
190 store = _bootstrap_store(tmp_path)
191 parsed = _parsed(tmp_path)
192 spec = BASE_MODELS["smollm2-135m"]
193 store_lock = store.root / "dlm.lock"
194 store_lock.write_text("{not json", encoding="utf-8")
195
196 caplog.set_level(logging.WARNING, logger="dlm.train.trainer")
197 run(
198 store,
199 parsed,
200 spec,
201 _plan(),
202 trainer_factory=_mock_trainer_factory,
203 lock_mode="update",
204 )
205
206 updated = load_lock(store.root)
207 assert updated is not None
208 assert updated.base_model_revision == spec.revision
209 assert "update-lock: ignoring unreadable prior dlm.lock" in caplog.text