Python · 11769 bytes Raw Blame History
1 """`BaseModelSpec` — curated metadata for a single pretrained base model.
2
3 Every field is strict: `extra="forbid"`, frozen, and validated on
4 instantiation. Values pack everything the rest of the project needs to
5 know about a base without re-fetching HF metadata at every decision
6 point:
7
8 - `revision`: 40-char commit SHA. Enforced non-None so retrains under the
9 same spec pin at exactly the same weights.
10 - `target_modules`: per-architecture LoRA target list (see findings §8;
11 `"all-linear"` is avoided because it bloats small models).
12 - `template`: the chat-template dialect used by the Go-template
13 registry for Modelfile generation.
14 - `gguf_arch` / `tokenizer_pre`: identifiers the llama.cpp converter
15 matches against; export preflight uses them.
16 - `reasoning_tuned` / `context_length_effective`: additive registry
17 hints for prompt defaults and realistic doctor estimates. The
18 effective length defaults to the nominal context window when unset.
19 - `refresh_check_hf_gating` / `provenance_url` /
20 `provenance_match_text`: live-registry refresh hints for entries
21 whose fetch mirror and first-party provenance page are not the same
22 system.
23 - License / gating: separate fields for SPDX, acceptance gating, and
24 re-distribution — each consumed by a different policy gate (license
25 acceptance, pack `--include-base`, share-protocol refusal).
26 """
27
28 from __future__ import annotations
29
30 import re
31 from typing import Final, Literal
32
33 from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
34
35 _SHA_RE: Final[re.Pattern[str]] = re.compile(r"^[0-9a-f]{40}$")
36 DEFAULT_PROMPT_TEMPERATURE: Final[float] = 0.7
37 DEFAULT_REASONING_PROMPT_TEMPERATURE: Final[float] = 0.6
38
39
40 class VlPreprocessorPlan(BaseModel):
41 """Per-base vision-preprocessing parameters.
42
43 Pinned at registry-build time so `dlm export` + the VL cache key
44 stay stable across reruns. HF's `AutoProcessor` is the source of
45 truth at runtime; this block records the *expected* shape for
46 preflight checks + cache keying.
47
48 `target_size` is `(height, width)` in pixels. `resize_policy`
49 defaults to `"fixed"` because that's what the current launch
50 registry ships. `image_token` is the textual placeholder inserted
51 into prompts before the processor expands it into
52 `num_image_tokens` copies.
53 """
54
55 model_config = ConfigDict(extra="forbid", frozen=True)
56
57 target_size: tuple[int, int] = Field(..., description="(height, width) in pixels")
58 resize_policy: Literal["fixed", "dynamic"] = "fixed"
59 image_token: str = Field(..., min_length=1, description="Placeholder token string")
60 num_image_tokens: int = Field(..., gt=0, description="Tokens consumed per image")
61
62 @field_validator("target_size")
63 @classmethod
64 def _validate_target_size(cls, value: tuple[int, int]) -> tuple[int, int]:
65 h, w = value
66 if h <= 0 or w <= 0:
67 raise ValueError(f"target_size must be positive, got {value!r}")
68 return value
69
70
71 class AudioPreprocessorPlan(BaseModel):
72 """Per-base audio-preprocessing parameters.
73
74 Mirrors `VlPreprocessorPlan` — pinned at registry-build time so
75 the audio cache key stays stable. Current releases refuse audio at
76 non-target `sample_rate`; resampling lands as a follow-up.
77
78 `sample_rate` is the model's training rate in Hz (Qwen2-Audio:
79 16000). `max_length_seconds` caps the per-clip duration the
80 processor sees; longer clips are truncated (the processor's
81 built-in policy). `audio_token` is the textual placeholder that
82 expands into the model's fixed audio-token window.
83 """
84
85 model_config = ConfigDict(extra="forbid", frozen=True)
86
87 sample_rate: int = Field(..., gt=0, description="Hz — refuse on mismatch")
88 max_length_seconds: float = Field(..., gt=0.0)
89 audio_token: str = Field(..., min_length=1, description="Placeholder token string")
90 num_audio_tokens: int = Field(..., gt=0, description="Tokens reserved per clip")
91
92
93 class BaseModelSpec(BaseModel):
94 """Curated registry metadata for one base model."""
95
96 model_config = ConfigDict(extra="forbid", frozen=True)
97
98 key: str = Field(..., min_length=1, description="Registry slug (e.g. `qwen2.5-1.5b`).")
99 hf_id: str = Field(
100 ..., min_length=1, description="HuggingFace id, e.g. `Qwen/Qwen2.5-1.5B-Instruct`."
101 )
102 revision: str = Field(..., description="40-char commit SHA; never a branch.")
103 architecture: str = Field(..., description="transformers `config.architectures[0]` value.")
104 params: int = Field(..., gt=0, description="Parameter count; drives hardware doctor.")
105 target_modules: list[str] = Field(..., min_length=1)
106 template: Literal[
107 "chatml",
108 "qwen3thinking",
109 "gemma2",
110 "smollm3",
111 "olmo2",
112 "llama3",
113 "phi3",
114 "phi4mini",
115 "mistral",
116 "paligemma",
117 "qwen2-audio",
118 "qwen2-vl",
119 "internvl2",
120 ]
121 gguf_arch: str = Field(..., min_length=1, description="Name llama.cpp's converter uses.")
122 tokenizer_pre: str = Field(..., min_length=1, description="Pre-tokenizer label.")
123
124 # License + acceptance.
125 license_spdx: str = Field(..., min_length=1)
126 license_url: str | None = None
127 requires_acceptance: bool = False
128 redistributable: bool = Field(
129 ...,
130 description="True iff the license allows bundling the base inside a .dlm.pack.",
131 )
132 # trust_remote_code: `True` for bases whose HF class lives in the
133 # model's own repo (custom `modeling_*.py` files) rather than in
134 # transformers. Picking such a base as `base_model:` in a .dlm is
135 # the user's informed acknowledgment — the registry entry carries
136 # a docstring caveat, vl-memory.md + the cookbook flag it, and the
137 # loader only passes `trust_remote_code=True` to HF when this is
138 # `True` on the spec. Defaults to False so non-custom bases can
139 # never accidentally opt into remote code.
140 trust_remote_code: bool = False
141
142 # Size + context hints.
143 size_gb_fp16: float = Field(..., gt=0)
144 context_length: int = Field(..., gt=0)
145 context_length_effective: int | None = Field(None, gt=0)
146 recommended_seq_len: int = Field(..., gt=0)
147 reasoning_tuned: bool = False
148 refresh_check_hf_gating: bool = True
149 provenance_url: str | None = None
150 provenance_match_text: str | None = None
151
152 # Modality + multi-modal preprocessing (schema v10 + v11, plus the
153 # additive `text-moe` discriminator).
154 # Text-family bases leave `modality in {"text", "text-moe"}`
155 # with both plans None;
156 # `modality="vision-language"` requires a `vl_preprocessor_plan`
157 # and rejects an audio plan; `modality="audio-language"` requires
158 # an `audio_preprocessor_plan` and rejects a vl plan. Every
159 # invariant is enforced below at validate time.
160 modality: Literal["text", "text-moe", "vision-language", "audio-language"] = "text"
161 vl_preprocessor_plan: VlPreprocessorPlan | None = None
162 audio_preprocessor_plan: AudioPreprocessorPlan | None = None
163
164 @model_validator(mode="after")
165 def _modality_matches_plan(self) -> BaseModelSpec:
166 if self.modality == "vision-language":
167 if self.vl_preprocessor_plan is None:
168 raise ValueError(
169 f"base {self.key!r}: modality='vision-language' requires "
170 "a vl_preprocessor_plan (pinned image size + token shape)"
171 )
172 if self.audio_preprocessor_plan is not None:
173 raise ValueError(
174 f"base {self.key!r}: audio_preprocessor_plan is invalid "
175 "on a vision-language base"
176 )
177 elif self.modality == "audio-language":
178 if self.audio_preprocessor_plan is None:
179 raise ValueError(
180 f"base {self.key!r}: modality='audio-language' requires "
181 "an audio_preprocessor_plan (pinned sample_rate + token shape)"
182 )
183 if self.vl_preprocessor_plan is not None:
184 raise ValueError(
185 f"base {self.key!r}: vl_preprocessor_plan is invalid on an audio-language base"
186 )
187 else: # "text" or "text-moe"
188 if self.vl_preprocessor_plan is not None:
189 raise ValueError(
190 f"base {self.key!r}: vl_preprocessor_plan only valid with "
191 "modality='vision-language'"
192 )
193 if self.audio_preprocessor_plan is not None:
194 raise ValueError(
195 f"base {self.key!r}: audio_preprocessor_plan only valid "
196 "with modality='audio-language'"
197 )
198 return self
199
200 @model_validator(mode="after")
201 def _effective_context_length_is_bounded(self) -> BaseModelSpec:
202 if (
203 self.context_length_effective is not None
204 and self.context_length_effective > self.context_length
205 ):
206 raise ValueError(
207 f"base {self.key!r}: context_length_effective={self.context_length_effective} "
208 f"cannot exceed context_length={self.context_length}"
209 )
210 return self
211
212 @model_validator(mode="after")
213 def _provenance_probe_is_complete(self) -> BaseModelSpec:
214 url_set = self.provenance_url is not None
215 text_set = self.provenance_match_text is not None
216 if url_set != text_set:
217 raise ValueError(
218 f"base {self.key!r}: provenance_url and provenance_match_text must be set together"
219 )
220 if not self.refresh_check_hf_gating and not url_set:
221 raise ValueError(
222 f"base {self.key!r}: refresh_check_hf_gating=False requires a "
223 "first-party provenance_url + provenance_match_text"
224 )
225 return self
226
227 @property
228 def suggested_prompt_temperature(self) -> float:
229 """Default sampling temperature for `dlm prompt`.
230
231 Most instruct bases keep the long-standing 0.7 default.
232 Reasoning-tuned bases run slightly cooler by default so the
233 chain-of-thought control tokens they were tuned around stay
234 stable when the user omits `--temp`.
235 """
236 if self.reasoning_tuned:
237 return DEFAULT_REASONING_PROMPT_TEMPERATURE
238 return DEFAULT_PROMPT_TEMPERATURE
239
240 @property
241 def effective_context_length(self) -> int:
242 """Context window `dlm doctor` should estimate against.
243
244 Registry rows can pin a lower practical ceiling than the
245 model's advertised nominal context length. When no such hint is
246 set, the nominal context window remains the source of truth.
247 """
248 return self.context_length_effective or self.context_length
249
250 @field_validator("revision")
251 @classmethod
252 def _validate_revision(cls, value: str) -> str:
253 if not _SHA_RE.fullmatch(value):
254 raise ValueError(f"revision must be a 40-char lowercase hex SHA, got {value!r}")
255 return value
256
257 @field_validator("hf_id")
258 @classmethod
259 def _validate_hf_id(cls, value: str) -> str:
260 if "/" not in value or value.startswith("/") or value.endswith("/"):
261 raise ValueError(f"hf_id must be 'org/name', got {value!r}")
262 org, _, name = value.partition("/")
263 if not org or not name or "/" in name:
264 raise ValueError(f"hf_id must be 'org/name' (single `/`), got {value!r}")
265 return value
266
267 @field_validator("target_modules")
268 @classmethod
269 def _validate_target_modules(cls, value: list[str]) -> list[str]:
270 if any(not m for m in value):
271 raise ValueError("target_modules must not contain empty strings")
272 return value