Python · 14948 bytes Raw Blame History
1 """`run_all` orchestration over named adapters."""
2
3 from __future__ import annotations
4
5 from pathlib import Path
6 from types import SimpleNamespace
7 from typing import Any
8 from unittest.mock import MagicMock
9
10 import pytest
11
12 import dlm.train.gate.orchestrator as gate_orchestrator
13 import dlm.train.multi_adapter.trainer as multi_adapter_trainer
14 from dlm.base_models import BASE_MODELS
15 from dlm.doc.parser import ParsedDlm
16 from dlm.doc.schema import (
17 AdapterConfig,
18 DlmFrontmatter,
19 GateConfig,
20 TrainingConfig,
21 )
22 from dlm.doc.sections import Section, SectionType
23 from dlm.store.manifest import Manifest, save_manifest
24 from dlm.store.paths import for_dlm
25 from dlm.train.multi_adapter.trainer import run_all
26
27
28 def _plan() -> SimpleNamespace:
29 return SimpleNamespace(
30 precision="bf16",
31 attn_implementation="sdpa",
32 use_qlora=False,
33 quant_compute_dtype=None,
34 micro_batch_size=1,
35 grad_accum=1,
36 effective_batch_size=1,
37 gradient_checkpointing=False,
38 est_peak_vram_gb=1.0,
39 est_step_seconds=0.1,
40 reason="test",
41 to_dict=lambda: {"precision": "bf16"},
42 )
43
44
45 def _mock_trainer_factory(**_: Any) -> MagicMock:
46 sft = MagicMock()
47 sft.state = SimpleNamespace(global_step=5, epoch=1.0, best_metric=0.9)
48 sft.optimizer = SimpleNamespace(state_dict=lambda: {"lr": 1e-4})
49 sft.lr_scheduler = SimpleNamespace(state_dict=lambda: {"step": 5})
50 sft.scaler = None
51 sft.control = SimpleNamespace(should_training_stop=False)
52 sft.train.return_value = SimpleNamespace(training_loss=1.0)
53
54 def _save_model(path: str) -> None:
55 p = Path(path)
56 p.mkdir(parents=True, exist_ok=True)
57 (p / "adapter_config.json").write_text("{}")
58 (p / "adapter_model.safetensors").write_bytes(b"\x00" * 32)
59
60 sft.save_model.side_effect = _save_model
61 return sft
62
63
64 def _multi_adapter_parsed(dlm_id: str, *, gate_enabled: bool = False) -> ParsedDlm:
65 return ParsedDlm(
66 frontmatter=DlmFrontmatter(
67 dlm_id=dlm_id,
68 base_model="smollm2-135m",
69 training=TrainingConfig(
70 seed=42,
71 gate=GateConfig(enabled=gate_enabled),
72 adapters={
73 "knowledge": AdapterConfig(),
74 "tone": AdapterConfig(lora_r=4),
75 },
76 ),
77 ),
78 sections=(
79 Section(type=SectionType.PROSE, content="Shared domain prose."),
80 Section(
81 type=SectionType.INSTRUCTION,
82 content="### Q\nfacts?\n### A\nfacts.",
83 adapter="knowledge",
84 ),
85 Section(
86 type=SectionType.INSTRUCTION,
87 content="### Q\ntone?\n### A\ncrisp.",
88 adapter="tone",
89 ),
90 ),
91 )
92
93
94 def _single_adapter_parsed(dlm_id: str) -> ParsedDlm:
95 return ParsedDlm(
96 frontmatter=DlmFrontmatter(
97 dlm_id=dlm_id,
98 base_model="smollm2-135m",
99 training=TrainingConfig(seed=42),
100 ),
101 sections=(Section(type=SectionType.PROSE, content="Single-adapter prose."),),
102 )
103
104
105 def _seed_store(tmp_path: Path, dlm_id: str) -> Any:
106 store = for_dlm(dlm_id, home=tmp_path)
107 store.ensure_layout()
108 save_manifest(store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m"))
109 return store
110
111
112 class TestSingleAdapterPassthrough:
113 def test_single_adapter_doc_yields_one_result(self, tmp_path: Path) -> None:
114 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FA"
115 store = _seed_store(tmp_path, dlm_id)
116 results = run_all(
117 store,
118 _single_adapter_parsed(dlm_id),
119 BASE_MODELS["smollm2-135m"],
120 _plan(),
121 mode="fresh",
122 trainer_factory=_mock_trainer_factory,
123 )
124 assert len(results) == 1
125 # Flat layout: version dir lives under adapter/versions/, not a named subdir.
126 assert store.adapter_version(1).is_dir()
127
128 def test_one_named_adapter_still_passthroughs(self, tmp_path: Path) -> None:
129 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FZ"
130 store = _seed_store(tmp_path, dlm_id)
131 parsed = ParsedDlm(
132 frontmatter=DlmFrontmatter(
133 dlm_id=dlm_id,
134 base_model="smollm2-135m",
135 training=TrainingConfig(
136 seed=42,
137 adapters={"knowledge": AdapterConfig()},
138 ),
139 ),
140 sections=(
141 Section(type=SectionType.PROSE, content="Shared domain prose."),
142 Section(
143 type=SectionType.INSTRUCTION,
144 content="### Q\nfacts?\n### A\nfacts.",
145 adapter="knowledge",
146 ),
147 ),
148 )
149 results = run_all(
150 store,
151 parsed,
152 BASE_MODELS["smollm2-135m"],
153 _plan(),
154 mode="fresh",
155 trainer_factory=_mock_trainer_factory,
156 )
157 assert len(results) == 1
158
159 def test_gate_enabled_with_one_named_adapter_still_returns_one_result(
160 self, tmp_path: Path
161 ) -> None:
162 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FY"
163 store = _seed_store(tmp_path, dlm_id)
164 parsed = ParsedDlm(
165 frontmatter=DlmFrontmatter(
166 dlm_id=dlm_id,
167 base_model="smollm2-135m",
168 training=TrainingConfig(
169 seed=42,
170 adapters={"knowledge": AdapterConfig()},
171 gate=GateConfig(enabled=False),
172 ),
173 ),
174 sections=(
175 Section(type=SectionType.PROSE, content="Shared domain prose."),
176 Section(
177 type=SectionType.INSTRUCTION,
178 content="### Q\nfacts?\n### A\nfacts.",
179 adapter="knowledge",
180 ),
181 ),
182 )
183 object.__setattr__(parsed.frontmatter.training.gate, "enabled", True)
184 results = run_all(
185 store,
186 parsed,
187 BASE_MODELS["smollm2-135m"],
188 _plan(),
189 mode="fresh",
190 trainer_factory=_mock_trainer_factory,
191 )
192 assert len(results) == 1
193
194
195 class TestMultiAdapterOrchestration:
196 def test_trains_each_declared_adapter(self, tmp_path: Path) -> None:
197 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
198 store = _seed_store(tmp_path, dlm_id)
199 results = run_all(
200 store,
201 _multi_adapter_parsed(dlm_id),
202 BASE_MODELS["smollm2-135m"],
203 _plan(),
204 mode="fresh",
205 trainer_factory=_mock_trainer_factory,
206 )
207 assert len(results) == 2
208
209 def test_each_adapter_gets_own_version_dir(self, tmp_path: Path) -> None:
210 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
211 store = _seed_store(tmp_path, dlm_id)
212 run_all(
213 store,
214 _multi_adapter_parsed(dlm_id),
215 BASE_MODELS["smollm2-135m"],
216 _plan(),
217 mode="fresh",
218 trainer_factory=_mock_trainer_factory,
219 )
220 assert store.adapter_version_for("knowledge", 1).is_dir()
221 assert store.adapter_version_for("tone", 1).is_dir()
222
223 def test_each_adapter_gets_own_current_pointer(self, tmp_path: Path) -> None:
224 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
225 store = _seed_store(tmp_path, dlm_id)
226 run_all(
227 store,
228 _multi_adapter_parsed(dlm_id),
229 BASE_MODELS["smollm2-135m"],
230 _plan(),
231 mode="fresh",
232 trainer_factory=_mock_trainer_factory,
233 )
234 assert store.resolve_current_adapter_for("knowledge") == (
235 store.adapter_version_for("knowledge", 1).resolve()
236 )
237 assert store.resolve_current_adapter_for("tone") == (
238 store.adapter_version_for("tone", 1).resolve()
239 )
240
241 def test_manifest_gets_one_run_per_adapter(self, tmp_path: Path) -> None:
242 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
243 store = _seed_store(tmp_path, dlm_id)
244 run_all(
245 store,
246 _multi_adapter_parsed(dlm_id),
247 BASE_MODELS["smollm2-135m"],
248 _plan(),
249 mode="fresh",
250 trainer_factory=_mock_trainer_factory,
251 )
252 from dlm.store.manifest import load_manifest
253
254 manifest = load_manifest(store.manifest)
255 assert len(manifest.training_runs) == 2
256
257 def test_declaration_order_preserved(self, tmp_path: Path) -> None:
258 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
259 store = _seed_store(tmp_path, dlm_id)
260 results = run_all(
261 store,
262 _multi_adapter_parsed(dlm_id),
263 BASE_MODELS["smollm2-135m"],
264 _plan(),
265 mode="fresh",
266 trainer_factory=_mock_trainer_factory,
267 )
268 # Knowledge is declared first; its run_id should be lower.
269 assert results[0].run_id < results[1].run_id
270
271 def test_manifest_adapter_versions_populated(self, tmp_path: Path) -> None:
272 """Audit-07 M1: multi-adapter runs bump per-adapter version dict,
273 not the flat `adapter_version`."""
274 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
275 store = _seed_store(tmp_path, dlm_id)
276 run_all(
277 store,
278 _multi_adapter_parsed(dlm_id),
279 BASE_MODELS["smollm2-135m"],
280 _plan(),
281 mode="fresh",
282 trainer_factory=_mock_trainer_factory,
283 )
284 from dlm.store.manifest import load_manifest
285
286 manifest = load_manifest(store.manifest)
287 assert manifest.adapter_versions == {"knowledge": 1, "tone": 1}
288 # Flat field stays at 0 (untouched) for multi-adapter stores.
289 assert manifest.adapter_version == 0
290
291 def test_training_run_summaries_carry_adapter_name(self, tmp_path: Path) -> None:
292 """Audit-07 M1: each TrainingRunSummary is tagged with the name."""
293 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB"
294 store = _seed_store(tmp_path, dlm_id)
295 run_all(
296 store,
297 _multi_adapter_parsed(dlm_id),
298 BASE_MODELS["smollm2-135m"],
299 _plan(),
300 mode="fresh",
301 trainer_factory=_mock_trainer_factory,
302 )
303 from dlm.store.manifest import load_manifest
304
305 manifest = load_manifest(store.manifest)
306 names = [r.adapter_name for r in manifest.training_runs]
307 assert sorted(names, key=str) == ["knowledge", "tone"]
308
309
310 class TestGatePass:
311 def test_enabled_gate_runs_post_sft_pass(
312 self,
313 tmp_path: Path,
314 monkeypatch: pytest.MonkeyPatch,
315 ) -> None:
316 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GC"
317 store = _seed_store(tmp_path, dlm_id)
318 parsed = _multi_adapter_parsed(dlm_id, gate_enabled=True)
319 seen: dict[str, object] = {}
320
321 def _fake_run_post_sft_gate(
322 store_arg: object,
323 parsed_arg: object,
324 *,
325 run_id: int,
326 recorder: object,
327 embed: object,
328 input_dim: int,
329 seed: int | None,
330 ) -> None:
331 seen.update(
332 {
333 "store": store_arg,
334 "parsed": parsed_arg,
335 "run_id": run_id,
336 "recorder": recorder,
337 "embed": embed,
338 "input_dim": input_dim,
339 "seed": seed,
340 }
341 )
342
343 def _embed(prompt: str) -> str:
344 return prompt.upper()
345
346 monkeypatch.setattr(gate_orchestrator, "run_post_sft_gate", _fake_run_post_sft_gate)
347 results = run_all(
348 store,
349 parsed,
350 BASE_MODELS["smollm2-135m"],
351 _plan(),
352 mode="fresh",
353 trainer_factory=_mock_trainer_factory,
354 gate_embed_factory=lambda: (_embed, 7),
355 )
356 assert seen["store"] == store
357 assert isinstance(seen["recorder"], object)
358 assert seen["parsed"] == parsed
359 assert seen["run_id"] == results[-1].run_id
360 assert callable(seen["embed"])
361 assert seen["input_dim"] == 7
362 assert seen["seed"] is None
363
364 def test_gate_embedder_failure_is_logged(
365 self,
366 tmp_path: Path,
367 caplog: pytest.LogCaptureFixture,
368 ) -> None:
369 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GD"
370 store = _seed_store(tmp_path, dlm_id)
371 caplog.set_level("WARNING")
372
373 def _boom_factory() -> tuple[object, int]:
374 raise RuntimeError("boom")
375
376 run_all(
377 store,
378 _multi_adapter_parsed(dlm_id, gate_enabled=True),
379 BASE_MODELS["smollm2-135m"],
380 _plan(),
381 mode="fresh",
382 trainer_factory=_mock_trainer_factory,
383 gate_embed_factory=_boom_factory,
384 )
385 assert "gate: embedder setup failed" in caplog.text
386
387 def test_gate_training_failure_is_logged(
388 self,
389 tmp_path: Path,
390 caplog: pytest.LogCaptureFixture,
391 monkeypatch: pytest.MonkeyPatch,
392 ) -> None:
393 dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GE"
394 store = _seed_store(tmp_path, dlm_id)
395 caplog.set_level("WARNING")
396
397 def _raising_gate(*args: object, **kwargs: object) -> None:
398 raise RuntimeError("gate boom")
399
400 def _embed(prompt: str) -> str:
401 return prompt
402
403 monkeypatch.setattr(gate_orchestrator, "run_post_sft_gate", _raising_gate)
404 run_all(
405 store,
406 _multi_adapter_parsed(dlm_id, gate_enabled=True),
407 BASE_MODELS["smollm2-135m"],
408 _plan(),
409 mode="fresh",
410 trainer_factory=_mock_trainer_factory,
411 gate_embed_factory=lambda: (_embed, 4),
412 )
413 assert "gate: post-SFT pass failed" in caplog.text
414
415
416 class TestResolveGateEmbedder:
417 def test_factory_path_is_used(self) -> None:
418 def _embed(prompt: str) -> str:
419 return prompt
420
421 resolved, input_dim = multi_adapter_trainer._resolve_gate_embedder(
422 BASE_MODELS["smollm2-135m"],
423 _plan(),
424 lambda: (_embed, 9),
425 )
426 assert resolved is _embed
427 assert input_dim == 9
428
429 def test_default_embedder_path_is_used_when_factory_missing(
430 self,
431 monkeypatch: pytest.MonkeyPatch,
432 ) -> None:
433 def _embed(prompt: str) -> str:
434 return prompt
435
436 def _fake_default(spec: object, plan: object) -> tuple[object, int]:
437 return _embed, 11
438
439 monkeypatch.setattr(multi_adapter_trainer, "_default_embedder", _fake_default)
440 resolved, input_dim = multi_adapter_trainer._resolve_gate_embedder(
441 BASE_MODELS["smollm2-135m"],
442 _plan(),
443 None,
444 )
445 assert resolved is _embed
446 assert input_dim == 11