Python · 8610 bytes Raw Blame History
1 """`dpo_phase.run()` end-to-end with a mocked DPOTrainer.
2
3 Mirrors `test_trainer.py`'s factory-seam pattern: we pass a MagicMock
4 factory so `run()` exercises preflight → lock → log → commit →
5 manifest → state-sidecar without importing HF/TRL or torch.
6 """
7
8 from __future__ import annotations
9
10 from dataclasses import replace
11 from pathlib import Path
12 from types import SimpleNamespace
13 from typing import Any
14 from unittest.mock import MagicMock
15
16 import pytest
17
18 import dlm.train.preference.dpo_phase as dpo_phase
19 from dlm.base_models import BASE_MODELS
20 from dlm.doc.parser import ParsedDlm
21 from dlm.doc.schema import DlmFrontmatter, PreferenceConfig, TrainingConfig
22 from dlm.doc.sections import Section, SectionType
23 from dlm.store.manifest import Manifest, save_manifest
24 from dlm.store.paths import for_dlm
25 from dlm.train.preference.dpo_phase import run
26 from dlm.train.state_sidecar import STATE_FILENAME, STATE_SHA_FILENAME
27
28
29 def _parsed_with_preferences() -> ParsedDlm:
30 pref_body = "### Prompt\nq?\n### Chosen\nc.\n### Rejected\nr.\n"
31 return ParsedDlm(
32 frontmatter=DlmFrontmatter(
33 dlm_id="01KABCD" + "0" * 19,
34 base_model="smollm2-135m",
35 training=TrainingConfig(seed=42, preference=PreferenceConfig(enabled=True)),
36 ),
37 sections=(Section(type=SectionType.PREFERENCE, content=pref_body),),
38 )
39
40
41 def _plan() -> SimpleNamespace:
42 return SimpleNamespace(
43 precision="bf16",
44 attn_implementation="sdpa",
45 use_qlora=False,
46 quant_compute_dtype=None,
47 micro_batch_size=1,
48 grad_accum=1,
49 effective_batch_size=1,
50 gradient_checkpointing=False,
51 est_peak_vram_gb=1.0,
52 est_step_seconds=0.1,
53 reason="test",
54 to_dict=lambda: {"precision": "bf16", "phase": "dpo"},
55 )
56
57
58 def _mock_factory(**_: Any) -> MagicMock:
59 dpo = MagicMock()
60 dpo.state = SimpleNamespace(global_step=15, epoch=1.0, best_metric=None)
61 dpo.optimizer = SimpleNamespace(state_dict=lambda: {"lr": 5e-6})
62 dpo.lr_scheduler = SimpleNamespace(state_dict=lambda: {"step": 15})
63 dpo.scaler = None
64 dpo.control = SimpleNamespace(should_training_stop=False)
65
66 dpo.train.return_value = SimpleNamespace(training_loss=0.42)
67
68 def _save_model(path: str) -> None:
69 p = Path(path)
70 p.mkdir(parents=True, exist_ok=True)
71 (p / "adapter_config.json").write_text("{}")
72 (p / "adapter_model.safetensors").write_bytes(b"\x00" * 64)
73
74 dpo.save_model.side_effect = _save_model
75 return dpo
76
77
78 def _seed_prior_sft(store, dlm_id: str = "01DPOTEST") -> None: # type: ignore[no-untyped-def]
79 """Prime the store with a plausible post-SFT state.
80
81 `allocate_next_version` picks the next vNNNN by scanning on-disk
82 dirs — not by reading the manifest — so we materialize a v0001
83 placeholder. The manifest entry keeps the schema side consistent.
84 """
85 store.ensure_layout()
86 save_manifest(
87 store.manifest,
88 Manifest(dlm_id=dlm_id, base_model="smollm2-135m", adapter_version=1),
89 )
90 v0001 = store.adapter_version(1)
91 v0001.mkdir(parents=True, exist_ok=True)
92 (v0001 / "adapter_config.json").write_text("{}")
93
94
95 class TestRunHappyPath:
96 def test_commits_next_adapter_version(self, tmp_path: Path) -> None:
97 store = for_dlm("01DPOTEST", home=tmp_path)
98 _seed_prior_sft(store)
99
100 spec = BASE_MODELS["smollm2-135m"]
101 result = run(
102 store,
103 _parsed_with_preferences(),
104 spec,
105 _plan(),
106 reference_adapter_version=1,
107 trainer_factory=_mock_factory,
108 )
109
110 # DPO writes the next adapter version on top of SFT's v0001.
111 assert result.adapter_version == 2
112 assert result.adapter_path.name == "v0002"
113 assert (result.adapter_path / "adapter_config.json").exists()
114 assert (result.adapter_path / STATE_FILENAME).exists()
115 assert (result.adapter_path / STATE_SHA_FILENAME).exists()
116
117 def test_manifest_gets_new_training_run_entry(self, tmp_path: Path) -> None:
118 store = for_dlm("01DPOTEST", home=tmp_path)
119 _seed_prior_sft(store)
120
121 spec = BASE_MODELS["smollm2-135m"]
122 run(
123 store,
124 _parsed_with_preferences(),
125 spec,
126 _plan(),
127 reference_adapter_version=1,
128 trainer_factory=_mock_factory,
129 )
130
131 from dlm.store.manifest import load_manifest
132
133 manifest = load_manifest(store.manifest)
134 assert manifest.adapter_version == 2
135 assert len(manifest.training_runs) == 1
136 assert manifest.training_runs[0].adapter_version == 2
137
138 def test_result_carries_training_loss_from_mock(self, tmp_path: Path) -> None:
139 store = for_dlm("01DPOTEST", home=tmp_path)
140 _seed_prior_sft(store)
141
142 spec = BASE_MODELS["smollm2-135m"]
143 result = run(
144 store,
145 _parsed_with_preferences(),
146 spec,
147 _plan(),
148 reference_adapter_version=1,
149 trainer_factory=_mock_factory,
150 )
151 assert result.final_train_loss == 0.42
152 # DPO phase doesn't wire eval — val metrics stay None.
153 assert result.final_val_loss is None
154 assert result.final_val_perplexity is None
155 assert result.early_stopped is False
156
157 def test_seed_defaults_to_training_config(self, tmp_path: Path) -> None:
158 store = for_dlm("01DPOTEST", home=tmp_path)
159 _seed_prior_sft(store)
160
161 spec = BASE_MODELS["smollm2-135m"]
162 result = run(
163 store,
164 _parsed_with_preferences(),
165 spec,
166 _plan(),
167 reference_adapter_version=1,
168 trainer_factory=_mock_factory,
169 )
170 assert result.seed == 42 # matches TrainingConfig(seed=42)
171
172
173 class TestRunSteps:
174 def test_factory_receives_reference_adapter_version(self, tmp_path: Path) -> None:
175 """The factory call should see the reference_adapter_version
176 we passed into `run()`."""
177 captured: dict[str, Any] = {}
178
179 def _capturing_factory(**kwargs: Any) -> MagicMock:
180 captured.update(kwargs)
181 return _mock_factory(**kwargs)
182
183 store = for_dlm("01DPOTEST3", home=tmp_path)
184 store.ensure_layout()
185 save_manifest(
186 store.manifest,
187 Manifest(dlm_id="01DPOTEST3", base_model="smollm2-135m", adapter_version=3),
188 )
189 # Seed adapter version dirs v0001..v0003 so allocate_next picks v0004.
190 for n in (1, 2, 3):
191 vn = store.adapter_version(n)
192 vn.mkdir(parents=True, exist_ok=True)
193 (vn / "adapter_config.json").write_text("{}")
194
195 spec = BASE_MODELS["smollm2-135m"]
196 run(
197 store,
198 _parsed_with_preferences(),
199 spec,
200 _plan(),
201 reference_adapter_version=3,
202 trainer_factory=_capturing_factory,
203 )
204 assert captured["reference_adapter_version"] == 3
205
206 def test_factory_receives_include_auto_mined(self, tmp_path: Path) -> None:
207 captured: dict[str, Any] = {}
208
209 def _capturing_factory(**kwargs: Any) -> MagicMock:
210 captured.update(kwargs)
211 return _mock_factory(**kwargs)
212
213 store = for_dlm("01DPOTEST4", home=tmp_path)
214 _seed_prior_sft(store, dlm_id="01DPOTEST4")
215
216 spec = BASE_MODELS["smollm2-135m"]
217 run(
218 store,
219 _parsed_with_preferences(),
220 spec,
221 _plan(),
222 reference_adapter_version=1,
223 include_auto_mined=False,
224 trainer_factory=_capturing_factory,
225 )
226 assert captured["include_auto_mined"] is False
227
228 def test_writes_lock_when_decision_requests_it(
229 self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
230 ) -> None:
231 store = for_dlm("01DPOTEST5", home=tmp_path)
232 _seed_prior_sft(store, dlm_id="01DPOTEST5")
233 parsed = replace(_parsed_with_preferences(), source_path=tmp_path / "doc.dlm")
234 persist_lock = MagicMock()
235
236 monkeypatch.setattr(
237 dpo_phase,
238 "_validate_or_abort_lock",
239 lambda **_kwargs: SimpleNamespace(should_write_lock=True),
240 )
241 monkeypatch.setattr(dpo_phase, "_persist_lock", persist_lock)
242
243 run(
244 store,
245 parsed,
246 BASE_MODELS["smollm2-135m"],
247 _plan(),
248 reference_adapter_version=1,
249 trainer_factory=_mock_factory,
250 )
251
252 persist_lock.assert_called_once()