| 1 |
"""`run_all` orchestration over named adapters.""" |
| 2 |
|
| 3 |
from __future__ import annotations |
| 4 |
|
| 5 |
from pathlib import Path |
| 6 |
from types import SimpleNamespace |
| 7 |
from typing import Any |
| 8 |
from unittest.mock import MagicMock |
| 9 |
|
| 10 |
import pytest |
| 11 |
|
| 12 |
import dlm.train.gate.orchestrator as gate_orchestrator |
| 13 |
import dlm.train.multi_adapter.trainer as multi_adapter_trainer |
| 14 |
from dlm.base_models import BASE_MODELS |
| 15 |
from dlm.doc.parser import ParsedDlm |
| 16 |
from dlm.doc.schema import ( |
| 17 |
AdapterConfig, |
| 18 |
DlmFrontmatter, |
| 19 |
GateConfig, |
| 20 |
TrainingConfig, |
| 21 |
) |
| 22 |
from dlm.doc.sections import Section, SectionType |
| 23 |
from dlm.store.manifest import Manifest, save_manifest |
| 24 |
from dlm.store.paths import for_dlm |
| 25 |
from dlm.train.multi_adapter.trainer import run_all |
| 26 |
|
| 27 |
|
| 28 |
def _plan() -> SimpleNamespace: |
| 29 |
return SimpleNamespace( |
| 30 |
precision="bf16", |
| 31 |
attn_implementation="sdpa", |
| 32 |
use_qlora=False, |
| 33 |
quant_compute_dtype=None, |
| 34 |
micro_batch_size=1, |
| 35 |
grad_accum=1, |
| 36 |
effective_batch_size=1, |
| 37 |
gradient_checkpointing=False, |
| 38 |
est_peak_vram_gb=1.0, |
| 39 |
est_step_seconds=0.1, |
| 40 |
reason="test", |
| 41 |
to_dict=lambda: {"precision": "bf16"}, |
| 42 |
) |
| 43 |
|
| 44 |
|
| 45 |
def _mock_trainer_factory(**_: Any) -> MagicMock: |
| 46 |
sft = MagicMock() |
| 47 |
sft.state = SimpleNamespace(global_step=5, epoch=1.0, best_metric=0.9) |
| 48 |
sft.optimizer = SimpleNamespace(state_dict=lambda: {"lr": 1e-4}) |
| 49 |
sft.lr_scheduler = SimpleNamespace(state_dict=lambda: {"step": 5}) |
| 50 |
sft.scaler = None |
| 51 |
sft.control = SimpleNamespace(should_training_stop=False) |
| 52 |
sft.train.return_value = SimpleNamespace(training_loss=1.0) |
| 53 |
|
| 54 |
def _save_model(path: str) -> None: |
| 55 |
p = Path(path) |
| 56 |
p.mkdir(parents=True, exist_ok=True) |
| 57 |
(p / "adapter_config.json").write_text("{}") |
| 58 |
(p / "adapter_model.safetensors").write_bytes(b"\x00" * 32) |
| 59 |
|
| 60 |
sft.save_model.side_effect = _save_model |
| 61 |
return sft |
| 62 |
|
| 63 |
|
| 64 |
def _multi_adapter_parsed(dlm_id: str, *, gate_enabled: bool = False) -> ParsedDlm: |
| 65 |
return ParsedDlm( |
| 66 |
frontmatter=DlmFrontmatter( |
| 67 |
dlm_id=dlm_id, |
| 68 |
base_model="smollm2-135m", |
| 69 |
training=TrainingConfig( |
| 70 |
seed=42, |
| 71 |
gate=GateConfig(enabled=gate_enabled), |
| 72 |
adapters={ |
| 73 |
"knowledge": AdapterConfig(), |
| 74 |
"tone": AdapterConfig(lora_r=4), |
| 75 |
}, |
| 76 |
), |
| 77 |
), |
| 78 |
sections=( |
| 79 |
Section(type=SectionType.PROSE, content="Shared domain prose."), |
| 80 |
Section( |
| 81 |
type=SectionType.INSTRUCTION, |
| 82 |
content="### Q\nfacts?\n### A\nfacts.", |
| 83 |
adapter="knowledge", |
| 84 |
), |
| 85 |
Section( |
| 86 |
type=SectionType.INSTRUCTION, |
| 87 |
content="### Q\ntone?\n### A\ncrisp.", |
| 88 |
adapter="tone", |
| 89 |
), |
| 90 |
), |
| 91 |
) |
| 92 |
|
| 93 |
|
| 94 |
def _single_adapter_parsed(dlm_id: str) -> ParsedDlm: |
| 95 |
return ParsedDlm( |
| 96 |
frontmatter=DlmFrontmatter( |
| 97 |
dlm_id=dlm_id, |
| 98 |
base_model="smollm2-135m", |
| 99 |
training=TrainingConfig(seed=42), |
| 100 |
), |
| 101 |
sections=(Section(type=SectionType.PROSE, content="Single-adapter prose."),), |
| 102 |
) |
| 103 |
|
| 104 |
|
| 105 |
def _seed_store(tmp_path: Path, dlm_id: str) -> Any: |
| 106 |
store = for_dlm(dlm_id, home=tmp_path) |
| 107 |
store.ensure_layout() |
| 108 |
save_manifest(store.manifest, Manifest(dlm_id=dlm_id, base_model="smollm2-135m")) |
| 109 |
return store |
| 110 |
|
| 111 |
|
| 112 |
class TestSingleAdapterPassthrough: |
| 113 |
def test_single_adapter_doc_yields_one_result(self, tmp_path: Path) -> None: |
| 114 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FA" |
| 115 |
store = _seed_store(tmp_path, dlm_id) |
| 116 |
results = run_all( |
| 117 |
store, |
| 118 |
_single_adapter_parsed(dlm_id), |
| 119 |
BASE_MODELS["smollm2-135m"], |
| 120 |
_plan(), |
| 121 |
mode="fresh", |
| 122 |
trainer_factory=_mock_trainer_factory, |
| 123 |
) |
| 124 |
assert len(results) == 1 |
| 125 |
# Flat layout: version dir lives under adapter/versions/, not a named subdir. |
| 126 |
assert store.adapter_version(1).is_dir() |
| 127 |
|
| 128 |
def test_one_named_adapter_still_passthroughs(self, tmp_path: Path) -> None: |
| 129 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FZ" |
| 130 |
store = _seed_store(tmp_path, dlm_id) |
| 131 |
parsed = ParsedDlm( |
| 132 |
frontmatter=DlmFrontmatter( |
| 133 |
dlm_id=dlm_id, |
| 134 |
base_model="smollm2-135m", |
| 135 |
training=TrainingConfig( |
| 136 |
seed=42, |
| 137 |
adapters={"knowledge": AdapterConfig()}, |
| 138 |
), |
| 139 |
), |
| 140 |
sections=( |
| 141 |
Section(type=SectionType.PROSE, content="Shared domain prose."), |
| 142 |
Section( |
| 143 |
type=SectionType.INSTRUCTION, |
| 144 |
content="### Q\nfacts?\n### A\nfacts.", |
| 145 |
adapter="knowledge", |
| 146 |
), |
| 147 |
), |
| 148 |
) |
| 149 |
results = run_all( |
| 150 |
store, |
| 151 |
parsed, |
| 152 |
BASE_MODELS["smollm2-135m"], |
| 153 |
_plan(), |
| 154 |
mode="fresh", |
| 155 |
trainer_factory=_mock_trainer_factory, |
| 156 |
) |
| 157 |
assert len(results) == 1 |
| 158 |
|
| 159 |
def test_gate_enabled_with_one_named_adapter_still_returns_one_result( |
| 160 |
self, tmp_path: Path |
| 161 |
) -> None: |
| 162 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FY" |
| 163 |
store = _seed_store(tmp_path, dlm_id) |
| 164 |
parsed = ParsedDlm( |
| 165 |
frontmatter=DlmFrontmatter( |
| 166 |
dlm_id=dlm_id, |
| 167 |
base_model="smollm2-135m", |
| 168 |
training=TrainingConfig( |
| 169 |
seed=42, |
| 170 |
adapters={"knowledge": AdapterConfig()}, |
| 171 |
gate=GateConfig(enabled=False), |
| 172 |
), |
| 173 |
), |
| 174 |
sections=( |
| 175 |
Section(type=SectionType.PROSE, content="Shared domain prose."), |
| 176 |
Section( |
| 177 |
type=SectionType.INSTRUCTION, |
| 178 |
content="### Q\nfacts?\n### A\nfacts.", |
| 179 |
adapter="knowledge", |
| 180 |
), |
| 181 |
), |
| 182 |
) |
| 183 |
object.__setattr__(parsed.frontmatter.training.gate, "enabled", True) |
| 184 |
results = run_all( |
| 185 |
store, |
| 186 |
parsed, |
| 187 |
BASE_MODELS["smollm2-135m"], |
| 188 |
_plan(), |
| 189 |
mode="fresh", |
| 190 |
trainer_factory=_mock_trainer_factory, |
| 191 |
) |
| 192 |
assert len(results) == 1 |
| 193 |
|
| 194 |
|
| 195 |
class TestMultiAdapterOrchestration: |
| 196 |
def test_trains_each_declared_adapter(self, tmp_path: Path) -> None: |
| 197 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 198 |
store = _seed_store(tmp_path, dlm_id) |
| 199 |
results = run_all( |
| 200 |
store, |
| 201 |
_multi_adapter_parsed(dlm_id), |
| 202 |
BASE_MODELS["smollm2-135m"], |
| 203 |
_plan(), |
| 204 |
mode="fresh", |
| 205 |
trainer_factory=_mock_trainer_factory, |
| 206 |
) |
| 207 |
assert len(results) == 2 |
| 208 |
|
| 209 |
def test_each_adapter_gets_own_version_dir(self, tmp_path: Path) -> None: |
| 210 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 211 |
store = _seed_store(tmp_path, dlm_id) |
| 212 |
run_all( |
| 213 |
store, |
| 214 |
_multi_adapter_parsed(dlm_id), |
| 215 |
BASE_MODELS["smollm2-135m"], |
| 216 |
_plan(), |
| 217 |
mode="fresh", |
| 218 |
trainer_factory=_mock_trainer_factory, |
| 219 |
) |
| 220 |
assert store.adapter_version_for("knowledge", 1).is_dir() |
| 221 |
assert store.adapter_version_for("tone", 1).is_dir() |
| 222 |
|
| 223 |
def test_each_adapter_gets_own_current_pointer(self, tmp_path: Path) -> None: |
| 224 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 225 |
store = _seed_store(tmp_path, dlm_id) |
| 226 |
run_all( |
| 227 |
store, |
| 228 |
_multi_adapter_parsed(dlm_id), |
| 229 |
BASE_MODELS["smollm2-135m"], |
| 230 |
_plan(), |
| 231 |
mode="fresh", |
| 232 |
trainer_factory=_mock_trainer_factory, |
| 233 |
) |
| 234 |
assert store.resolve_current_adapter_for("knowledge") == ( |
| 235 |
store.adapter_version_for("knowledge", 1).resolve() |
| 236 |
) |
| 237 |
assert store.resolve_current_adapter_for("tone") == ( |
| 238 |
store.adapter_version_for("tone", 1).resolve() |
| 239 |
) |
| 240 |
|
| 241 |
def test_manifest_gets_one_run_per_adapter(self, tmp_path: Path) -> None: |
| 242 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 243 |
store = _seed_store(tmp_path, dlm_id) |
| 244 |
run_all( |
| 245 |
store, |
| 246 |
_multi_adapter_parsed(dlm_id), |
| 247 |
BASE_MODELS["smollm2-135m"], |
| 248 |
_plan(), |
| 249 |
mode="fresh", |
| 250 |
trainer_factory=_mock_trainer_factory, |
| 251 |
) |
| 252 |
from dlm.store.manifest import load_manifest |
| 253 |
|
| 254 |
manifest = load_manifest(store.manifest) |
| 255 |
assert len(manifest.training_runs) == 2 |
| 256 |
|
| 257 |
def test_declaration_order_preserved(self, tmp_path: Path) -> None: |
| 258 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 259 |
store = _seed_store(tmp_path, dlm_id) |
| 260 |
results = run_all( |
| 261 |
store, |
| 262 |
_multi_adapter_parsed(dlm_id), |
| 263 |
BASE_MODELS["smollm2-135m"], |
| 264 |
_plan(), |
| 265 |
mode="fresh", |
| 266 |
trainer_factory=_mock_trainer_factory, |
| 267 |
) |
| 268 |
# Knowledge is declared first; its run_id should be lower. |
| 269 |
assert results[0].run_id < results[1].run_id |
| 270 |
|
| 271 |
def test_manifest_adapter_versions_populated(self, tmp_path: Path) -> None: |
| 272 |
"""Audit-07 M1: multi-adapter runs bump per-adapter version dict, |
| 273 |
not the flat `adapter_version`.""" |
| 274 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 275 |
store = _seed_store(tmp_path, dlm_id) |
| 276 |
run_all( |
| 277 |
store, |
| 278 |
_multi_adapter_parsed(dlm_id), |
| 279 |
BASE_MODELS["smollm2-135m"], |
| 280 |
_plan(), |
| 281 |
mode="fresh", |
| 282 |
trainer_factory=_mock_trainer_factory, |
| 283 |
) |
| 284 |
from dlm.store.manifest import load_manifest |
| 285 |
|
| 286 |
manifest = load_manifest(store.manifest) |
| 287 |
assert manifest.adapter_versions == {"knowledge": 1, "tone": 1} |
| 288 |
# Flat field stays at 0 (untouched) for multi-adapter stores. |
| 289 |
assert manifest.adapter_version == 0 |
| 290 |
|
| 291 |
def test_training_run_summaries_carry_adapter_name(self, tmp_path: Path) -> None: |
| 292 |
"""Audit-07 M1: each TrainingRunSummary is tagged with the name.""" |
| 293 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6FB" |
| 294 |
store = _seed_store(tmp_path, dlm_id) |
| 295 |
run_all( |
| 296 |
store, |
| 297 |
_multi_adapter_parsed(dlm_id), |
| 298 |
BASE_MODELS["smollm2-135m"], |
| 299 |
_plan(), |
| 300 |
mode="fresh", |
| 301 |
trainer_factory=_mock_trainer_factory, |
| 302 |
) |
| 303 |
from dlm.store.manifest import load_manifest |
| 304 |
|
| 305 |
manifest = load_manifest(store.manifest) |
| 306 |
names = [r.adapter_name for r in manifest.training_runs] |
| 307 |
assert sorted(names, key=str) == ["knowledge", "tone"] |
| 308 |
|
| 309 |
|
| 310 |
class TestGatePass: |
| 311 |
def test_enabled_gate_runs_post_sft_pass( |
| 312 |
self, |
| 313 |
tmp_path: Path, |
| 314 |
monkeypatch: pytest.MonkeyPatch, |
| 315 |
) -> None: |
| 316 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GC" |
| 317 |
store = _seed_store(tmp_path, dlm_id) |
| 318 |
parsed = _multi_adapter_parsed(dlm_id, gate_enabled=True) |
| 319 |
seen: dict[str, object] = {} |
| 320 |
|
| 321 |
def _fake_run_post_sft_gate( |
| 322 |
store_arg: object, |
| 323 |
parsed_arg: object, |
| 324 |
*, |
| 325 |
run_id: int, |
| 326 |
recorder: object, |
| 327 |
embed: object, |
| 328 |
input_dim: int, |
| 329 |
seed: int | None, |
| 330 |
) -> None: |
| 331 |
seen.update( |
| 332 |
{ |
| 333 |
"store": store_arg, |
| 334 |
"parsed": parsed_arg, |
| 335 |
"run_id": run_id, |
| 336 |
"recorder": recorder, |
| 337 |
"embed": embed, |
| 338 |
"input_dim": input_dim, |
| 339 |
"seed": seed, |
| 340 |
} |
| 341 |
) |
| 342 |
|
| 343 |
def _embed(prompt: str) -> str: |
| 344 |
return prompt.upper() |
| 345 |
|
| 346 |
monkeypatch.setattr(gate_orchestrator, "run_post_sft_gate", _fake_run_post_sft_gate) |
| 347 |
results = run_all( |
| 348 |
store, |
| 349 |
parsed, |
| 350 |
BASE_MODELS["smollm2-135m"], |
| 351 |
_plan(), |
| 352 |
mode="fresh", |
| 353 |
trainer_factory=_mock_trainer_factory, |
| 354 |
gate_embed_factory=lambda: (_embed, 7), |
| 355 |
) |
| 356 |
assert seen["store"] == store |
| 357 |
assert isinstance(seen["recorder"], object) |
| 358 |
assert seen["parsed"] == parsed |
| 359 |
assert seen["run_id"] == results[-1].run_id |
| 360 |
assert callable(seen["embed"]) |
| 361 |
assert seen["input_dim"] == 7 |
| 362 |
assert seen["seed"] is None |
| 363 |
|
| 364 |
def test_gate_embedder_failure_is_logged( |
| 365 |
self, |
| 366 |
tmp_path: Path, |
| 367 |
caplog: pytest.LogCaptureFixture, |
| 368 |
) -> None: |
| 369 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GD" |
| 370 |
store = _seed_store(tmp_path, dlm_id) |
| 371 |
caplog.set_level("WARNING") |
| 372 |
|
| 373 |
def _boom_factory() -> tuple[object, int]: |
| 374 |
raise RuntimeError("boom") |
| 375 |
|
| 376 |
run_all( |
| 377 |
store, |
| 378 |
_multi_adapter_parsed(dlm_id, gate_enabled=True), |
| 379 |
BASE_MODELS["smollm2-135m"], |
| 380 |
_plan(), |
| 381 |
mode="fresh", |
| 382 |
trainer_factory=_mock_trainer_factory, |
| 383 |
gate_embed_factory=_boom_factory, |
| 384 |
) |
| 385 |
assert "gate: embedder setup failed" in caplog.text |
| 386 |
|
| 387 |
def test_gate_training_failure_is_logged( |
| 388 |
self, |
| 389 |
tmp_path: Path, |
| 390 |
caplog: pytest.LogCaptureFixture, |
| 391 |
monkeypatch: pytest.MonkeyPatch, |
| 392 |
) -> None: |
| 393 |
dlm_id = "01HZ4X7TGZM3J1A2B3C4D5E6GE" |
| 394 |
store = _seed_store(tmp_path, dlm_id) |
| 395 |
caplog.set_level("WARNING") |
| 396 |
|
| 397 |
def _raising_gate(*args: object, **kwargs: object) -> None: |
| 398 |
raise RuntimeError("gate boom") |
| 399 |
|
| 400 |
def _embed(prompt: str) -> str: |
| 401 |
return prompt |
| 402 |
|
| 403 |
monkeypatch.setattr(gate_orchestrator, "run_post_sft_gate", _raising_gate) |
| 404 |
run_all( |
| 405 |
store, |
| 406 |
_multi_adapter_parsed(dlm_id, gate_enabled=True), |
| 407 |
BASE_MODELS["smollm2-135m"], |
| 408 |
_plan(), |
| 409 |
mode="fresh", |
| 410 |
trainer_factory=_mock_trainer_factory, |
| 411 |
gate_embed_factory=lambda: (_embed, 4), |
| 412 |
) |
| 413 |
assert "gate: post-SFT pass failed" in caplog.text |
| 414 |
|
| 415 |
|
| 416 |
class TestResolveGateEmbedder: |
| 417 |
def test_factory_path_is_used(self) -> None: |
| 418 |
def _embed(prompt: str) -> str: |
| 419 |
return prompt |
| 420 |
|
| 421 |
resolved, input_dim = multi_adapter_trainer._resolve_gate_embedder( |
| 422 |
BASE_MODELS["smollm2-135m"], |
| 423 |
_plan(), |
| 424 |
lambda: (_embed, 9), |
| 425 |
) |
| 426 |
assert resolved is _embed |
| 427 |
assert input_dim == 9 |
| 428 |
|
| 429 |
def test_default_embedder_path_is_used_when_factory_missing( |
| 430 |
self, |
| 431 |
monkeypatch: pytest.MonkeyPatch, |
| 432 |
) -> None: |
| 433 |
def _embed(prompt: str) -> str: |
| 434 |
return prompt |
| 435 |
|
| 436 |
def _fake_default(spec: object, plan: object) -> tuple[object, int]: |
| 437 |
return _embed, 11 |
| 438 |
|
| 439 |
monkeypatch.setattr(multi_adapter_trainer, "_default_embedder", _fake_default) |
| 440 |
resolved, input_dim = multi_adapter_trainer._resolve_gate_embedder( |
| 441 |
BASE_MODELS["smollm2-135m"], |
| 442 |
_plan(), |
| 443 |
None, |
| 444 |
) |
| 445 |
assert resolved is _embed |
| 446 |
assert input_dim == 11 |