Python · 9380 bytes Raw Blame History
1 """Resolve a user-supplied base-model spec to a `BaseModelSpec`.
2
3 Spec grammar:
4
5 - `<key>` — registry lookup (e.g., `qwen2.5-1.5b`). `UnknownBaseModelError`
6 if not present.
7 - `hf:<org>/<name>` — escape hatch. Fetches config.json + tokenizer
8 metadata from HF, synthesizes a `BaseModelSpec`, runs the probe suite,
9 and raises `ProbeFailedError` if any hard probe fails.
10
11 Gated models (`requires_acceptance=True`) raise `GatedModelError` unless
12 the caller has already accepted the license (signalled via
13 `accept_license=True`). The CLI uses this to persist acceptance; tests
14 pass `accept_license=True` directly to exercise the downstream path.
15 """
16
17 from __future__ import annotations
18
19 import logging
20 from typing import Final, Literal
21
22 from dlm.base_models.errors import (
23 GatedModelError,
24 ProbeFailedError,
25 ProbeResult,
26 UnknownBaseModelError,
27 )
28 from dlm.base_models.registry import BASE_MODELS, known_keys
29 from dlm.base_models.schema import BaseModelSpec
30
31 TemplateDialect = Literal[
32 "chatml",
33 "qwen3thinking",
34 "gemma2",
35 "smollm3",
36 "olmo2",
37 "llama3",
38 "phi3",
39 "phi4mini",
40 "mistral",
41 ]
42
43 _LOG = logging.getLogger(__name__)
44
45 _HF_PREFIX: Final = "hf:"
46
47
48 def resolve(
49 spec: str,
50 *,
51 accept_license: bool = False,
52 skip_export_probes: bool = False,
53 ) -> BaseModelSpec:
54 """Return the `BaseModelSpec` for `spec`.
55
56 Registry lookup first; `hf:`-prefix falls through to `resolve_hf()`.
57 Gating is enforced here regardless of path. `skip_export_probes`
58 only applies to the `hf:` path — registry entries are curated and
59 always pass all probes by construction.
60 """
61 if spec.startswith(_HF_PREFIX):
62 return resolve_hf(
63 spec[len(_HF_PREFIX) :],
64 accept_license=accept_license,
65 skip_export_probes=skip_export_probes,
66 )
67
68 entry = BASE_MODELS.get(spec)
69 if entry is None:
70 raise UnknownBaseModelError(spec, known_keys())
71
72 _enforce_gate(entry, accept_license=accept_license)
73 return entry
74
75
76 def _env_skip_export_probes() -> bool:
77 """Read `DLM_SKIP_EXPORT_PROBES` — set by power users whose base isn't
78 yet in vendored llama.cpp but who only need training + HF inference.
79
80 Checked by every `resolve` path so `dlm train/prompt/export` inherits
81 the decision the user made at `dlm init --skip-export-probes` time
82 without persisting extra state on the per-store manifest.
83 """
84 import os
85
86 return os.environ.get("DLM_SKIP_EXPORT_PROBES", "").strip().lower() in (
87 "1",
88 "true",
89 "yes",
90 )
91
92
93 def resolve_hf(
94 hf_id: str,
95 *,
96 accept_license: bool = False,
97 skip_export_probes: bool = False,
98 ) -> BaseModelSpec:
99 """Synthesize a `BaseModelSpec` for an arbitrary HF model id.
100
101 Runs the probe suite; raises `ProbeFailedError` with a full report
102 if any hard probe fails. This is the gate that prevents users from
103 pinning a model our export pipeline can't actually convert.
104
105 `skip_export_probes=True` drops the llama.cpp / GGUF-conversion
106 probes so brand-new architectures (not yet in the vendored
107 llama.cpp) can still train + HF-infer. Users opting in forfeit
108 `dlm export` until the vendored copy catches up.
109 """
110 # Deferred import: probes pull transformers, which is expensive.
111 from dlm.base_models import probes
112
113 spec = _synthesize_spec(hf_id)
114 _enforce_gate(spec, accept_license=accept_license)
115
116 skip = skip_export_probes or _env_skip_export_probes()
117 report = probes.run_all(spec, skip_export_probes=skip)
118 if not report.passed:
119 raise ProbeFailedError(spec.hf_id, list(report.results))
120 return spec
121
122
123 # --- internals ---------------------------------------------------------------
124
125
126 def _enforce_gate(spec: BaseModelSpec, *, accept_license: bool) -> None:
127 if spec.requires_acceptance and not accept_license:
128 raise GatedModelError(spec.hf_id, spec.license_url)
129
130
131 def _synthesize_spec(hf_id: str) -> BaseModelSpec:
132 """Build a minimal `BaseModelSpec` for an arbitrary HF id.
133
134 Pulls config + tokenizer_config metadata from the Hub so probes have
135 real data to work against. The synthesized spec is shaped to pass
136 `BaseModelSpec` validation; users who want tighter defaults should
137 add the model to the curated registry instead.
138 """
139 if "/" not in hf_id or hf_id.startswith("/") or hf_id.endswith("/"):
140 raise UnknownBaseModelError(f"hf:{hf_id}", known_keys())
141
142 try:
143 from huggingface_hub import HfApi
144 from huggingface_hub.errors import (
145 EntryNotFoundError,
146 GatedRepoError,
147 RepositoryNotFoundError,
148 )
149 from transformers import AutoConfig
150 except ImportError as exc: # pragma: no cover — dev env always has these
151 raise RuntimeError(
152 "hf: escape hatch requires huggingface_hub + transformers; install dev deps"
153 ) from exc
154
155 api = HfApi()
156 try:
157 info = api.model_info(hf_id)
158 except GatedRepoError as exc:
159 raise GatedModelError(hf_id, license_url=None) from exc
160 except RepositoryNotFoundError as exc:
161 raise UnknownBaseModelError(f"hf:{hf_id}", known_keys()) from exc
162
163 revision = info.sha
164 if not revision or len(revision) != 40:
165 raise RuntimeError(f"HF returned non-40-char SHA for {hf_id}: {revision!r}")
166
167 try:
168 config = AutoConfig.from_pretrained(hf_id, revision=revision)
169 except GatedRepoError as exc:
170 raise GatedModelError(hf_id, license_url=None) from exc
171 except EntryNotFoundError as exc:
172 raise UnknownBaseModelError(f"hf:{hf_id}", known_keys()) from exc
173
174 architectures = getattr(config, "architectures", None) or ()
175 if not architectures:
176 # Build a single synthetic failure so the caller has something
177 # to show — we can't construct a BaseModelSpec without arch.
178 raise ProbeFailedError(
179 hf_id,
180 [
181 ProbeResult(
182 name="architecture",
183 passed=False,
184 detail="config.json has no `architectures` entry",
185 )
186 ],
187 )
188
189 architecture = architectures[0]
190 params = getattr(config, "num_parameters", None) or _estimate_params(config)
191 context_length = (
192 getattr(config, "max_position_embeddings", None)
193 or getattr(config, "n_positions", None)
194 or 4096
195 )
196
197 gguf_arch = _infer_gguf_arch(architecture)
198 template = _infer_template(hf_id, architecture)
199
200 # `hf:` models are advisory — we can't audit their license from here
201 # alone; mark them conservatively as requiring acceptance + not
202 # redistributable. Users who know better add the model to the registry.
203 return BaseModelSpec(
204 key=f"hf:{hf_id}",
205 hf_id=hf_id,
206 revision=revision,
207 architecture=architecture,
208 params=params,
209 target_modules=_default_target_modules(gguf_arch),
210 template=template,
211 gguf_arch=gguf_arch,
212 tokenizer_pre="default",
213 license_spdx="Unknown",
214 license_url=None,
215 requires_acceptance=False,
216 redistributable=False,
217 size_gb_fp16=max(0.1, params * 2 / (1024**3)),
218 context_length=context_length,
219 recommended_seq_len=min(context_length, 2048),
220 )
221
222
223 def _estimate_params(config: object) -> int:
224 """Rough param count from hidden_size / num_hidden_layers / vocab_size."""
225 hidden: int = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None) or 2048
226 layers: int = (
227 getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None) or 24
228 )
229 vocab: int = getattr(config, "vocab_size", None) or 32_000
230 # 12 * h^2 * L is a textbook approximation of transformer params; add embeddings.
231 return int(12 * hidden**2 * layers + 2 * hidden * vocab)
232
233
234 def _infer_gguf_arch(architecture: str) -> str:
235 mapping = {
236 "LlamaForCausalLM": "llama",
237 "SmolLM3ForCausalLM": "llama",
238 "Olmo2ForCausalLM": "olmo2",
239 "Qwen2ForCausalLM": "qwen2",
240 "Qwen3ForCausalLM": "qwen3",
241 "MistralForCausalLM": "llama",
242 "MixtralForCausalLM": "llama",
243 "Phi3ForCausalLM": "phi3",
244 "GemmaForCausalLM": "gemma",
245 "Gemma2ForCausalLM": "gemma2",
246 }
247 return mapping.get(architecture, architecture.lower().replace("forcausallm", ""))
248
249
250 def _infer_template(hf_id: str, architecture: str) -> TemplateDialect:
251 """Best-effort template dialect picker for `hf:` synthesis."""
252 lower = hf_id.lower()
253 if "gemma-2" in lower or architecture.startswith("Gemma2"):
254 return "gemma2"
255 if "smollm3" in lower or architecture.startswith("SmolLM3"):
256 return "smollm3"
257 if "olmo-2" in lower or architecture.startswith("Olmo2"):
258 return "olmo2"
259 if "llama-3" in lower or "llama3" in lower:
260 return "llama3"
261 if "phi-4-mini-reasoning" in lower:
262 return "phi4mini"
263 if architecture.startswith("Phi"):
264 return "phi3"
265 if architecture.startswith(("Mistral", "Mixtral")):
266 return "mistral"
267 return "chatml"
268
269
270 def _default_target_modules(gguf_arch: str) -> list[str]:
271 if gguf_arch == "phi3":
272 return ["qkv_proj", "o_proj", "gate_up_proj", "down_proj"]
273 return ["q_proj", "k_proj", "v_proj", "o_proj"]