| 1 |
"""`dpo_phase.run()` end-to-end with a mocked DPOTrainer. |
| 2 |
|
| 3 |
Mirrors `test_trainer.py`'s factory-seam pattern: we pass a MagicMock |
| 4 |
factory so `run()` exercises preflight → lock → log → commit → |
| 5 |
manifest → state-sidecar without importing HF/TRL or torch. |
| 6 |
""" |
| 7 |
|
| 8 |
from __future__ import annotations |
| 9 |
|
| 10 |
from dataclasses import replace |
| 11 |
from pathlib import Path |
| 12 |
from types import SimpleNamespace |
| 13 |
from typing import Any |
| 14 |
from unittest.mock import MagicMock |
| 15 |
|
| 16 |
import pytest |
| 17 |
|
| 18 |
import dlm.train.preference.dpo_phase as dpo_phase |
| 19 |
from dlm.base_models import BASE_MODELS |
| 20 |
from dlm.doc.parser import ParsedDlm |
| 21 |
from dlm.doc.schema import DlmFrontmatter, PreferenceConfig, TrainingConfig |
| 22 |
from dlm.doc.sections import Section, SectionType |
| 23 |
from dlm.store.manifest import Manifest, save_manifest |
| 24 |
from dlm.store.paths import for_dlm |
| 25 |
from dlm.train.preference.dpo_phase import run |
| 26 |
from dlm.train.state_sidecar import STATE_FILENAME, STATE_SHA_FILENAME |
| 27 |
|
| 28 |
|
| 29 |
def _parsed_with_preferences() -> ParsedDlm: |
| 30 |
pref_body = "### Prompt\nq?\n### Chosen\nc.\n### Rejected\nr.\n" |
| 31 |
return ParsedDlm( |
| 32 |
frontmatter=DlmFrontmatter( |
| 33 |
dlm_id="01KABCD" + "0" * 19, |
| 34 |
base_model="smollm2-135m", |
| 35 |
training=TrainingConfig(seed=42, preference=PreferenceConfig(enabled=True)), |
| 36 |
), |
| 37 |
sections=(Section(type=SectionType.PREFERENCE, content=pref_body),), |
| 38 |
) |
| 39 |
|
| 40 |
|
| 41 |
def _plan() -> SimpleNamespace: |
| 42 |
return SimpleNamespace( |
| 43 |
precision="bf16", |
| 44 |
attn_implementation="sdpa", |
| 45 |
use_qlora=False, |
| 46 |
quant_compute_dtype=None, |
| 47 |
micro_batch_size=1, |
| 48 |
grad_accum=1, |
| 49 |
effective_batch_size=1, |
| 50 |
gradient_checkpointing=False, |
| 51 |
est_peak_vram_gb=1.0, |
| 52 |
est_step_seconds=0.1, |
| 53 |
reason="test", |
| 54 |
to_dict=lambda: {"precision": "bf16", "phase": "dpo"}, |
| 55 |
) |
| 56 |
|
| 57 |
|
| 58 |
def _mock_factory(**_: Any) -> MagicMock: |
| 59 |
dpo = MagicMock() |
| 60 |
dpo.state = SimpleNamespace(global_step=15, epoch=1.0, best_metric=None) |
| 61 |
dpo.optimizer = SimpleNamespace(state_dict=lambda: {"lr": 5e-6}) |
| 62 |
dpo.lr_scheduler = SimpleNamespace(state_dict=lambda: {"step": 15}) |
| 63 |
dpo.scaler = None |
| 64 |
dpo.control = SimpleNamespace(should_training_stop=False) |
| 65 |
|
| 66 |
dpo.train.return_value = SimpleNamespace(training_loss=0.42) |
| 67 |
|
| 68 |
def _save_model(path: str) -> None: |
| 69 |
p = Path(path) |
| 70 |
p.mkdir(parents=True, exist_ok=True) |
| 71 |
(p / "adapter_config.json").write_text("{}") |
| 72 |
(p / "adapter_model.safetensors").write_bytes(b"\x00" * 64) |
| 73 |
|
| 74 |
dpo.save_model.side_effect = _save_model |
| 75 |
return dpo |
| 76 |
|
| 77 |
|
| 78 |
def _seed_prior_sft(store, dlm_id: str = "01DPOTEST") -> None: # type: ignore[no-untyped-def] |
| 79 |
"""Prime the store with a plausible post-SFT state. |
| 80 |
|
| 81 |
`allocate_next_version` picks the next vNNNN by scanning on-disk |
| 82 |
dirs — not by reading the manifest — so we materialize a v0001 |
| 83 |
placeholder. The manifest entry keeps the schema side consistent. |
| 84 |
""" |
| 85 |
store.ensure_layout() |
| 86 |
save_manifest( |
| 87 |
store.manifest, |
| 88 |
Manifest(dlm_id=dlm_id, base_model="smollm2-135m", adapter_version=1), |
| 89 |
) |
| 90 |
v0001 = store.adapter_version(1) |
| 91 |
v0001.mkdir(parents=True, exist_ok=True) |
| 92 |
(v0001 / "adapter_config.json").write_text("{}") |
| 93 |
|
| 94 |
|
| 95 |
class TestRunHappyPath: |
| 96 |
def test_commits_next_adapter_version(self, tmp_path: Path) -> None: |
| 97 |
store = for_dlm("01DPOTEST", home=tmp_path) |
| 98 |
_seed_prior_sft(store) |
| 99 |
|
| 100 |
spec = BASE_MODELS["smollm2-135m"] |
| 101 |
result = run( |
| 102 |
store, |
| 103 |
_parsed_with_preferences(), |
| 104 |
spec, |
| 105 |
_plan(), |
| 106 |
reference_adapter_version=1, |
| 107 |
trainer_factory=_mock_factory, |
| 108 |
) |
| 109 |
|
| 110 |
# DPO writes the next adapter version on top of SFT's v0001. |
| 111 |
assert result.adapter_version == 2 |
| 112 |
assert result.adapter_path.name == "v0002" |
| 113 |
assert (result.adapter_path / "adapter_config.json").exists() |
| 114 |
assert (result.adapter_path / STATE_FILENAME).exists() |
| 115 |
assert (result.adapter_path / STATE_SHA_FILENAME).exists() |
| 116 |
|
| 117 |
def test_manifest_gets_new_training_run_entry(self, tmp_path: Path) -> None: |
| 118 |
store = for_dlm("01DPOTEST", home=tmp_path) |
| 119 |
_seed_prior_sft(store) |
| 120 |
|
| 121 |
spec = BASE_MODELS["smollm2-135m"] |
| 122 |
run( |
| 123 |
store, |
| 124 |
_parsed_with_preferences(), |
| 125 |
spec, |
| 126 |
_plan(), |
| 127 |
reference_adapter_version=1, |
| 128 |
trainer_factory=_mock_factory, |
| 129 |
) |
| 130 |
|
| 131 |
from dlm.store.manifest import load_manifest |
| 132 |
|
| 133 |
manifest = load_manifest(store.manifest) |
| 134 |
assert manifest.adapter_version == 2 |
| 135 |
assert len(manifest.training_runs) == 1 |
| 136 |
assert manifest.training_runs[0].adapter_version == 2 |
| 137 |
|
| 138 |
def test_result_carries_training_loss_from_mock(self, tmp_path: Path) -> None: |
| 139 |
store = for_dlm("01DPOTEST", home=tmp_path) |
| 140 |
_seed_prior_sft(store) |
| 141 |
|
| 142 |
spec = BASE_MODELS["smollm2-135m"] |
| 143 |
result = run( |
| 144 |
store, |
| 145 |
_parsed_with_preferences(), |
| 146 |
spec, |
| 147 |
_plan(), |
| 148 |
reference_adapter_version=1, |
| 149 |
trainer_factory=_mock_factory, |
| 150 |
) |
| 151 |
assert result.final_train_loss == 0.42 |
| 152 |
# DPO phase doesn't wire eval — val metrics stay None. |
| 153 |
assert result.final_val_loss is None |
| 154 |
assert result.final_val_perplexity is None |
| 155 |
assert result.early_stopped is False |
| 156 |
|
| 157 |
def test_seed_defaults_to_training_config(self, tmp_path: Path) -> None: |
| 158 |
store = for_dlm("01DPOTEST", home=tmp_path) |
| 159 |
_seed_prior_sft(store) |
| 160 |
|
| 161 |
spec = BASE_MODELS["smollm2-135m"] |
| 162 |
result = run( |
| 163 |
store, |
| 164 |
_parsed_with_preferences(), |
| 165 |
spec, |
| 166 |
_plan(), |
| 167 |
reference_adapter_version=1, |
| 168 |
trainer_factory=_mock_factory, |
| 169 |
) |
| 170 |
assert result.seed == 42 # matches TrainingConfig(seed=42) |
| 171 |
|
| 172 |
|
| 173 |
class TestRunSteps: |
| 174 |
def test_factory_receives_reference_adapter_version(self, tmp_path: Path) -> None: |
| 175 |
"""The factory call should see the reference_adapter_version |
| 176 |
we passed into `run()`.""" |
| 177 |
captured: dict[str, Any] = {} |
| 178 |
|
| 179 |
def _capturing_factory(**kwargs: Any) -> MagicMock: |
| 180 |
captured.update(kwargs) |
| 181 |
return _mock_factory(**kwargs) |
| 182 |
|
| 183 |
store = for_dlm("01DPOTEST3", home=tmp_path) |
| 184 |
store.ensure_layout() |
| 185 |
save_manifest( |
| 186 |
store.manifest, |
| 187 |
Manifest(dlm_id="01DPOTEST3", base_model="smollm2-135m", adapter_version=3), |
| 188 |
) |
| 189 |
# Seed adapter version dirs v0001..v0003 so allocate_next picks v0004. |
| 190 |
for n in (1, 2, 3): |
| 191 |
vn = store.adapter_version(n) |
| 192 |
vn.mkdir(parents=True, exist_ok=True) |
| 193 |
(vn / "adapter_config.json").write_text("{}") |
| 194 |
|
| 195 |
spec = BASE_MODELS["smollm2-135m"] |
| 196 |
run( |
| 197 |
store, |
| 198 |
_parsed_with_preferences(), |
| 199 |
spec, |
| 200 |
_plan(), |
| 201 |
reference_adapter_version=3, |
| 202 |
trainer_factory=_capturing_factory, |
| 203 |
) |
| 204 |
assert captured["reference_adapter_version"] == 3 |
| 205 |
|
| 206 |
def test_factory_receives_include_auto_mined(self, tmp_path: Path) -> None: |
| 207 |
captured: dict[str, Any] = {} |
| 208 |
|
| 209 |
def _capturing_factory(**kwargs: Any) -> MagicMock: |
| 210 |
captured.update(kwargs) |
| 211 |
return _mock_factory(**kwargs) |
| 212 |
|
| 213 |
store = for_dlm("01DPOTEST4", home=tmp_path) |
| 214 |
_seed_prior_sft(store, dlm_id="01DPOTEST4") |
| 215 |
|
| 216 |
spec = BASE_MODELS["smollm2-135m"] |
| 217 |
run( |
| 218 |
store, |
| 219 |
_parsed_with_preferences(), |
| 220 |
spec, |
| 221 |
_plan(), |
| 222 |
reference_adapter_version=1, |
| 223 |
include_auto_mined=False, |
| 224 |
trainer_factory=_capturing_factory, |
| 225 |
) |
| 226 |
assert captured["include_auto_mined"] is False |
| 227 |
|
| 228 |
def test_writes_lock_when_decision_requests_it( |
| 229 |
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch |
| 230 |
) -> None: |
| 231 |
store = for_dlm("01DPOTEST5", home=tmp_path) |
| 232 |
_seed_prior_sft(store, dlm_id="01DPOTEST5") |
| 233 |
parsed = replace(_parsed_with_preferences(), source_path=tmp_path / "doc.dlm") |
| 234 |
persist_lock = MagicMock() |
| 235 |
|
| 236 |
monkeypatch.setattr( |
| 237 |
dpo_phase, |
| 238 |
"_validate_or_abort_lock", |
| 239 |
lambda **_kwargs: SimpleNamespace(should_write_lock=True), |
| 240 |
) |
| 241 |
monkeypatch.setattr(dpo_phase, "_persist_lock", persist_lock) |
| 242 |
|
| 243 |
run( |
| 244 |
store, |
| 245 |
parsed, |
| 246 |
BASE_MODELS["smollm2-135m"], |
| 247 |
_plan(), |
| 248 |
reference_adapter_version=1, |
| 249 |
trainer_factory=_mock_factory, |
| 250 |
) |
| 251 |
|
| 252 |
persist_lock.assert_called_once() |