Python · 8800 bytes Raw Blame History
1 """Probe the vendored `convert_hf_to_gguf.py` for VL arch coverage.
2
3 VL GGUF conversion is moving upstream; some VL architectures are
4 fully registered in `convert_hf_to_gguf.py` (LM + vision tower emit
5 together), some are partial (LM-only via the TextModel base, vision
6 tower ships separately via an mmproj class), and some aren't
7 registered at all.
8
9 `probe_gguf_arch(arch_class)` scans the vendored conversion script
10 for `@ModelBase.register(...)` decorators that name the arch, then
11 looks at the decorated class's base(s) to decide the verdict:
12
13 - **SUPPORTED** — at least one register binds to a class that does
14 NOT inherit from `MmprojModel` (i.e. a TextModel or similar that
15 emits the LM + combined VL path cleanly).
16 - **PARTIAL** — the arch appears only on `MmprojModel` subclasses,
17 meaning llama.cpp handles the vision tower separately via an
18 mmproj sidecar but no single-file GGUF covers the full VL model.
19 - **UNSUPPORTED** — the arch isn't registered anywhere; GGUF
20 conversion would fail.
21
22 Callers (the export dispatcher) use this verdict to choose between
23 emitting GGUF, emitting GGUF with a warning banner about vision-
24 tower caveats, or falling back cleanly to the HF-snapshot path.
25
26 The probe is cheap (one read + regex scan) but we memoize on
27 `(llama_cpp_sha, arch_class)` so a single export run doesn't re-parse
28 the 200+ KB script on every registry lookup.
29 """
30
31 from __future__ import annotations
32
33 import enum
34 import re
35 from dataclasses import dataclass
36 from pathlib import Path
37 from typing import Final
38
39 from dlm.export.vendoring import convert_hf_to_gguf_py, pinned_tag
40
41
42 class SupportLevel(enum.StrEnum):
43 """Verdict for one VL arch class against a vendored llama.cpp tree.
44
45 StrEnum so manifests can store the value as a plain JSON string
46 ("SUPPORTED" / "PARTIAL" / "UNSUPPORTED") without custom encoders.
47 """
48
49 SUPPORTED = "SUPPORTED"
50 PARTIAL = "PARTIAL"
51 UNSUPPORTED = "UNSUPPORTED"
52
53
54 @dataclass(frozen=True)
55 class ArchProbeResult:
56 """Outcome of probing one arch class.
57
58 `reason` is the human-readable explanation the dispatcher surfaces
59 in banner/error messages; `llama_cpp_tag` records which vendored
60 build the verdict came from (so users bumping the submodule can
61 tell when the answer changes).
62 """
63
64 arch_class: str
65 support: SupportLevel
66 reason: str
67 llama_cpp_tag: str | None
68
69
70 # Matches a `@ModelBase.register(...)` decorator and captures the arg
71 # list verbatim so we can look for the quoted arch-class name inside.
72 # Multi-line decorators are supported via DOTALL on the arg-list span.
73 _REGISTER_DECORATOR: Final[re.Pattern[str]] = re.compile(
74 r"@ModelBase\.register\((?P<args>[^)]*)\)",
75 re.DOTALL,
76 )
77
78 # Captures the first `class Foo(Bar, Baz):` line after a register
79 # decorator — we read the base-class list to decide SUPPORTED vs
80 # PARTIAL (MmprojModel-only registration → PARTIAL).
81 _CLASS_DEFINITION: Final[re.Pattern[str]] = re.compile(
82 r"^\s*class\s+\w+\((?P<bases>[^)]*)\)\s*:",
83 re.MULTILINE,
84 )
85
86 _MMPROJ_BASE: Final[str] = "MmprojModel"
87
88 _CACHE: dict[tuple[str | None, str], ArchProbeResult] = {}
89
90
91 def probe_gguf_arch(
92 arch_class: str,
93 *,
94 llama_cpp_root: Path | None = None,
95 ) -> ArchProbeResult:
96 """Return the SUPPORTED/PARTIAL/UNSUPPORTED verdict for `arch_class`.
97
98 `llama_cpp_root` overrides the default vendored root (used by
99 tests that point the probe at a fixture tree); production callers
100 omit it and let `dlm.export.vendoring` resolve via the env var or
101 `vendor/llama.cpp/`.
102
103 Raises `VendoringError` (surfaced by `convert_hf_to_gguf_py`) when
104 the vendored tree doesn't contain the conversion script at all —
105 a pre-Sprint-11 layout or an uninitialized submodule.
106 """
107 tag = pinned_tag(llama_cpp_root)
108 cache_key = (tag, arch_class)
109 cached = _CACHE.get(cache_key)
110 if cached is not None:
111 return cached
112
113 script_path = convert_hf_to_gguf_py(llama_cpp_root)
114 result = _probe_from_script(
115 arch_class=arch_class,
116 script_path=script_path,
117 llama_cpp_tag=tag,
118 )
119 _CACHE[cache_key] = result
120 return result
121
122
123 def clear_cache() -> None:
124 """Drop the memoized probe results. Test-only; production never calls this."""
125 _CACHE.clear()
126
127
128 def _probe_from_script(
129 *,
130 arch_class: str,
131 script_path: Path,
132 llama_cpp_tag: str | None,
133 ) -> ArchProbeResult:
134 """Core scanner: read the script, locate registrations, classify."""
135 text = script_path.read_text(encoding="utf-8")
136 bindings = _find_arch_bindings(text, arch_class)
137
138 if not bindings:
139 return ArchProbeResult(
140 arch_class=arch_class,
141 support=SupportLevel.UNSUPPORTED,
142 reason=(
143 f"{arch_class!r} not found in any @ModelBase.register(...) "
144 f"decorator — vendored llama.cpp "
145 f"(tag={llama_cpp_tag or 'unknown'}) does not know this "
146 "architecture. GGUF conversion would fail."
147 ),
148 llama_cpp_tag=llama_cpp_tag,
149 )
150
151 # If any binding targets a non-Mmproj class, the LM converts as a
152 # standalone GGUF. The vision tower may still ship separately via
153 # a different registration, but single-file GGUF is viable.
154 non_mmproj_bindings = [b for b in bindings if not _is_mmproj(b.bases)]
155 if non_mmproj_bindings:
156 reason = (
157 f"{arch_class!r} registered on "
158 f"{', '.join(sorted({b.class_name for b in non_mmproj_bindings}))} "
159 f"in llama.cpp tag={llama_cpp_tag or 'unknown'}; LM converts "
160 "cleanly via convert_hf_to_gguf.py."
161 )
162 return ArchProbeResult(
163 arch_class=arch_class,
164 support=SupportLevel.SUPPORTED,
165 reason=reason,
166 llama_cpp_tag=llama_cpp_tag,
167 )
168
169 # All bindings are Mmproj-only — vision tower ships as a
170 # separate GGUF and the LM side needs a different arch string.
171 mmproj_names = sorted({b.class_name for b in bindings})
172 return ArchProbeResult(
173 arch_class=arch_class,
174 support=SupportLevel.PARTIAL,
175 reason=(
176 f"{arch_class!r} registered only on MmprojModel class(es) "
177 f"{', '.join(mmproj_names)} in llama.cpp "
178 f"tag={llama_cpp_tag or 'unknown'}. The vision tower converts "
179 "but no single-file GGUF covers the full VL model."
180 ),
181 llama_cpp_tag=llama_cpp_tag,
182 )
183
184
185 @dataclass(frozen=True)
186 class _ArchBinding:
187 """One `@ModelBase.register(...)` → `class Foo(Bar):` pairing."""
188
189 class_name: str
190 bases: str # raw comma-separated base-class list
191
192
193 def _find_arch_bindings(text: str, arch_class: str) -> list[_ArchBinding]:
194 """Return every class registration that lists `arch_class` as an arg."""
195 bindings: list[_ArchBinding] = []
196 quoted_needles = (f'"{arch_class}"', f"'{arch_class}'")
197 for match in _REGISTER_DECORATOR.finditer(text):
198 args = match.group("args")
199 if not any(needle in args for needle in quoted_needles):
200 continue
201 class_match = _CLASS_DEFINITION.search(text, match.end())
202 if class_match is None:
203 # Decorator at end of file with no following class — treat
204 # as if it didn't bind to anything recognizable.
205 continue
206 # Pull the class name by rewinding from the `(` to `class `.
207 class_name = _extract_class_name(text, class_match.start())
208 if class_name is None:
209 continue
210 bindings.append(
211 _ArchBinding(
212 class_name=class_name,
213 bases=class_match.group("bases"),
214 )
215 )
216 return bindings
217
218
219 _CLASS_NAME_RE: Final[re.Pattern[str]] = re.compile(r"class\s+(\w+)\s*\(")
220
221
222 def _extract_class_name(text: str, class_def_start: int) -> str | None:
223 """Parse `class Foo(...)` starting at `class_def_start`."""
224 # The MULTILINE match on _CLASS_DEFINITION starts at the first
225 # leading whitespace; re-match from there to capture the name.
226 segment_end = text.find("(", class_def_start)
227 if segment_end == -1:
228 return None
229 segment = text[class_def_start : segment_end + 1]
230 name_match = _CLASS_NAME_RE.search(segment)
231 return name_match.group(1) if name_match else None
232
233
234 def _is_mmproj(bases: str) -> bool:
235 """True when the class inherits from MmprojModel (direct or indirect).
236
237 We use a plain substring check rather than importing the class
238 hierarchy — the script lives in the vendored tree and we don't
239 want to import it into dlm's process just to classify. Good
240 enough because the base list is short and MmprojModel is a
241 distinctive name that won't false-match another base.
242 """
243 return _MMPROJ_BASE in bases