Python · 10516 bytes Raw Blame History
1 """Happy-path coverage for `resolve_hf()` — the `hf:org/name` escape hatch.
2
3 Covers: successful synthesis, helper fns (`_estimate_params`,
4 `_infer_gguf_arch`, `_infer_template`, `_default_target_modules`), and
5 `resolve_hf()`'s probe-report failure path.
6 """
7
8 from __future__ import annotations
9
10 from types import SimpleNamespace
11 from unittest.mock import Mock, patch
12
13 import pytest
14
15 from dlm.base_models import BaseModelSpec, GatedModelError, resolve, resolve_hf
16 from dlm.base_models.errors import ProbeFailedError, ProbeReport, ProbeResult
17 from dlm.base_models.resolver import (
18 _default_target_modules,
19 _estimate_params,
20 _infer_gguf_arch,
21 _infer_template,
22 )
23
24 # --- helper fns --------------------------------------------------------------
25
26
27 class TestEstimateParams:
28 def test_uses_num_parameters_when_present(self) -> None:
29 cfg = SimpleNamespace(num_parameters=1_500_000_000)
30 # num_parameters is consumed by the caller, not _estimate_params;
31 # _estimate_params ignores it and falls back to hidden/layer heuristic.
32 result = _estimate_params(cfg)
33 assert result > 0
34
35 def test_hidden_plus_layers_default_fallbacks(self) -> None:
36 # No hidden_size / num_hidden_layers / vocab_size present: uses
37 # built-in defaults (2048, 24, 32000).
38 cfg = SimpleNamespace()
39 result = _estimate_params(cfg)
40 assert result > 0
41
42 def test_honors_overrides(self) -> None:
43 # With explicit hidden_size=4096, layers=32, vocab=128k,
44 # we expect a sharply larger estimate.
45 big = _estimate_params(
46 SimpleNamespace(hidden_size=4096, num_hidden_layers=32, vocab_size=128_000)
47 )
48 small = _estimate_params(
49 SimpleNamespace(hidden_size=1024, num_hidden_layers=12, vocab_size=32_000)
50 )
51 assert big > small
52
53
54 class TestInferGgufArch:
55 @pytest.mark.parametrize(
56 ("architecture", "expected"),
57 [
58 ("LlamaForCausalLM", "llama"),
59 ("SmolLM3ForCausalLM", "llama"),
60 ("Olmo2ForCausalLM", "olmo2"),
61 ("Qwen2ForCausalLM", "qwen2"),
62 ("Qwen3ForCausalLM", "qwen3"),
63 ("MistralForCausalLM", "llama"),
64 ("MixtralForCausalLM", "llama"),
65 ("Phi3ForCausalLM", "phi3"),
66 ("GemmaForCausalLM", "gemma"),
67 ("Gemma2ForCausalLM", "gemma2"),
68 ],
69 )
70 def test_known_architectures_map_correctly(self, architecture: str, expected: str) -> None:
71 assert _infer_gguf_arch(architecture) == expected
72
73 def test_unknown_arch_falls_back_to_lowercase_stripped(self) -> None:
74 # Unknown: lowercase + strip `forcausallm`.
75 assert _infer_gguf_arch("SomeNewForCausalLM") == "somenew"
76
77
78 class TestInferTemplate:
79 @pytest.mark.parametrize(
80 ("hf_id", "architecture", "expected"),
81 [
82 ("google/gemma-2-2b-it", "Gemma2ForCausalLM", "gemma2"),
83 ("HuggingFaceTB/SmolLM3-3B", "SmolLM3ForCausalLM", "smollm3"),
84 ("allenai/OLMo-2-1124-7B-Instruct", "Olmo2ForCausalLM", "olmo2"),
85 ("meta-llama/Llama-3.2-1B-Instruct", "LlamaForCausalLM", "llama3"),
86 ("meta-llama/llama3-base", "LlamaForCausalLM", "llama3"),
87 ("microsoft/Phi-4-mini-reasoning", "Phi3ForCausalLM", "phi4mini"),
88 ("microsoft/Phi-3.5-mini-instruct", "Phi3ForCausalLM", "phi3"),
89 ("mistralai/Mistral-7B-Instruct", "MistralForCausalLM", "mistral"),
90 ("mistralai/Mixtral-8x7B-Instruct-v0.1", "MixtralForCausalLM", "mistral"),
91 ("Qwen/Qwen2.5-1.5B-Instruct", "Qwen2ForCausalLM", "chatml"),
92 ],
93 )
94 def test_template_inference(self, hf_id: str, architecture: str, expected: str) -> None:
95 assert _infer_template(hf_id, architecture) == expected
96
97
98 class TestDefaultTargetModules:
99 def test_phi3_uses_fused_qkv(self) -> None:
100 assert _default_target_modules("phi3") == [
101 "qkv_proj",
102 "o_proj",
103 "gate_up_proj",
104 "down_proj",
105 ]
106
107 def test_other_archs_use_split_qkv(self) -> None:
108 for arch in ("llama", "olmo2", "qwen2", "qwen3", "gemma2"):
109 assert _default_target_modules(arch) == ["q_proj", "k_proj", "v_proj", "o_proj"]
110
111
112 # --- resolve_hf end-to-end ---------------------------------------------------
113
114
115 class TestResolveHfHappyPath:
116 def _mock_config(self, **overrides: object) -> SimpleNamespace:
117 defaults = {
118 "architectures": ["Qwen2ForCausalLM"],
119 "hidden_size": 1536,
120 "num_hidden_layers": 28,
121 "vocab_size": 151_936,
122 "max_position_embeddings": 32_768,
123 }
124 defaults.update(overrides)
125 return SimpleNamespace(**defaults)
126
127 def test_successful_synthesis_returns_spec(self) -> None:
128 info = SimpleNamespace(sha="a" * 40, gated=False)
129 report = ProbeReport(
130 hf_id="org/custom",
131 results=(
132 ProbeResult(name="architecture", passed=True, detail="ok"),
133 ProbeResult(name="chat_template", passed=True, detail="ok"),
134 ProbeResult(name="gguf_arch", passed=True, detail="ok", skipped=True),
135 ProbeResult(name="pretokenizer_label", passed=True, detail="ok", skipped=True),
136 ),
137 )
138 with (
139 patch("huggingface_hub.HfApi") as api_cls,
140 patch("transformers.AutoConfig.from_pretrained", return_value=self._mock_config()),
141 patch("dlm.base_models.probes.run_all", return_value=report),
142 ):
143 api_cls.return_value.model_info.return_value = info
144 spec = resolve_hf("org/custom")
145
146 assert isinstance(spec, BaseModelSpec)
147 assert spec.key == "hf:org/custom"
148 assert spec.hf_id == "org/custom"
149 assert spec.revision == "a" * 40
150 assert spec.architecture == "Qwen2ForCausalLM"
151 assert spec.gguf_arch == "qwen2"
152 assert spec.template == "chatml"
153 assert spec.redistributable is False
154 # hf: synthesis is conservative on license.
155 assert spec.license_spdx == "Unknown"
156
157 def test_probe_failure_raises_probe_failed_error(self) -> None:
158 info = SimpleNamespace(sha="a" * 40, gated=False)
159 report = ProbeReport(
160 hf_id="org/custom",
161 results=(
162 ProbeResult(name="architecture", passed=True, detail="ok"),
163 ProbeResult(name="chat_template", passed=False, detail="missing"),
164 ProbeResult(name="gguf_arch", passed=True, detail="ok", skipped=True),
165 ProbeResult(name="pretokenizer_label", passed=True, detail="ok", skipped=True),
166 ),
167 )
168 with (
169 patch("huggingface_hub.HfApi") as api_cls,
170 patch("transformers.AutoConfig.from_pretrained", return_value=self._mock_config()),
171 patch("dlm.base_models.probes.run_all", return_value=report),
172 ):
173 api_cls.return_value.model_info.return_value = info
174 with pytest.raises(ProbeFailedError, match="chat_template"):
175 resolve_hf("org/custom")
176
177 def test_rejects_non_40_char_sha_from_hf(self) -> None:
178 info = SimpleNamespace(sha="tooshort", gated=False)
179 with (
180 patch("huggingface_hub.HfApi") as api_cls,
181 patch("transformers.AutoConfig.from_pretrained", return_value=self._mock_config()),
182 ):
183 api_cls.return_value.model_info.return_value = info
184 with pytest.raises(RuntimeError, match="non-40-char SHA"):
185 resolve_hf("org/custom")
186
187 def test_empty_architectures_fails_fast(self) -> None:
188 info = SimpleNamespace(sha="a" * 40, gated=False)
189 with (
190 patch("huggingface_hub.HfApi") as api_cls,
191 patch(
192 "transformers.AutoConfig.from_pretrained",
193 return_value=self._mock_config(architectures=[]),
194 ),
195 ):
196 api_cls.return_value.model_info.return_value = info
197 with pytest.raises(ProbeFailedError, match="architectures"):
198 resolve_hf("org/custom")
199
200 def test_gated_repo_during_config_load_surfaces_as_gated_error(self) -> None:
201 from huggingface_hub.errors import GatedRepoError
202
203 info = SimpleNamespace(sha="a" * 40, gated=False)
204 with (
205 patch("huggingface_hub.HfApi") as api_cls,
206 patch(
207 "transformers.AutoConfig.from_pretrained",
208 side_effect=GatedRepoError("gated at config load", response=Mock()),
209 ),
210 ):
211 api_cls.return_value.model_info.return_value = info
212 with pytest.raises(GatedModelError):
213 resolve_hf("org/custom")
214
215 def test_resolve_dispatches_to_hf_escape_on_prefix(self) -> None:
216 """Smoke test: `resolve('hf:...')` delegates to `resolve_hf`."""
217 info = SimpleNamespace(sha="a" * 40, gated=False)
218 report = ProbeReport(
219 hf_id="org/mini",
220 results=(
221 ProbeResult(name="architecture", passed=True, detail="ok"),
222 ProbeResult(name="chat_template", passed=True, detail="ok"),
223 ProbeResult(name="gguf_arch", passed=True, detail="ok", skipped=True),
224 ProbeResult(name="pretokenizer_label", passed=True, detail="ok", skipped=True),
225 ),
226 )
227 with (
228 patch("huggingface_hub.HfApi") as api_cls,
229 patch("transformers.AutoConfig.from_pretrained", return_value=self._mock_config()),
230 patch("dlm.base_models.probes.run_all", return_value=report),
231 ):
232 api_cls.return_value.model_info.return_value = info
233 spec = resolve("hf:org/mini")
234 assert spec.key == "hf:org/mini"
235
236
237 class TestResolveHfConfigLookupErrors:
238 def test_entry_not_found_surfaces_as_unknown_base_model(self) -> None:
239 from huggingface_hub.errors import EntryNotFoundError
240
241 info = SimpleNamespace(sha="a" * 40, gated=False)
242 with (
243 patch("huggingface_hub.HfApi") as api_cls,
244 patch(
245 "transformers.AutoConfig.from_pretrained",
246 side_effect=EntryNotFoundError("no config"),
247 ),
248 ):
249 api_cls.return_value.model_info.return_value = info
250 from dlm.base_models import UnknownBaseModelError
251
252 with pytest.raises(UnknownBaseModelError):
253 resolve_hf("org/nocfg")