| 1 |
"""Sprint 15 — `trainer.run()` lock validation + persistence wiring.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
import logging |
| 6 |
from pathlib import Path |
| 7 |
from types import SimpleNamespace |
| 8 |
from typing import Any |
| 9 |
from unittest.mock import MagicMock |
| 10 |
|
| 11 |
import pytest |
| 12 |
|
| 13 |
from dlm.base_models import BASE_MODELS |
| 14 |
from dlm.doc.parser import ParsedDlm |
| 15 |
from dlm.doc.schema import DlmFrontmatter, TrainingConfig |
| 16 |
from dlm.doc.sections import Section, SectionType |
| 17 |
from dlm.lock import DlmLock, LockValidationError, load_lock, write_lock |
| 18 |
from dlm.lock.schema import CURRENT_LOCK_VERSION |
| 19 |
from dlm.store.manifest import Manifest, save_manifest |
| 20 |
from dlm.store.paths import for_dlm |
| 21 |
from dlm.train.trainer import run |
| 22 |
|
| 23 |
# Cheap parsed helper duplicated from test_trainer; kept local so a future |
| 24 |
# refactor in one test file doesn't ripple into the other. |
| 25 |
|
| 26 |
|
| 27 |
def _parsed(tmp_path: Path, dlm_id: str = "01TEST0" + "0" * 19) -> ParsedDlm: |
| 28 |
doc = tmp_path / "doc.dlm" |
| 29 |
doc.write_text("placeholder .dlm body\n", encoding="utf-8") |
| 30 |
return ParsedDlm( |
| 31 |
frontmatter=DlmFrontmatter( |
| 32 |
dlm_id=dlm_id, |
| 33 |
base_model="smollm2-135m", |
| 34 |
training=TrainingConfig(seed=42), |
| 35 |
), |
| 36 |
sections=(Section(type=SectionType.PROSE, content="x"),), |
| 37 |
source_path=doc, |
| 38 |
) |
| 39 |
|
| 40 |
|
| 41 |
def _plan() -> SimpleNamespace: |
| 42 |
return SimpleNamespace( |
| 43 |
precision="bf16", |
| 44 |
attn_implementation="sdpa", |
| 45 |
use_qlora=False, |
| 46 |
quant_compute_dtype=None, |
| 47 |
micro_batch_size=1, |
| 48 |
grad_accum=1, |
| 49 |
effective_batch_size=1, |
| 50 |
gradient_checkpointing=False, |
| 51 |
est_peak_vram_gb=1.0, |
| 52 |
est_step_seconds=0.1, |
| 53 |
reason="test", |
| 54 |
to_dict=lambda: {"precision": "bf16"}, |
| 55 |
) |
| 56 |
|
| 57 |
|
| 58 |
def _mock_trainer_factory(**_: Any) -> MagicMock: |
| 59 |
sft = MagicMock() |
| 60 |
sft.state = SimpleNamespace(global_step=10, epoch=1.0, best_metric=0.5) |
| 61 |
sft.optimizer = SimpleNamespace(state_dict=lambda: {}) |
| 62 |
sft.lr_scheduler = SimpleNamespace(state_dict=lambda: {}) |
| 63 |
sft.scaler = None |
| 64 |
sft.control = SimpleNamespace(should_training_stop=False) |
| 65 |
sft.train.return_value = SimpleNamespace(training_loss=0.5) |
| 66 |
|
| 67 |
def _save_model(path: str) -> None: |
| 68 |
p = Path(path) |
| 69 |
p.mkdir(parents=True, exist_ok=True) |
| 70 |
(p / "adapter_config.json").write_text("{}") |
| 71 |
(p / "adapter_model.safetensors").write_bytes(b"\x00" * 32) |
| 72 |
|
| 73 |
sft.save_model.side_effect = _save_model |
| 74 |
return sft |
| 75 |
|
| 76 |
|
| 77 |
def _bootstrap_store(tmp_path: Path, dlm_id: str = "01TEST0" + "0" * 19): |
| 78 |
store = for_dlm(dlm_id, home=tmp_path) |
| 79 |
store.ensure_layout() |
| 80 |
save_manifest(store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m")) |
| 81 |
return store |
| 82 |
|
| 83 |
|
| 84 |
class TestFirstRunWritesLock: |
| 85 |
def test_fresh_run_creates_dlm_lock_with_run_id_1(self, tmp_path: Path) -> None: |
| 86 |
store = _bootstrap_store(tmp_path) |
| 87 |
parsed = _parsed(tmp_path) |
| 88 |
spec = BASE_MODELS["smollm2-135m"] |
| 89 |
|
| 90 |
run(store, parsed, spec, _plan(), trainer_factory=_mock_trainer_factory) |
| 91 |
|
| 92 |
loaded = load_lock(store.root) |
| 93 |
assert loaded is not None |
| 94 |
assert loaded.lock_version == CURRENT_LOCK_VERSION |
| 95 |
assert loaded.dlm_id == parsed.frontmatter.dlm_id |
| 96 |
assert loaded.seed == 42 |
| 97 |
assert loaded.last_run_id == 1 |
| 98 |
|
| 99 |
|
| 100 |
class TestIgnoreModeSkipsLock: |
| 101 |
def test_ignore_mode_doesnt_write_lock(self, tmp_path: Path) -> None: |
| 102 |
store = _bootstrap_store(tmp_path) |
| 103 |
parsed = _parsed(tmp_path) |
| 104 |
spec = BASE_MODELS["smollm2-135m"] |
| 105 |
|
| 106 |
run( |
| 107 |
store, |
| 108 |
parsed, |
| 109 |
spec, |
| 110 |
_plan(), |
| 111 |
trainer_factory=_mock_trainer_factory, |
| 112 |
lock_mode="ignore", |
| 113 |
) |
| 114 |
assert load_lock(store.root) is None |
| 115 |
|
| 116 |
|
| 117 |
class TestErrorSeverityAborts: |
| 118 |
def test_base_revision_drift_raises(self, tmp_path: Path) -> None: |
| 119 |
store = _bootstrap_store(tmp_path) |
| 120 |
parsed = _parsed(tmp_path) |
| 121 |
spec = BASE_MODELS["smollm2-135m"] |
| 122 |
|
| 123 |
# Seed the store with a lock whose recorded base_model_revision |
| 124 |
# differs from the real spec.revision — the validator must abort. |
| 125 |
from datetime import UTC, datetime |
| 126 |
|
| 127 |
forged = DlmLock( |
| 128 |
lock_version=CURRENT_LOCK_VERSION, |
| 129 |
created_at=datetime(2026, 4, 1, tzinfo=UTC), |
| 130 |
dlm_id=parsed.frontmatter.dlm_id, |
| 131 |
dlm_sha256="0" * 64, |
| 132 |
base_model_revision="totally-different-revision", |
| 133 |
hardware_tier="cpu", |
| 134 |
seed=42, |
| 135 |
determinism_class="best-effort", |
| 136 |
last_run_id=1, |
| 137 |
) |
| 138 |
write_lock(store.root, forged) |
| 139 |
|
| 140 |
with pytest.raises(LockValidationError, match="base_model_revision"): |
| 141 |
run( |
| 142 |
store, |
| 143 |
parsed, |
| 144 |
spec, |
| 145 |
_plan(), |
| 146 |
trainer_factory=_mock_trainer_factory, |
| 147 |
) |
| 148 |
|
| 149 |
|
| 150 |
class TestUpdateModeOverrides: |
| 151 |
def test_update_mode_bypasses_validation_and_writes(self, tmp_path: Path) -> None: |
| 152 |
store = _bootstrap_store(tmp_path) |
| 153 |
parsed = _parsed(tmp_path) |
| 154 |
spec = BASE_MODELS["smollm2-135m"] |
| 155 |
|
| 156 |
from datetime import UTC, datetime |
| 157 |
|
| 158 |
forged = DlmLock( |
| 159 |
lock_version=CURRENT_LOCK_VERSION, |
| 160 |
created_at=datetime(2026, 4, 1, tzinfo=UTC), |
| 161 |
dlm_id=parsed.frontmatter.dlm_id, |
| 162 |
dlm_sha256="0" * 64, |
| 163 |
base_model_revision="totally-different-revision", |
| 164 |
hardware_tier="cpu", |
| 165 |
seed=42, |
| 166 |
determinism_class="best-effort", |
| 167 |
last_run_id=1, |
| 168 |
) |
| 169 |
write_lock(store.root, forged) |
| 170 |
|
| 171 |
# --update-lock should NOT raise despite the base-revision drift. |
| 172 |
run( |
| 173 |
store, |
| 174 |
parsed, |
| 175 |
spec, |
| 176 |
_plan(), |
| 177 |
trainer_factory=_mock_trainer_factory, |
| 178 |
lock_mode="update", |
| 179 |
) |
| 180 |
|
| 181 |
updated = load_lock(store.root) |
| 182 |
assert updated is not None |
| 183 |
assert updated.base_model_revision == spec.revision |
| 184 |
|
| 185 |
def test_update_mode_warns_and_recovers_from_broken_lock( |
| 186 |
self, |
| 187 |
tmp_path: Path, |
| 188 |
caplog: pytest.LogCaptureFixture, |
| 189 |
) -> None: |
| 190 |
store = _bootstrap_store(tmp_path) |
| 191 |
parsed = _parsed(tmp_path) |
| 192 |
spec = BASE_MODELS["smollm2-135m"] |
| 193 |
store_lock = store.root / "dlm.lock" |
| 194 |
store_lock.write_text("{not json", encoding="utf-8") |
| 195 |
|
| 196 |
caplog.set_level(logging.WARNING, logger="dlm.train.trainer") |
| 197 |
run( |
| 198 |
store, |
| 199 |
parsed, |
| 200 |
spec, |
| 201 |
_plan(), |
| 202 |
trainer_factory=_mock_trainer_factory, |
| 203 |
lock_mode="update", |
| 204 |
) |
| 205 |
|
| 206 |
updated = load_lock(store.root) |
| 207 |
assert updated is not None |
| 208 |
assert updated.base_model_revision == spec.revision |
| 209 |
assert "update-lock: ignoring unreadable prior dlm.lock" in caplog.text |