tenseleyflow/loader / 02dd95a

Browse files

Teach Ollama and model switching about capability profiles

Authored by espadonne
SHA
02dd95a64f6a3963eb556bc0d321e231d176899f
Parents
3c5be85
Tree
bb34ce8

3 changed files

StatusFile+-
M src/loader/llm/ollama.py 44 58
M src/loader/ui/app.py 2 0
M tests/test_capabilities.py 17 1
src/loader/llm/ollama.pymodified
@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator
55
 
66
 import httpx
77
 
8
+from ..runtime.capabilities import CapabilityProfile, resolve_capability_profile
89
 from .base import (
910
     CompletionResponse,
1011
     LLMBackend,
@@ -18,40 +19,6 @@ from .base import (
1819
 class OllamaBackend(LLMBackend):
1920
     """Ollama API backend for local LLM inference."""
2021
 
21
-    # Models known to support native function calling in Ollama
22
-    # Verified working with Ollama's tool calling API
23
-    NATIVE_TOOL_MODELS = {
24
-        "llama3.1", "llama3.2", "llama3.3",
25
-        "qwen2.5", "qwen2",
26
-        "mistral", "mixtral",
27
-        "command-r",
28
-        "granite",
29
-        # Note: deepseek-coder, codestral, starcoder do NOT support tools in Ollama
30
-    }
31
-
32
-    # Models that definitely do NOT support native tools (use ReAct)
33
-    NO_TOOL_MODELS = {
34
-        "llama2", "llama:latest",  # Base llama without version
35
-        "phi", "phi3",
36
-        "gemma", "gemma2",
37
-        "tinyllama",
38
-        "orca",
39
-        "vicuna",
40
-        "wizard",
41
-        "neural-chat",
42
-        "starling",
43
-        "openchat",
44
-        "yi",
45
-        "solar",
46
-        "dolphin",
47
-        # Coding models that don't support Ollama tools
48
-        "codestral",
49
-        "deepseek-coder",
50
-        "starcoder",
51
-        "codegemma",
52
-        "deepseek-r1",  # Reasoning model, no tools
53
-    }
54
-
5522
     def __init__(
5623
         self,
5724
         model: str = "llama3.1:8b",
@@ -72,6 +39,8 @@ class OllamaBackend(LLMBackend):
7239
         self.num_gpu = num_gpu
7340
         self._client = httpx.AsyncClient(timeout=timeout)
7441
         self._supports_native_tools: bool | None = None
42
+        self._model_details_cache: dict[str, Any] | None = None
43
+        self._capability_profile: CapabilityProfile | None = None
7544
 
7645
     def _build_options(self, temperature: float, max_tokens: int) -> dict:
7746
         """Build Ollama options dict with performance settings."""
@@ -136,6 +105,40 @@ class OllamaBackend(LLMBackend):
136105
         except Exception:
137106
             return []
138107
 
108
+    async def describe_model(self) -> dict[str, Any] | None:
109
+        """Fetch and cache Ollama model details for capability resolution."""
110
+
111
+        if self._model_details_cache is not None:
112
+            return self._model_details_cache
113
+
114
+        if not self.model:
115
+            return None
116
+
117
+        try:
118
+            response = await self._client.post(
119
+                f"{self.base_url}/api/show",
120
+                json={"name": self.model},
121
+            )
122
+            response.raise_for_status()
123
+            self._model_details_cache = response.json()
124
+        except Exception:
125
+            self._model_details_cache = None
126
+
127
+        return self._model_details_cache
128
+
129
+    def capability_profile(self) -> CapabilityProfile:
130
+        """Return the resolved capability profile for the current model."""
131
+
132
+        if (
133
+            self._capability_profile is None
134
+            or self._capability_profile.model_name != self.model
135
+        ):
136
+            self._capability_profile = resolve_capability_profile(
137
+                self.model,
138
+                model_details=self._model_details_cache,
139
+            )
140
+        return self._capability_profile
141
+
139142
     def supports_native_tools(self) -> bool:
140143
         """Check if current model supports native function calling.
141144
 
@@ -145,36 +148,19 @@ class OllamaBackend(LLMBackend):
145148
         if self.force_react:
146149
             return False
147150
 
151
+        if self._capability_profile is not None and self._capability_profile.model_name != self.model:
152
+            self._capability_profile = None
153
+            self._supports_native_tools = None
154
+
148155
         if self._supports_native_tools is not None:
149156
             return self._supports_native_tools
150157
 
151
-        model_lower = self.model.lower()
152
-
153
-        # First check if it's explicitly a NO_TOOL model
154
-        for no_tool_model in self.NO_TOOL_MODELS:
155
-            if no_tool_model in model_lower:
156
-                self._supports_native_tools = False
157
-                return False
158
-
159
-        # Check if model name contains any known native tool model
160
-        for native_model in self.NATIVE_TOOL_MODELS:
161
-            if native_model in model_lower:
162
-                self._supports_native_tools = True
163
-                return True
164
-
165
-        # Default to False for unknown models (safer - uses ReAct)
166
-        self._supports_native_tools = False
167
-        return False
158
+        self._supports_native_tools = self.capability_profile().supports_native_tools
159
+        return self._supports_native_tools
168160
 
169161
     def _format_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
170162
         """Format messages for Ollama API."""
171
-        formatted = []
172
-        for msg in messages:
173
-            formatted.append({
174
-                "role": msg.role.value,
175
-                "content": msg.content,
176
-            })
177
-        return formatted
163
+        return [message.to_dict() for message in messages]
178164
 
179165
     def _format_tools(self, tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
180166
         """Format tools for Ollama API."""
src/loader/ui/app.pymodified
@@ -270,6 +270,8 @@ class LoaderApp(App):
270270
         if hasattr(self.agent.backend, "model"):
271271
             old_model = self.agent.backend.model
272272
             self.agent.backend.model = model_name
273
+            if hasattr(self.agent, "refresh_capability_profile"):
274
+                self.agent.refresh_capability_profile()
273275
             self.model_name = model_name
274276
             # Update status line
275277
             self.query_one(StatusLine).model = model_name
tests/test_capabilities.pymodified
@@ -1,6 +1,10 @@
11
 """Tests for runtime capability profile resolution."""
22
 
3
-from loader.runtime.capabilities import CapabilityProfile, resolve_capability_profile
3
+from loader.runtime.capabilities import (
4
+    CapabilityProfile,
5
+    resolve_backend_capability_profile,
6
+    resolve_capability_profile,
7
+)
48
 
59
 
610
 def test_explicit_override_wins() -> None:
@@ -43,3 +47,15 @@ def test_unknown_models_default_to_safe_react_profile() -> None:
4347
     assert not resolved.supports_native_tools
4448
     assert resolved.preferred_tool_call_format == "json_tag"
4549
     assert "defaulting to safe" in resolved.notes[0].lower()
50
+
51
+
52
+def test_backend_capability_profile_prefers_explicit_backend_surface() -> None:
53
+    class DummyBackend:
54
+        def supports_native_tools(self) -> bool:
55
+            return True
56
+
57
+    resolved = resolve_backend_capability_profile(DummyBackend())
58
+
59
+    assert resolved.supports_native_tools
60
+    assert resolved.preferred_tool_call_format == "native"
61
+    assert "backend capability surface" in resolved.notes[0].lower()