Python · 5271 bytes Raw Blame History
1 """Session-scoped fixture that trains a tiny-model store once per test session.
2
3 Sprint 14.5 consumers (`tests/integration/train/`, `tests/integration/export/`,
4 `tests/integration/pack/`) all need "a store with a real trained adapter."
5 A session fixture amortizes the ~60-90s train cost across the slow-test run.
6
7 The fixture uses `dlm.train.run()` directly (not Typer CliRunner) so tests
8 can treat it as a clean dependency: any `dlm train` CLI regression belongs
9 in the one-cycle rewrite, not here. The CLI-level round trip is covered
10 by the pack test which exercises `dlm pack` / `dlm unpack` / `dlm prompt`
11 through the Typer stack.
12
13 Skips gracefully when:
14 - `torch` or `transformers` aren't importable (e.g. a CPU-only slim runner)
15 - The tiny-model fixture can't download / isn't cached
16 - `doctor().plan` returns None (no viable training plan on the host)
17 """
18
19 from __future__ import annotations
20
21 import os
22 from collections.abc import Iterator
23 from dataclasses import dataclass
24 from pathlib import Path
25 from typing import TYPE_CHECKING, Final
26
27 import pytest
28
29 if TYPE_CHECKING:
30 from dlm.hardware.capabilities import Capabilities
31 from dlm.hardware.plan import TrainingPlan
32 from dlm.store.paths import StorePath
33
34 # Small enough to keep wall-clock manageable on CPU (smollm2-135m: ~1s/step
35 # on a recent Mac, ~0.5s on a CI ubuntu runner). Exports and resumes can
36 # hash the resulting adapter either way.
37 _TRAIN_MAX_STEPS: Final[int] = 20
38 _TRAIN_SEED: Final[int] = 42
39
40
41 @dataclass(frozen=True)
42 class TrainedStoreHandle:
43 """Result of one shared training run.
44
45 `store` is a live `StorePath` (DLM_HOME is set to `home` for the
46 fixture's lifetime). `dlm_id` is cached so tests don't need to
47 re-parse the document.
48 """
49
50 doc: Path
51 home: Path
52 dlm_id: str
53 store: StorePath
54 plan: TrainingPlan
55 capabilities: Capabilities
56
57
58 @pytest.fixture(scope="session")
59 def trained_store(tmp_path_factory: pytest.TempPathFactory) -> Iterator[TrainedStoreHandle]:
60 """Train smollm2-135m once, yield handle with doc/home/store/dlm_id.
61
62 Requires full model weights, not just the tokenizer. On a dev machine
63 with a cold cache the first run takes ~2 minutes (download + 20-step
64 train). CI's slow-test job pre-warms the HF cache so subsequent runs
65 skip the download.
66 """
67 # Clear the autouse `_offline_hf_env` so snapshot_download / from_pretrained
68 # can pull missing model weights. Restored at teardown so downstream
69 # fast-path tests see the offline contract again.
70 offline_vars = ("HF_HUB_OFFLINE", "TRANSFORMERS_OFFLINE", "HF_DATASETS_OFFLINE")
71 saved_env = {k: os.environ.pop(k, None) for k in offline_vars}
72
73 try:
74 try:
75 import torch # noqa: F401
76 import transformers # noqa: F401
77 except ImportError as exc:
78 pytest.skip(f"torch/transformers unavailable: {exc}")
79
80 try:
81 from tests.fixtures.tiny_model import tiny_model_path
82
83 tiny_model_path() # force-resolve the snapshot (may download)
84 except Exception as exc:
85 pytest.skip(f"tiny-model fixture unavailable: {exc}")
86
87 from dlm.base_models import resolve as resolve_base_model
88 from dlm.doc.parser import parse_file
89 from dlm.hardware import doctor
90 from dlm.store.paths import for_dlm
91 from dlm.train import run as run_training
92 from tests.fixtures.dlm_factory import make_dlm
93
94 home = tmp_path_factory.mktemp("dlm-trained-home")
95 os.environ["DLM_HOME"] = str(home)
96
97 doc = home / "smoke.dlm"
98 doc.write_text(make_dlm(base_model="smollm2-135m"), encoding="utf-8")
99
100 parsed = parse_file(doc)
101 spec = resolve_base_model(parsed.frontmatter.base_model)
102 doctor_result = doctor(
103 training_config=parsed.frontmatter.training,
104 base_params=spec.params,
105 seq_len=min(parsed.frontmatter.training.sequence_len, spec.effective_context_length),
106 )
107 plan = doctor_result.plan
108 if plan is None:
109 pytest.skip("doctor() returned no viable training plan on this host")
110 store = for_dlm(parsed.frontmatter.dlm_id)
111 store.ensure_layout()
112
113 # Seed the initial manifest — `dlm init` owns this in the CLI
114 # path, but the fixture skips `init` and calls `run_training`
115 # directly. Missing manifest → ManifestCorruptError on load.
116 from dlm.store.manifest import Manifest, save_manifest
117
118 save_manifest(
119 store.manifest,
120 Manifest(
121 dlm_id=parsed.frontmatter.dlm_id,
122 base_model=parsed.frontmatter.base_model,
123 ),
124 )
125
126 run_training(
127 store,
128 parsed,
129 spec,
130 plan,
131 mode="fresh",
132 seed=_TRAIN_SEED,
133 max_steps=_TRAIN_MAX_STEPS,
134 )
135
136 yield TrainedStoreHandle(
137 doc=doc,
138 home=home,
139 dlm_id=parsed.frontmatter.dlm_id,
140 store=store,
141 plan=plan,
142 capabilities=doctor_result.capabilities,
143 )
144 finally:
145 for key, value in saved_env.items():
146 if value is not None:
147 os.environ[key] = value