| 1 | """Capability profile resolution for runtime behavior.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import dataclass, field |
| 6 | from typing import Any, Literal, Protocol |
| 7 | |
| 8 | ToolCallFormat = Literal["native", "json_tag", "bracket"] |
| 9 | VerificationStrictness = Literal["lax", "standard", "strict"] |
| 10 | |
| 11 | |
| 12 | class SupportsCapabilityProfile(Protocol): |
| 13 | """Runtime interface for backends that can describe capabilities.""" |
| 14 | |
| 15 | def capability_profile(self) -> CapabilityProfile: ... |
| 16 | |
| 17 | |
| 18 | class SupportsNativeTools(Protocol): |
| 19 | """Runtime interface for backends that can explicitly report tool support.""" |
| 20 | |
| 21 | def supports_native_tools(self) -> bool: ... |
| 22 | |
| 23 | |
| 24 | @dataclass(frozen=True) |
| 25 | class CapabilityProfile: |
| 26 | """Resolved model/runtime capability profile.""" |
| 27 | |
| 28 | model_name: str |
| 29 | supports_native_tools: bool |
| 30 | supports_streaming: bool |
| 31 | context_window: int |
| 32 | preferred_tool_call_format: ToolCallFormat |
| 33 | verification_strictness: VerificationStrictness |
| 34 | notes: list[str] = field(default_factory=list) |
| 35 | |
| 36 | |
| 37 | def _profile( |
| 38 | model_name: str, |
| 39 | *, |
| 40 | supports_native_tools: bool, |
| 41 | supports_streaming: bool = True, |
| 42 | context_window: int = 8192, |
| 43 | preferred_tool_call_format: ToolCallFormat = "native", |
| 44 | verification_strictness: VerificationStrictness = "standard", |
| 45 | notes: list[str] | None = None, |
| 46 | ) -> CapabilityProfile: |
| 47 | return CapabilityProfile( |
| 48 | model_name=model_name, |
| 49 | supports_native_tools=supports_native_tools, |
| 50 | supports_streaming=supports_streaming, |
| 51 | context_window=context_window, |
| 52 | preferred_tool_call_format=preferred_tool_call_format, |
| 53 | verification_strictness=verification_strictness, |
| 54 | notes=list(notes or []), |
| 55 | ) |
| 56 | |
| 57 | |
| 58 | KNOWN_CAPABILITY_PROFILES: dict[str, CapabilityProfile] = { |
| 59 | "llama3.1": _profile( |
| 60 | "llama3.1", |
| 61 | supports_native_tools=True, |
| 62 | preferred_tool_call_format="native", |
| 63 | verification_strictness="standard", |
| 64 | ), |
| 65 | "llama3.2": _profile( |
| 66 | "llama3.2", |
| 67 | supports_native_tools=True, |
| 68 | preferred_tool_call_format="native", |
| 69 | verification_strictness="standard", |
| 70 | ), |
| 71 | "llama3.3": _profile( |
| 72 | "llama3.3", |
| 73 | supports_native_tools=True, |
| 74 | preferred_tool_call_format="native", |
| 75 | verification_strictness="standard", |
| 76 | ), |
| 77 | "qwen2.5": _profile( |
| 78 | "qwen2.5", |
| 79 | supports_native_tools=True, |
| 80 | preferred_tool_call_format="native", |
| 81 | verification_strictness="strict", |
| 82 | ), |
| 83 | "qwen2.5-coder": _profile( |
| 84 | "qwen2.5-coder", |
| 85 | supports_native_tools=True, |
| 86 | preferred_tool_call_format="native", |
| 87 | verification_strictness="strict", |
| 88 | ), |
| 89 | "devstral": _profile( |
| 90 | "devstral", |
| 91 | supports_native_tools=True, |
| 92 | preferred_tool_call_format="native", |
| 93 | verification_strictness="standard", |
| 94 | notes=["Agentic coding model; well-suited to loader's tool loop."], |
| 95 | ), |
| 96 | "gpt-oss": _profile( |
| 97 | "gpt-oss", |
| 98 | supports_native_tools=True, |
| 99 | preferred_tool_call_format="native", |
| 100 | verification_strictness="standard", |
| 101 | ), |
| 102 | "mistral": _profile( |
| 103 | "mistral", |
| 104 | supports_native_tools=True, |
| 105 | preferred_tool_call_format="native", |
| 106 | verification_strictness="standard", |
| 107 | ), |
| 108 | "mixtral": _profile( |
| 109 | "mixtral", |
| 110 | supports_native_tools=True, |
| 111 | preferred_tool_call_format="native", |
| 112 | verification_strictness="standard", |
| 113 | ), |
| 114 | "codestral": _profile( |
| 115 | "codestral", |
| 116 | supports_native_tools=False, |
| 117 | preferred_tool_call_format="json_tag", |
| 118 | verification_strictness="strict", |
| 119 | notes=["Use ReAct-style prompting for tool calls."], |
| 120 | ), |
| 121 | "deepseek-coder": _profile( |
| 122 | "deepseek-coder", |
| 123 | supports_native_tools=False, |
| 124 | preferred_tool_call_format="json_tag", |
| 125 | verification_strictness="strict", |
| 126 | notes=["Better with extracted JSON-tag tool calls than native tools."], |
| 127 | ), |
| 128 | "deepseek-r1": _profile( |
| 129 | "deepseek-r1", |
| 130 | supports_native_tools=False, |
| 131 | preferred_tool_call_format="json_tag", |
| 132 | verification_strictness="strict", |
| 133 | notes=["Reasoning-oriented family; filter think blocks and use ReAct."], |
| 134 | ), |
| 135 | "phi3": _profile( |
| 136 | "phi3", |
| 137 | supports_native_tools=False, |
| 138 | preferred_tool_call_format="bracket", |
| 139 | verification_strictness="lax", |
| 140 | ), |
| 141 | "gemma2": _profile( |
| 142 | "gemma2", |
| 143 | supports_native_tools=False, |
| 144 | preferred_tool_call_format="bracket", |
| 145 | verification_strictness="lax", |
| 146 | ), |
| 147 | } |
| 148 | |
| 149 | NATIVE_TOOL_FAMILIES = { |
| 150 | "llama3", "qwen2", "qwen2.5", "qwen3", "mistral", "mixtral", |
| 151 | "command-r", "granite", "devstral", "gemma4", "gpt-oss", |
| 152 | } |
| 153 | NO_TOOL_FAMILIES = { |
| 154 | "llama2", |
| 155 | "phi", |
| 156 | "phi3", |
| 157 | "gemma", |
| 158 | "gemma2", |
| 159 | "tinyllama", |
| 160 | "codestral", |
| 161 | "deepseek-coder", |
| 162 | "deepseek-r1", |
| 163 | "starcoder", |
| 164 | "codegemma", |
| 165 | } |
| 166 | |
| 167 | |
| 168 | def _family_tokens(model_name: str, model_details: dict[str, Any] | None) -> set[str]: |
| 169 | """Collect lowercase family tokens from model name and model details.""" |
| 170 | |
| 171 | name = model_name.lower() |
| 172 | tokens = {name} |
| 173 | # Strip :tag so "devstral:24b" also produces "devstral" |
| 174 | if ":" in name: |
| 175 | tokens.add(name.split(":")[0]) |
| 176 | if model_details: |
| 177 | details = model_details.get("details", model_details) |
| 178 | families = details.get("families", []) |
| 179 | family = details.get("family") |
| 180 | if isinstance(families, list): |
| 181 | tokens.update(str(token).lower() for token in families) |
| 182 | if family: |
| 183 | tokens.add(str(family).lower()) |
| 184 | return tokens |
| 185 | |
| 186 | |
| 187 | def _any_prefix_match(tokens: set[str], family_set: set[str]) -> bool: |
| 188 | """Check if any family entry is a prefix of any token.""" |
| 189 | for token in tokens: |
| 190 | for family in family_set: |
| 191 | if token.startswith(family): |
| 192 | return True |
| 193 | return False |
| 194 | |
| 195 | |
| 196 | def _coerce_positive_int(value: Any) -> int | None: |
| 197 | """Return one positive integer when the input looks numeric.""" |
| 198 | |
| 199 | try: |
| 200 | number = int(value) |
| 201 | except (TypeError, ValueError): |
| 202 | return None |
| 203 | if number <= 0: |
| 204 | return None |
| 205 | return number |
| 206 | |
| 207 | |
| 208 | def _infer_context_window(model_details: dict[str, Any] | None) -> int | None: |
| 209 | """Infer one model context window from Ollama model metadata.""" |
| 210 | |
| 211 | if not isinstance(model_details, dict): |
| 212 | return None |
| 213 | |
| 214 | candidates: list[int] = [] |
| 215 | |
| 216 | details = model_details.get("details") |
| 217 | if isinstance(details, dict): |
| 218 | context_length = _coerce_positive_int(details.get("context_length")) |
| 219 | if context_length is not None: |
| 220 | candidates.append(context_length) |
| 221 | |
| 222 | model_info = model_details.get("model_info") |
| 223 | if isinstance(model_info, dict): |
| 224 | for key, value in model_info.items(): |
| 225 | if str(key).endswith(".context_length"): |
| 226 | context_length = _coerce_positive_int(value) |
| 227 | if context_length is not None: |
| 228 | candidates.append(context_length) |
| 229 | |
| 230 | return max(candidates) if candidates else None |
| 231 | |
| 232 | |
| 233 | def resolve_capability_profile( |
| 234 | model_name: str, |
| 235 | *, |
| 236 | override: CapabilityProfile | None = None, |
| 237 | model_details: dict[str, Any] | None = None, |
| 238 | ) -> CapabilityProfile: |
| 239 | """Resolve the capability profile for a model. |
| 240 | |
| 241 | Resolution order: |
| 242 | 1. explicit override |
| 243 | 2. exact-name match in the built-in registry |
| 244 | 3. heuristic fallback using model details / family tokens |
| 245 | """ |
| 246 | |
| 247 | inferred_context_window = _infer_context_window(model_details) |
| 248 | |
| 249 | if override is not None: |
| 250 | if inferred_context_window is None: |
| 251 | return override |
| 252 | return CapabilityProfile( |
| 253 | model_name=override.model_name, |
| 254 | supports_native_tools=override.supports_native_tools, |
| 255 | supports_streaming=override.supports_streaming, |
| 256 | context_window=inferred_context_window, |
| 257 | preferred_tool_call_format=override.preferred_tool_call_format, |
| 258 | verification_strictness=override.verification_strictness, |
| 259 | notes=list(override.notes), |
| 260 | ) |
| 261 | |
| 262 | normalized = model_name.lower().strip() |
| 263 | # Try full name first, then without :tag (e.g. "deepseek-r1:14b" -> "deepseek-r1") |
| 264 | for key in (normalized, normalized.split(":")[0]): |
| 265 | if key in KNOWN_CAPABILITY_PROFILES: |
| 266 | known = KNOWN_CAPABILITY_PROFILES[key] |
| 267 | return CapabilityProfile( |
| 268 | model_name=model_name, |
| 269 | supports_native_tools=known.supports_native_tools, |
| 270 | supports_streaming=known.supports_streaming, |
| 271 | context_window=inferred_context_window or known.context_window, |
| 272 | preferred_tool_call_format=known.preferred_tool_call_format, |
| 273 | verification_strictness=known.verification_strictness, |
| 274 | notes=list(known.notes), |
| 275 | ) |
| 276 | |
| 277 | tokens = _family_tokens(normalized, model_details) |
| 278 | |
| 279 | if _any_prefix_match(tokens, NATIVE_TOOL_FAMILIES): |
| 280 | return _profile( |
| 281 | model_name, |
| 282 | supports_native_tools=True, |
| 283 | context_window=inferred_context_window or 8192, |
| 284 | preferred_tool_call_format="native", |
| 285 | verification_strictness="standard", |
| 286 | notes=["Resolved from model family heuristic."], |
| 287 | ) |
| 288 | |
| 289 | if _any_prefix_match(tokens, NO_TOOL_FAMILIES): |
| 290 | return _profile( |
| 291 | model_name, |
| 292 | supports_native_tools=False, |
| 293 | context_window=inferred_context_window or 8192, |
| 294 | preferred_tool_call_format="json_tag", |
| 295 | verification_strictness="standard", |
| 296 | notes=["Resolved from conservative no-native-tools heuristic."], |
| 297 | ) |
| 298 | |
| 299 | return _profile( |
| 300 | model_name, |
| 301 | supports_native_tools=False, |
| 302 | context_window=inferred_context_window or 8192, |
| 303 | preferred_tool_call_format="json_tag", |
| 304 | verification_strictness="standard", |
| 305 | notes=["Unknown model family; defaulting to safe ReAct-style tool use."], |
| 306 | ) |
| 307 | |
| 308 | |
| 309 | def resolve_backend_capability_profile(backend: Any) -> CapabilityProfile: |
| 310 | """Resolve capabilities from the backend first, then fall back to model heuristics.""" |
| 311 | |
| 312 | explicit_profile = getattr(backend, "capability_profile", None) |
| 313 | if callable(explicit_profile): |
| 314 | profile = explicit_profile() |
| 315 | if isinstance(profile, CapabilityProfile): |
| 316 | return profile |
| 317 | |
| 318 | model_name = getattr(backend, "model", backend.__class__.__name__) |
| 319 | explicit_native_tools = getattr(backend, "supports_native_tools", None) |
| 320 | if callable(explicit_native_tools): |
| 321 | supports_native_tools = bool(explicit_native_tools()) |
| 322 | preferred_tool_call_format: ToolCallFormat = ( |
| 323 | "native" if supports_native_tools else "json_tag" |
| 324 | ) |
| 325 | return _profile( |
| 326 | model_name, |
| 327 | supports_native_tools=supports_native_tools, |
| 328 | preferred_tool_call_format=preferred_tool_call_format, |
| 329 | verification_strictness="standard", |
| 330 | notes=["Resolved from backend capability surface."], |
| 331 | ) |
| 332 | |
| 333 | return resolve_capability_profile(model_name) |