tenseleyflow/loader / 00aec36

Browse files

Probe tool support at runtime temperature with 3/3 pass requirement to catch unreliable models

Authored by espadonne
SHA
00aec36944fbdd20eaccd9f87832bc0b25cdc5e9
Parents
c9162b0
Tree
52d2e2b

2 changed files

StatusFile+-
M src/loader/cli/main.py 8 3
M src/loader/llm/ollama.py 53 37
src/loader/cli/main.pymodified
@@ -338,10 +338,15 @@ async def _main(
338338
 
339339
     await llm.describe_model()
340340
 
341
-    # Probe the model's actual tool calling behavior (not just family heuristics)
341
+    # Probe the model's actual tool calling behavior at runtime temperature.
342
+    # Runs 3 rounds — model must pass all 3 to be considered native.
343
+    # This catches models like devstral that work at temp=0 but are
344
+    # unreliable at the actual runtime temperature.
342345
     if not react and hasattr(llm, "probe_native_tool_support"):
343
-        console.print("[dim]Probing tool support...[/dim]", end="")
344
-        native = await llm.probe_native_tool_support()
346
+        from .agent.loop import AgentConfig as _ProbeCfg
347
+        probe_temp = _ProbeCfg.temperature
348
+        console.print(f"[dim]Probing tool support (temp={probe_temp}, 3 rounds)...[/dim]", end="")
349
+        native = await llm.probe_native_tool_support(temperature=probe_temp)
345350
         console.print(f" [dim]{'native' if native else 'react'}[/dim]")
346351
     mode_str = "ReAct" if react or not llm.supports_native_tools() else "Native"
347352
 
src/loader/llm/ollama.pymodified
@@ -202,13 +202,18 @@ class OllamaBackend(LLMBackend):
202202
         self._supports_native_tools = self.capability_profile().supports_native_tools
203203
         return self._supports_native_tools
204204
 
205
-    async def probe_native_tool_support(self) -> bool:
206
-        """Send a minimal tool call to the model and check if it responds with
207
-        native ``tool_calls`` rather than text.  Caches the result so subsequent
208
-        calls are free.
209
-
210
-        This replaces relying solely on family-name heuristics — it tests the
211
-        model's *actual* behavior.
205
+    async def probe_native_tool_support(
206
+        self,
207
+        temperature: float = 0.3,
208
+        rounds: int = 3,
209
+        required_passes: int = 3,
210
+    ) -> bool:
211
+        """Probe whether the model *reliably* produces native tool calls.
212
+
213
+        Runs ``rounds`` probe calls at the actual runtime ``temperature``.
214
+        The model must produce ``tool_calls`` in at least ``required_passes``
215
+        of those rounds to be considered native.  This catches models like
216
+        devstral that pass at temp=0 but are unreliable at higher temps.
212217
         """
213218
         if self.force_react:
214219
             self._supports_native_tools = False
@@ -226,36 +231,47 @@ class OllamaBackend(LLMBackend):
226231
                 },
227232
             },
228233
         }]
229
-        payload = {
230
-            "model": self.model,
231
-            "messages": [{"role": "user", "content": "Call the probe tool with value OK"}],
232
-            "tools": probe_tool,
233
-            "stream": False,
234
-            "options": {"temperature": 0, "num_predict": 64, "num_ctx": 2048},
235
-        }
236
-        try:
237
-            response = await self._client.post(
238
-                f"{self.base_url}/api/chat", json=payload,
239
-            )
240
-            if response.status_code == 400:
241
-                error = response.json().get("error", "")
242
-                if "does not support tools" in error:
243
-                    self._supports_native_tools = False
244
-                    return False
245
-            response.raise_for_status()
246
-            data = response.json()
247
-            message = data.get("message", {})
248
-            has_tool_calls = bool(message.get("tool_calls"))
249
-            self._supports_native_tools = has_tool_calls
250
-            self._debug_log(
251
-                f"probe_native_tool_support: {has_tool_calls} "
252
-                f"(content_len={len(message.get('content', ''))})"
253
-            )
254
-            return has_tool_calls
255
-        except Exception:
256
-            # On any failure, fall back to heuristic
257
-            self._supports_native_tools = self.capability_profile().supports_native_tools
258
-            return self._supports_native_tools
234
+
235
+        passes = 0
236
+        for i in range(rounds):
237
+            payload = {
238
+                "model": self.model,
239
+                "messages": [
240
+                    {"role": "user", "content": "Call the probe tool with value OK"},
241
+                ],
242
+                "tools": probe_tool,
243
+                "stream": False,
244
+                "options": {
245
+                    "temperature": temperature,
246
+                    "num_predict": 64,
247
+                    "num_ctx": 2048,
248
+                },
249
+            }
250
+            try:
251
+                response = await self._client.post(
252
+                    f"{self.base_url}/api/chat", json=payload,
253
+                )
254
+                if response.status_code == 400:
255
+                    error = response.json().get("error", "")
256
+                    if "does not support tools" in error:
257
+                        self._supports_native_tools = False
258
+                        self._debug_log("probe: model rejected tools (400)")
259
+                        return False
260
+                response.raise_for_status()
261
+                data = response.json()
262
+                message = data.get("message", {})
263
+                if message.get("tool_calls"):
264
+                    passes += 1
265
+            except Exception:
266
+                pass  # treat failures as non-passes
267
+
268
+        native = passes >= required_passes
269
+        self._supports_native_tools = native
270
+        self._debug_log(
271
+            f"probe_native_tool_support: {passes}/{rounds} passed "
272
+            f"(need {required_passes}) → {'native' if native else 'react'}"
273
+        )
274
+        return native
259275
 
260276
     def _format_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
261277
         """Format messages for Ollama API.