Python · 11077 bytes Raw Blame History
1 """Capability profile resolution for runtime behavior."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass, field
6 from typing import Any, Literal, Protocol
7
8 ToolCallFormat = Literal["native", "json_tag", "bracket"]
9 VerificationStrictness = Literal["lax", "standard", "strict"]
10
11
12 class SupportsCapabilityProfile(Protocol):
13 """Runtime interface for backends that can describe capabilities."""
14
15 def capability_profile(self) -> CapabilityProfile: ...
16
17
18 class SupportsNativeTools(Protocol):
19 """Runtime interface for backends that can explicitly report tool support."""
20
21 def supports_native_tools(self) -> bool: ...
22
23
24 @dataclass(frozen=True)
25 class CapabilityProfile:
26 """Resolved model/runtime capability profile."""
27
28 model_name: str
29 supports_native_tools: bool
30 supports_streaming: bool
31 context_window: int
32 preferred_tool_call_format: ToolCallFormat
33 verification_strictness: VerificationStrictness
34 notes: list[str] = field(default_factory=list)
35
36
37 def _profile(
38 model_name: str,
39 *,
40 supports_native_tools: bool,
41 supports_streaming: bool = True,
42 context_window: int = 8192,
43 preferred_tool_call_format: ToolCallFormat = "native",
44 verification_strictness: VerificationStrictness = "standard",
45 notes: list[str] | None = None,
46 ) -> CapabilityProfile:
47 return CapabilityProfile(
48 model_name=model_name,
49 supports_native_tools=supports_native_tools,
50 supports_streaming=supports_streaming,
51 context_window=context_window,
52 preferred_tool_call_format=preferred_tool_call_format,
53 verification_strictness=verification_strictness,
54 notes=list(notes or []),
55 )
56
57
58 KNOWN_CAPABILITY_PROFILES: dict[str, CapabilityProfile] = {
59 "llama3.1": _profile(
60 "llama3.1",
61 supports_native_tools=True,
62 preferred_tool_call_format="native",
63 verification_strictness="standard",
64 ),
65 "llama3.2": _profile(
66 "llama3.2",
67 supports_native_tools=True,
68 preferred_tool_call_format="native",
69 verification_strictness="standard",
70 ),
71 "llama3.3": _profile(
72 "llama3.3",
73 supports_native_tools=True,
74 preferred_tool_call_format="native",
75 verification_strictness="standard",
76 ),
77 "qwen2.5": _profile(
78 "qwen2.5",
79 supports_native_tools=True,
80 preferred_tool_call_format="native",
81 verification_strictness="strict",
82 ),
83 "qwen2.5-coder": _profile(
84 "qwen2.5-coder",
85 supports_native_tools=True,
86 preferred_tool_call_format="native",
87 verification_strictness="strict",
88 ),
89 "devstral": _profile(
90 "devstral",
91 supports_native_tools=True,
92 preferred_tool_call_format="native",
93 verification_strictness="standard",
94 notes=["Agentic coding model; well-suited to loader's tool loop."],
95 ),
96 "gpt-oss": _profile(
97 "gpt-oss",
98 supports_native_tools=True,
99 preferred_tool_call_format="native",
100 verification_strictness="standard",
101 ),
102 "mistral": _profile(
103 "mistral",
104 supports_native_tools=True,
105 preferred_tool_call_format="native",
106 verification_strictness="standard",
107 ),
108 "mixtral": _profile(
109 "mixtral",
110 supports_native_tools=True,
111 preferred_tool_call_format="native",
112 verification_strictness="standard",
113 ),
114 "codestral": _profile(
115 "codestral",
116 supports_native_tools=False,
117 preferred_tool_call_format="json_tag",
118 verification_strictness="strict",
119 notes=["Use ReAct-style prompting for tool calls."],
120 ),
121 "deepseek-coder": _profile(
122 "deepseek-coder",
123 supports_native_tools=False,
124 preferred_tool_call_format="json_tag",
125 verification_strictness="strict",
126 notes=["Better with extracted JSON-tag tool calls than native tools."],
127 ),
128 "deepseek-r1": _profile(
129 "deepseek-r1",
130 supports_native_tools=False,
131 preferred_tool_call_format="json_tag",
132 verification_strictness="strict",
133 notes=["Reasoning-oriented family; filter think blocks and use ReAct."],
134 ),
135 "phi3": _profile(
136 "phi3",
137 supports_native_tools=False,
138 preferred_tool_call_format="bracket",
139 verification_strictness="lax",
140 ),
141 "gemma2": _profile(
142 "gemma2",
143 supports_native_tools=False,
144 preferred_tool_call_format="bracket",
145 verification_strictness="lax",
146 ),
147 }
148
149 NATIVE_TOOL_FAMILIES = {
150 "llama3", "qwen2", "qwen2.5", "qwen3", "mistral", "mixtral",
151 "command-r", "granite", "devstral", "gemma4", "gpt-oss",
152 }
153 NO_TOOL_FAMILIES = {
154 "llama2",
155 "phi",
156 "phi3",
157 "gemma",
158 "gemma2",
159 "tinyllama",
160 "codestral",
161 "deepseek-coder",
162 "deepseek-r1",
163 "starcoder",
164 "codegemma",
165 }
166
167
168 def _family_tokens(model_name: str, model_details: dict[str, Any] | None) -> set[str]:
169 """Collect lowercase family tokens from model name and model details."""
170
171 name = model_name.lower()
172 tokens = {name}
173 # Strip :tag so "devstral:24b" also produces "devstral"
174 if ":" in name:
175 tokens.add(name.split(":")[0])
176 if model_details:
177 details = model_details.get("details", model_details)
178 families = details.get("families", [])
179 family = details.get("family")
180 if isinstance(families, list):
181 tokens.update(str(token).lower() for token in families)
182 if family:
183 tokens.add(str(family).lower())
184 return tokens
185
186
187 def _any_prefix_match(tokens: set[str], family_set: set[str]) -> bool:
188 """Check if any family entry is a prefix of any token."""
189 for token in tokens:
190 for family in family_set:
191 if token.startswith(family):
192 return True
193 return False
194
195
196 def _coerce_positive_int(value: Any) -> int | None:
197 """Return one positive integer when the input looks numeric."""
198
199 try:
200 number = int(value)
201 except (TypeError, ValueError):
202 return None
203 if number <= 0:
204 return None
205 return number
206
207
208 def _infer_context_window(model_details: dict[str, Any] | None) -> int | None:
209 """Infer one model context window from Ollama model metadata."""
210
211 if not isinstance(model_details, dict):
212 return None
213
214 candidates: list[int] = []
215
216 details = model_details.get("details")
217 if isinstance(details, dict):
218 context_length = _coerce_positive_int(details.get("context_length"))
219 if context_length is not None:
220 candidates.append(context_length)
221
222 model_info = model_details.get("model_info")
223 if isinstance(model_info, dict):
224 for key, value in model_info.items():
225 if str(key).endswith(".context_length"):
226 context_length = _coerce_positive_int(value)
227 if context_length is not None:
228 candidates.append(context_length)
229
230 return max(candidates) if candidates else None
231
232
233 def resolve_capability_profile(
234 model_name: str,
235 *,
236 override: CapabilityProfile | None = None,
237 model_details: dict[str, Any] | None = None,
238 ) -> CapabilityProfile:
239 """Resolve the capability profile for a model.
240
241 Resolution order:
242 1. explicit override
243 2. exact-name match in the built-in registry
244 3. heuristic fallback using model details / family tokens
245 """
246
247 inferred_context_window = _infer_context_window(model_details)
248
249 if override is not None:
250 if inferred_context_window is None:
251 return override
252 return CapabilityProfile(
253 model_name=override.model_name,
254 supports_native_tools=override.supports_native_tools,
255 supports_streaming=override.supports_streaming,
256 context_window=inferred_context_window,
257 preferred_tool_call_format=override.preferred_tool_call_format,
258 verification_strictness=override.verification_strictness,
259 notes=list(override.notes),
260 )
261
262 normalized = model_name.lower().strip()
263 # Try full name first, then without :tag (e.g. "deepseek-r1:14b" -> "deepseek-r1")
264 for key in (normalized, normalized.split(":")[0]):
265 if key in KNOWN_CAPABILITY_PROFILES:
266 known = KNOWN_CAPABILITY_PROFILES[key]
267 return CapabilityProfile(
268 model_name=model_name,
269 supports_native_tools=known.supports_native_tools,
270 supports_streaming=known.supports_streaming,
271 context_window=inferred_context_window or known.context_window,
272 preferred_tool_call_format=known.preferred_tool_call_format,
273 verification_strictness=known.verification_strictness,
274 notes=list(known.notes),
275 )
276
277 tokens = _family_tokens(normalized, model_details)
278
279 if _any_prefix_match(tokens, NATIVE_TOOL_FAMILIES):
280 return _profile(
281 model_name,
282 supports_native_tools=True,
283 context_window=inferred_context_window or 8192,
284 preferred_tool_call_format="native",
285 verification_strictness="standard",
286 notes=["Resolved from model family heuristic."],
287 )
288
289 if _any_prefix_match(tokens, NO_TOOL_FAMILIES):
290 return _profile(
291 model_name,
292 supports_native_tools=False,
293 context_window=inferred_context_window or 8192,
294 preferred_tool_call_format="json_tag",
295 verification_strictness="standard",
296 notes=["Resolved from conservative no-native-tools heuristic."],
297 )
298
299 return _profile(
300 model_name,
301 supports_native_tools=False,
302 context_window=inferred_context_window or 8192,
303 preferred_tool_call_format="json_tag",
304 verification_strictness="standard",
305 notes=["Unknown model family; defaulting to safe ReAct-style tool use."],
306 )
307
308
309 def resolve_backend_capability_profile(backend: Any) -> CapabilityProfile:
310 """Resolve capabilities from the backend first, then fall back to model heuristics."""
311
312 explicit_profile = getattr(backend, "capability_profile", None)
313 if callable(explicit_profile):
314 profile = explicit_profile()
315 if isinstance(profile, CapabilityProfile):
316 return profile
317
318 model_name = getattr(backend, "model", backend.__class__.__name__)
319 explicit_native_tools = getattr(backend, "supports_native_tools", None)
320 if callable(explicit_native_tools):
321 supports_native_tools = bool(explicit_native_tools())
322 preferred_tool_call_format: ToolCallFormat = (
323 "native" if supports_native_tools else "json_tag"
324 )
325 return _profile(
326 model_name,
327 supports_native_tools=supports_native_tools,
328 preferred_tool_call_format=preferred_tool_call_format,
329 verification_strictness="standard",
330 notes=["Resolved from backend capability surface."],
331 )
332
333 return resolve_capability_profile(model_name)