Python · 14628 bytes Raw Blame History
1 """Phase-orchestrator dispatcher tests.
2
3 Uses mock SFT/DPO runners so no HF/TRL imports happen at test time.
4 The heavy DPO path is covered by the slow integration suite.
5 """
6
7 from __future__ import annotations
8
9 import logging
10 from dataclasses import dataclass
11 from pathlib import Path
12 from typing import Any
13 from unittest.mock import MagicMock
14
15 import pytest
16
17 import dlm.train.preference.phase_orchestrator as phase_orchestrator
18 from dlm.doc.schema import PreferenceConfig
19 from dlm.doc.sections import Section, SectionType
20 from dlm.train.preference.errors import (
21 NoPreferenceContentError,
22 PriorAdapterRequiredError,
23 )
24 from dlm.train.preference.phase_orchestrator import (
25 PhaseResult,
26 has_preference_content,
27 has_sft_content,
28 run_phases,
29 )
30
31 # ---- helpers ---------------------------------------------------------------
32
33
34 def _prose(body: str = "some prose content") -> Section:
35 return Section(type=SectionType.PROSE, content=body, start_line=1)
36
37
38 def _instruction() -> Section:
39 return Section(
40 type=SectionType.INSTRUCTION,
41 content="### Q\nhi\n### A\nhello\n",
42 start_line=1,
43 )
44
45
46 def _pref() -> Section:
47 return Section(
48 type=SectionType.PREFERENCE,
49 content=("### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n"),
50 start_line=1,
51 )
52
53
54 def _mined_pref() -> Section:
55 return Section(
56 type=SectionType.PREFERENCE,
57 content=("### Prompt\nq\n### Chosen\nc\n### Rejected\nr\n"),
58 start_line=1,
59 auto_mined=True,
60 judge_name="sway:preference_judge",
61 judge_score_chosen=0.9,
62 judge_score_rejected=0.1,
63 mined_at="2026-04-23T20:00:00Z",
64 mined_run_id=7,
65 )
66
67
68 @dataclass
69 class _FakeTraining:
70 preference: PreferenceConfig
71
72
73 @dataclass
74 class _FakeFrontmatter:
75 training: _FakeTraining
76
77
78 @dataclass
79 class _FakeParsed:
80 sections: tuple[Section, ...]
81 frontmatter: _FakeFrontmatter
82
83
84 @dataclass
85 class _FakeRunResult:
86 adapter_version: int
87
88
89 def _parsed(
90 sections: list[Section],
91 *,
92 dpo_enabled: bool | None = None,
93 ) -> Any:
94 """Build a fake ParsedDlm.
95
96 `dpo_enabled=None` leaves the `enabled` field unset so
97 `resolve_preference_enabled` sees this as "user didn't specify"
98 and auto-enables when preference content is present.
99
100 `dpo_enabled=True/False` sets it explicitly — simulating a user
101 who wrote `training.preference.enabled: true/false` in their
102 frontmatter.
103 """
104 pref = PreferenceConfig() if dpo_enabled is None else PreferenceConfig(enabled=dpo_enabled)
105 return _FakeParsed(
106 sections=tuple(sections),
107 frontmatter=_FakeFrontmatter(training=_FakeTraining(preference=pref)),
108 )
109
110
111 # ---- content detection ----------------------------------------------------
112
113
114 class TestHasSftContent:
115 def test_prose_with_body_counts(self) -> None:
116 assert has_sft_content([_prose("hello")]) is True
117
118 def test_empty_prose_does_not_count(self) -> None:
119 assert has_sft_content([_prose(" \n ")]) is False
120
121 def test_instruction_counts(self) -> None:
122 assert has_sft_content([_instruction()]) is True
123
124 def test_preference_only_has_no_sft(self) -> None:
125 assert has_sft_content([_pref()]) is False
126
127 def test_empty_sections(self) -> None:
128 assert has_sft_content([]) is False
129
130
131 class TestHasPreferenceContent:
132 def test_preference_present(self) -> None:
133 assert has_preference_content([_pref()]) is True
134
135 def test_no_preference_section(self) -> None:
136 assert has_preference_content([_prose(), _instruction()]) is False
137
138 def test_empty_sections(self) -> None:
139 assert has_preference_content([]) is False
140
141
142 # ---- dispatcher ----------------------------------------------------------
143
144
145 class TestDispatcherSftOnly:
146 def test_sft_phase_runs_sft_only_even_with_preference(self) -> None:
147 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
148 dpo = MagicMock(return_value=_FakeRunResult(adapter_version=2))
149 results = run_phases(
150 store=MagicMock(),
151 parsed=_parsed([_prose(), _pref()], dpo_enabled=True),
152 spec=MagicMock(),
153 plan=MagicMock(),
154 phase="sft",
155 sft_runner=sft,
156 dpo_runner=dpo,
157 )
158 assert len(results) == 1
159 assert results[0].phase == "sft"
160 sft.assert_called_once()
161 dpo.assert_not_called()
162
163 def test_sft_phase_forwards_strict_metrics_to_sft_runner(self) -> None:
164 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
165
166 run_phases(
167 store=MagicMock(),
168 parsed=_parsed([_prose()]),
169 spec=MagicMock(),
170 plan=MagicMock(),
171 phase="sft",
172 strict_metrics=True,
173 sft_runner=sft,
174 dpo_runner=MagicMock(),
175 )
176
177 _, kwargs = sft.call_args
178 assert kwargs["strict_metrics"] is True
179
180 def test_sft_phase_forwards_world_size_to_sft_runner(self) -> None:
181 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
182
183 run_phases(
184 store=MagicMock(),
185 parsed=_parsed([_prose()]),
186 spec=MagicMock(),
187 plan=MagicMock(),
188 phase="sft",
189 world_size=4,
190 sft_runner=sft,
191 dpo_runner=MagicMock(),
192 )
193
194 _, kwargs = sft.call_args
195 assert kwargs["world_size"] == 4
196
197 def test_sft_phase_skips_when_no_sft_content(self) -> None:
198 sft = MagicMock()
199 dpo = MagicMock()
200 results = run_phases(
201 store=MagicMock(),
202 parsed=_parsed([_pref()]),
203 spec=MagicMock(),
204 plan=MagicMock(),
205 phase="sft",
206 sft_runner=sft,
207 dpo_runner=dpo,
208 )
209 assert results == []
210 sft.assert_not_called()
211 dpo.assert_not_called()
212
213
214 class TestDispatcherAllPhase:
215 def test_runs_sft_then_dpo_when_both_enabled(self) -> None:
216 sft = MagicMock(return_value=_FakeRunResult(adapter_version=3))
217 dpo = MagicMock(return_value=_FakeRunResult(adapter_version=4))
218 results = run_phases(
219 store=MagicMock(),
220 parsed=_parsed([_prose(), _pref()], dpo_enabled=True),
221 spec=MagicMock(),
222 plan=MagicMock(),
223 phase="all",
224 sft_runner=sft,
225 dpo_runner=dpo,
226 )
227 assert [r.phase for r in results] == ["sft", "preference"]
228 sft.assert_called_once()
229 dpo.assert_called_once()
230 # DPO is told which adapter version to use as reference.
231 _, dpo_kwargs = dpo.call_args
232 assert dpo_kwargs["reference_adapter_version"] == 3
233
234 def test_forwards_include_auto_mined_to_preference_runner(self) -> None:
235 sft = MagicMock(return_value=_FakeRunResult(adapter_version=3))
236 dpo = MagicMock(return_value=_FakeRunResult(adapter_version=4))
237
238 run_phases(
239 store=MagicMock(),
240 parsed=_parsed([_prose(), _pref()], dpo_enabled=True),
241 spec=MagicMock(),
242 plan=MagicMock(),
243 phase="all",
244 include_auto_mined=False,
245 sft_runner=sft,
246 dpo_runner=dpo,
247 )
248
249 _, dpo_kwargs = dpo.call_args
250 assert dpo_kwargs["include_auto_mined"] is False
251
252 def test_skips_dpo_when_disabled(self) -> None:
253 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
254 dpo = MagicMock()
255 results = run_phases(
256 store=MagicMock(),
257 parsed=_parsed([_prose(), _pref()], dpo_enabled=False),
258 spec=MagicMock(),
259 plan=MagicMock(),
260 phase="all",
261 sft_runner=sft,
262 dpo_runner=dpo,
263 )
264 assert [r.phase for r in results] == ["sft"]
265 dpo.assert_not_called()
266
267 def test_all_phase_auto_warns_when_enabled_but_no_preference_content(
268 self, caplog: pytest.LogCaptureFixture
269 ) -> None:
270 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
271 dpo = MagicMock()
272 with caplog.at_level(logging.WARNING):
273 results = run_phases(
274 store=MagicMock(),
275 parsed=_parsed([_prose()], dpo_enabled=True),
276 spec=MagicMock(),
277 plan=MagicMock(),
278 phase="all",
279 sft_runner=sft,
280 dpo_runner=dpo,
281 )
282 assert [r.phase for r in results] == ["sft"]
283 dpo.assert_not_called()
284 assert any("no ::preference::" in rec.message for rec in caplog.records)
285
286 def test_no_mined_treats_mined_only_doc_as_no_preference(
287 self,
288 caplog: pytest.LogCaptureFixture,
289 ) -> None:
290 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
291 dpo = MagicMock()
292 with caplog.at_level(logging.WARNING):
293 results = run_phases(
294 store=MagicMock(),
295 parsed=_parsed([_prose(), _mined_pref()], dpo_enabled=True),
296 spec=MagicMock(),
297 plan=MagicMock(),
298 phase="all",
299 include_auto_mined=False,
300 sft_runner=sft,
301 dpo_runner=dpo,
302 )
303 assert [r.phase for r in results] == ["sft"]
304 dpo.assert_not_called()
305 assert any("no ::preference::" in rec.message for rec in caplog.records)
306
307
308 class TestDispatcherDpoOnly:
309 def test_dpo_only_reads_adapter_from_manifest(self, tmp_path: Path) -> None:
310 store = MagicMock()
311 store.manifest = tmp_path / "manifest.json"
312
313 sft = MagicMock()
314 dpo = MagicMock(return_value=_FakeRunResult(adapter_version=5))
315
316 # Seed a manifest via the real save/load path so _resolve_reference_adapter_version
317 # doesn't raise.
318 from dlm.store.manifest import Manifest, save_manifest
319
320 m = Manifest(
321 dlm_id="01HZ4X7TGZM3J1A2B3C4D5E6F7",
322 base_model="smollm2-135m",
323 adapter_version=4,
324 )
325 save_manifest(store.manifest, m)
326
327 results = run_phases(
328 store=store,
329 parsed=_parsed([_pref()]),
330 spec=MagicMock(),
331 plan=MagicMock(),
332 phase="preference",
333 sft_runner=sft,
334 dpo_runner=dpo,
335 )
336 assert [r.phase for r in results] == ["preference"]
337 sft.assert_not_called()
338 _, dpo_kwargs = dpo.call_args
339 assert dpo_kwargs["reference_adapter_version"] == 4
340
341 def test_dpo_only_raises_on_missing_preference(self, tmp_path: Path) -> None:
342 store = MagicMock()
343 store.manifest = tmp_path / "manifest.json"
344 from dlm.store.manifest import Manifest, save_manifest
345
346 save_manifest(
347 store.manifest,
348 Manifest(
349 dlm_id="01HZ4X7TGZM3J1A2B3C4D5E6F7",
350 base_model="smollm2-135m",
351 adapter_version=2,
352 ),
353 )
354 with pytest.raises(NoPreferenceContentError):
355 run_phases(
356 store=store,
357 parsed=_parsed([_prose()]),
358 spec=MagicMock(),
359 plan=MagicMock(),
360 phase="preference",
361 sft_runner=MagicMock(),
362 dpo_runner=MagicMock(),
363 )
364
365 def test_dpo_only_raises_without_prior_adapter(self, tmp_path: Path) -> None:
366 store = MagicMock()
367 store.manifest = tmp_path / "manifest.json"
368 from dlm.store.manifest import Manifest, save_manifest
369
370 save_manifest(
371 store.manifest,
372 Manifest(
373 dlm_id="01HZ4X7TGZM3J1A2B3C4D5E6F7",
374 base_model="smollm2-135m",
375 adapter_version=0, # no SFT ever run
376 ),
377 )
378 with pytest.raises(PriorAdapterRequiredError):
379 run_phases(
380 store=store,
381 parsed=_parsed([_pref()]),
382 spec=MagicMock(),
383 plan=MagicMock(),
384 phase="preference",
385 sft_runner=MagicMock(),
386 dpo_runner=MagicMock(),
387 )
388
389
390 class TestPhaseResult:
391 def test_phase_result_is_frozen(self) -> None:
392 import dataclasses
393
394 pr = PhaseResult(phase="sft", result=_FakeRunResult(adapter_version=1))
395 with pytest.raises(dataclasses.FrozenInstanceError):
396 pr.phase = "dpo" # type: ignore[misc]
397
398
399 class TestMethodRunner:
400 def test_method_runner_uses_registry_resolver(self, monkeypatch: pytest.MonkeyPatch) -> None:
401 fake = MagicMock()
402 monkeypatch.setattr(
403 "dlm.train.preference.method_registry.resolve",
404 lambda method: fake if method == "orpo" else None,
405 )
406 assert phase_orchestrator._method_runner("orpo") is fake
407
408
409 class TestAutoEnableIntegration:
410 """Auto-enable: when user didn't set `enabled` and preference
411 content is present, DPO runs under `--phase all`."""
412
413 def test_unset_enabled_with_preferences_auto_runs_dpo(self) -> None:
414 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
415 dpo = MagicMock(return_value=_FakeRunResult(adapter_version=2))
416 results = run_phases(
417 store=MagicMock(),
418 parsed=_parsed([_prose(), _pref()], dpo_enabled=None),
419 spec=MagicMock(),
420 plan=MagicMock(),
421 phase="all",
422 sft_runner=sft,
423 dpo_runner=dpo,
424 )
425 assert [r.phase for r in results] == ["sft", "preference"]
426
427 def test_explicit_false_blocks_auto_enable(self) -> None:
428 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
429 dpo = MagicMock()
430 results = run_phases(
431 store=MagicMock(),
432 parsed=_parsed([_prose(), _pref()], dpo_enabled=False),
433 spec=MagicMock(),
434 plan=MagicMock(),
435 phase="all",
436 sft_runner=sft,
437 dpo_runner=dpo,
438 )
439 assert [r.phase for r in results] == ["sft"]
440 dpo.assert_not_called()
441
442 def test_unset_enabled_with_no_preferences_stays_off(self) -> None:
443 sft = MagicMock(return_value=_FakeRunResult(adapter_version=1))
444 dpo = MagicMock()
445 results = run_phases(
446 store=MagicMock(),
447 parsed=_parsed([_prose()], dpo_enabled=None),
448 spec=MagicMock(),
449 plan=MagicMock(),
450 phase="all",
451 sft_runner=sft,
452 dpo_runner=dpo,
453 )
454 assert [r.phase for r in results] == ["sft"]
455 dpo.assert_not_called()