tenseleyflow/loader / 53d9e59

Browse files

Refresh runtime capabilities before each turn

Authored by espadonne
SHA
53d9e5951ef2972a8cf0f8cb103c99085c797bad
Parents
ea386ab
Tree
348f885

4 changed files

StatusFile+-
M src/loader/agent/loop.py 55 114
M src/loader/cli/main.py 16 9
M src/loader/llm/ollama.py 32 11
M src/loader/runtime/conversation.py 19 0
src/loader/agent/loop.pymodified
@@ -1,54 +1,47 @@
11
 """The main agent loop."""
22
 
33
 import asyncio
4
+import contextlib
5
+from collections.abc import AsyncIterator, Awaitable, Callable
46
 from dataclasses import dataclass
57
 from pathlib import Path
6
-from typing import AsyncIterator, Awaitable, Callable
78
 
8
-from ..llm.base import LLMBackend, Message, Role, ToolCall
9
-from ..tools.base import ToolRegistry, create_default_registry, ConfirmationRequired
109
 from ..context.project import ProjectContext, detect_project
10
+from ..llm.base import LLMBackend, Message, Role, ToolCall
1111
 from ..runtime.capabilities import resolve_backend_capability_profile
1212
 from ..runtime.conversation import ConversationRuntime
1313
 from ..runtime.events import AgentEvent, TurnSummary
1414
 from ..runtime.session import ConversationSession
15
+from ..tools.base import ToolRegistry, create_default_registry
16
+from .planner import (
17
+    PLANNING_PROMPT,
18
+    SHOULD_PLAN_PROMPT,
19
+    Plan,
20
+    format_step_prompt,
21
+    parse_plan,
22
+    should_plan,
23
+)
1524
 from .prompts import build_system_prompt
16
-from .parsing import parse_tool_calls, format_tool_result
17
-from .planner import Plan, parse_plan, should_plan, format_step_prompt, PLANNING_PROMPT, SHOULD_PLAN_PROMPT
18
-from .recovery import RecoveryContext, format_recovery_prompt, format_failure_message
1925
 from .reasoning import (
20
-    TaskDecomposition,
21
-    Subtask,
22
-    SelfCritique,
23
-    ConfidenceAssessment,
24
-    ActionVerification,
25
-    ConfidenceLevel,
26
-    TaskCompletionCheck,
27
-    RollbackPlan,
28
-    RollbackAction,
29
-    RollbackType,
26
+    CONFIDENCE_PROMPT,
3027
     DECOMPOSITION_PROMPT,
3128
     SELF_CRITIQUE_PROMPT,
32
-    CONFIDENCE_PROMPT,
3329
     VERIFICATION_PROMPT,
34
-    COMPLETION_CHECK_PROMPT,
30
+    ActionVerification,
31
+    ConfidenceAssessment,
32
+    ConfidenceLevel,
33
+    SelfCritique,
34
+    TaskDecomposition,
35
+    estimate_confidence_quick,
36
+    is_conversational,
37
+    parse_confidence,
3538
     parse_decomposition,
3639
     parse_self_critique,
37
-    parse_confidence,
3840
     parse_verification,
39
-    parse_completion_check,
40
-    should_decompose,
41
-    should_self_critique,
42
-    estimate_confidence_quick,
4341
     quick_verify,
44
-    detect_premature_completion,
45
-    get_continuation_prompt,
46
-    is_destructive_tool,
47
-    create_rollback_plan_for_action,
48
-    is_conversational,
49
-    estimate_complexity,
50
-    get_token_budget,
42
+    should_decompose,
5143
 )
44
+from .recovery import RecoveryContext
5245
 from .safeguards import RuntimeSafeguards
5346
 
5447
 
@@ -209,8 +202,11 @@ class Agent:
209202
 
210203
     def refresh_capability_profile(self) -> None:
211204
         """Refresh the runtime capability profile from the current backend."""
212
-
213
-        self.capability_profile = resolve_backend_capability_profile(self.backend)
205
+        previous_profile = self.capability_profile
206
+        refreshed_profile = resolve_backend_capability_profile(self.backend)
207
+        self.capability_profile = refreshed_profile
208
+        if refreshed_profile != previous_profile:
209
+            self._system_message = None
214210
         self._use_react = None
215211
 
216212
     def _get_few_shot_examples(self) -> list[Message]:
@@ -552,7 +548,7 @@ class Agent:
552548
 
553549
                     # Run the step
554550
                     step_prompt = format_step_prompt(plan, step)
555
-                    step_response = await self._run_inner(
551
+                    await self._run_inner(
556552
                         step_prompt, emit, on_confirmation,
557553
                         original_task=self._current_task,
558554
                     )
@@ -596,91 +592,36 @@ class Agent:
596592
         self,
597593
         user_message: str,
598594
     ) -> AsyncIterator[AgentEvent]:
599
-        """Run the agent with streaming output."""
600
-        # Add user message
601
-        self.messages.append(Message(role=Role.USER, content=user_message))
595
+        """Run the agent with streaming output from the primary runtime path."""
602596
 
603
-        iterations = 0
604
-        tools = None if self.use_react else self.registry.get_schemas()
597
+        queue: asyncio.Queue[AgentEvent | BaseException | None] = asyncio.Queue()
605598
 
606
-        while iterations < self.config.max_iterations:
607
-            iterations += 1
599
+        async def on_event(event: AgentEvent) -> None:
600
+            await queue.put(event)
608601
 
609
-            yield AgentEvent(type="thinking")
610
-
611
-            # Stream the response
612
-            full_content = ""
613
-            tool_calls: list[ToolCall] = []
614
-
615
-            async for chunk in self.backend.stream(
616
-                messages=self._build_messages(),
617
-                tools=tools,
618
-                temperature=self.config.temperature,
619
-                max_tokens=self.config.max_tokens,
620
-            ):
621
-                if chunk.content:
622
-                    full_content += chunk.content
623
-                    yield AgentEvent(type="response", content=chunk.content)
624
-
625
-                if chunk.tool_calls:
626
-                    tool_calls = chunk.tool_calls
627
-
628
-            # In ReAct mode, parse tool calls from text
629
-            if self.use_react:
630
-                parsed = parse_tool_calls(full_content)
631
-                tool_calls = parsed.tool_calls
602
+        async def run_agent() -> None:
603
+            try:
604
+                await self.run(user_message, on_event=on_event)
605
+            except BaseException as exc:  # pragma: no cover - propagated below
606
+                await queue.put(exc)
607
+            finally:
608
+                await queue.put(None)
632609
 
633
-                if parsed.is_final_answer and not tool_calls:
634
-                    self.messages.append(Message(
635
-                        role=Role.ASSISTANT,
636
-                        content=full_content,
637
-                    ))
610
+        task = asyncio.create_task(run_agent())
611
+        try:
612
+            while True:
613
+                item = await queue.get()
614
+                if item is None:
638615
                     break
639
-
640
-            # If there are tool calls, execute them
641
-            if tool_calls:
642
-                self.messages.append(Message(
643
-                    role=Role.ASSISTANT,
644
-                    content=full_content,
645
-                    tool_calls=tool_calls,
646
-                ))
647
-
648
-                for tool_call in tool_calls:
649
-                    yield AgentEvent(
650
-                        type="tool_call",
651
-                        tool_name=tool_call.name,
652
-                        tool_args=tool_call.arguments,
653
-                    )
654
-
655
-                    result = await self.registry.execute(
656
-                        tool_call.name,
657
-                        **tool_call.arguments,
658
-                    )
659
-
660
-                    yield AgentEvent(
661
-                        type="tool_result",
662
-                        content=result.output,
663
-                        tool_name=tool_call.name,
664
-                    )
665
-
666
-                    result_text = format_tool_result(
667
-                        tool_call.name,
668
-                        result.output,
669
-                        result.is_error,
670
-                    )
671
-                    self.messages.append(Message(
672
-                        role=Role.TOOL,
673
-                        content=result_text,
674
-                    ))
675
-
676
-                continue
677
-
678
-            # No tool calls - done
679
-            self.messages.append(Message(
680
-                role=Role.ASSISTANT,
681
-                content=full_content,
682
-            ))
683
-            break
616
+                if isinstance(item, BaseException):
617
+                    raise item
618
+                yield item
619
+            await task
620
+        finally:
621
+            if not task.done():
622
+                task.cancel()
623
+                with contextlib.suppress(asyncio.CancelledError):
624
+                    await task
684625
 
685626
     def _contains_unexecuted_code(self, content: str) -> bool:
686627
         """Detect if response contains code blocks that should be tool calls.
@@ -781,9 +722,9 @@ class Agent:
781722
         instead of using the proper tool calling API. This method tries to
782723
         parse and recover them.
783724
         """
784
-        import re
785725
         import json
786726
         import os
727
+        import re
787728
 
788729
         tool_calls = []
789730
         tool_names = ["write", "read", "edit", "bash", "glob", "grep"]
src/loader/cli/main.pymodified
@@ -2,6 +2,7 @@
22
 
33
 import asyncio
44
 import re
5
+
56
 import click
67
 import httpx
78
 from rich.console import Console
@@ -28,11 +29,12 @@ async def select_model_interactive() -> str | None:
2829
     Returns:
2930
         Selected model name, or None if cancelled/no models.
3031
     """
31
-    from ..llm.ollama import OllamaBackend
32
-    from ..config import get_last_model
3332
     from prompt_toolkit import PromptSession
3433
     from prompt_toolkit.completion import WordCompleter
3534
 
35
+    from ..config import get_last_model
36
+    from ..llm.ollama import OllamaBackend
37
+
3638
     # Create a temporary client to list models
3739
     backend = OllamaBackend(model="")
3840
     models = await backend.list_models()
@@ -194,10 +196,10 @@ async def _main(
194196
     reason: bool,
195197
     prompt: str | None,
196198
 ) -> None:
197
-    from ..llm.ollama import OllamaBackend
198199
     from ..agent.loop import Agent, AgentConfig, ReasoningConfig
200
+    from ..config import get_default_model, get_last_model, set_last_model
201
+    from ..llm.ollama import OllamaBackend
199202
     from ..tools.base import create_default_registry
200
-    from ..config import get_default_model, set_last_model, get_last_model
201203
 
202204
     # Handle model selection
203205
     if select_model:
@@ -222,9 +224,6 @@ async def _main(
222224
         timeout=timeout,
223225
     )
224226
 
225
-    # Determine actual mode based on model capabilities (not just CLI flag)
226
-    mode_str = "ReAct" if react or not llm.supports_native_tools() else "Native"
227
-
228227
     # Check health
229228
     if not await llm.health_check():
230229
         console.print("[red]Error: Cannot connect to Ollama. Is it running?[/red]")
@@ -234,6 +233,11 @@ async def _main(
234233
         console.print("\nTry [cyan]loader --select-model[/cyan] to choose from available models.")
235234
         return
236235
 
236
+    await llm.describe_model()
237
+
238
+    # Determine actual mode based on resolved model capabilities (not just CLI flag)
239
+    mode_str = "ReAct" if react or not llm.supports_native_tools() else "Native"
240
+
237241
     # Save this model as the new default
238242
     set_last_model(model)
239243
 
@@ -333,9 +337,10 @@ def _format_tool_args(args: dict | None) -> str:
333337
 
334338
 async def run_once(agent, prompt: str, skip_confirmation: bool = False) -> None:
335339
     """Run a single prompt."""
336
-    from ..tools.base import ConfirmationRequired
337340
     import time
338341
 
342
+    from ..tools.base import ConfirmationRequired
343
+
339344
     thinking_start = None
340345
     streamed_response = False
341346
 
@@ -404,10 +409,12 @@ async def run_once(agent, prompt: str, skip_confirmation: bool = False) -> None:
404409
 
405410
 async def run_interactive(agent, skip_confirmation: bool = False) -> None:
406411
     """Run interactive chat loop."""
412
+    import os
413
+
407414
     from prompt_toolkit import PromptSession
408415
     from prompt_toolkit.history import FileHistory
416
+
409417
     from ..tools.base import ConfirmationRequired
410
-    import os
411418
 
412419
     history_file = os.path.expanduser("~/.loader_history")
413420
     session = PromptSession(history=FileHistory(history_file))
src/loader/llm/ollama.pymodified
@@ -1,7 +1,8 @@
11
 """Ollama backend implementation."""
22
 
33
 import json
4
-from typing import Any, AsyncIterator
4
+from collections.abc import AsyncIterator
5
+from typing import Any
56
 
67
 import httpx
78
 
@@ -10,7 +11,6 @@ from .base import (
1011
     CompletionResponse,
1112
     LLMBackend,
1213
     Message,
13
-    Role,
1414
     StreamChunk,
1515
     ToolCall,
1616
 )
@@ -40,8 +40,26 @@ class OllamaBackend(LLMBackend):
4040
         self._client = httpx.AsyncClient(timeout=timeout)
4141
         self._supports_native_tools: bool | None = None
4242
         self._model_details_cache: dict[str, Any] | None = None
43
+        self._model_details_loaded_for: str | None = None
4344
         self._capability_profile: CapabilityProfile | None = None
4445
 
46
+    def _invalidate_model_caches_if_needed(self) -> None:
47
+        """Clear cached capability state when the active model changes."""
48
+
49
+        if (
50
+            self._capability_profile is not None
51
+            and self._capability_profile.model_name != self.model
52
+        ):
53
+            self._capability_profile = None
54
+            self._supports_native_tools = None
55
+
56
+        if (
57
+            self._model_details_loaded_for is not None
58
+            and self._model_details_loaded_for != self.model
59
+        ):
60
+            self._model_details_cache = None
61
+            self._model_details_loaded_for = None
62
+
4563
     def _build_options(self, temperature: float, max_tokens: int) -> dict:
4664
         """Build Ollama options dict with performance settings."""
4765
         return {
@@ -108,7 +126,9 @@ class OllamaBackend(LLMBackend):
108126
     async def describe_model(self) -> dict[str, Any] | None:
109127
         """Fetch and cache Ollama model details for capability resolution."""
110128
 
111
-        if self._model_details_cache is not None:
129
+        self._invalidate_model_caches_if_needed()
130
+
131
+        if self._model_details_loaded_for == self.model:
112132
             return self._model_details_cache
113133
 
114134
         if not self.model:
@@ -121,18 +141,18 @@ class OllamaBackend(LLMBackend):
121141
             )
122142
             response.raise_for_status()
123143
             self._model_details_cache = response.json()
144
+            self._model_details_loaded_for = self.model
124145
         except Exception:
125146
             self._model_details_cache = None
147
+            self._model_details_loaded_for = self.model
126148
 
127149
         return self._model_details_cache
128150
 
129151
     def capability_profile(self) -> CapabilityProfile:
130152
         """Return the resolved capability profile for the current model."""
131153
 
132
-        if (
133
-            self._capability_profile is None
134
-            or self._capability_profile.model_name != self.model
135
-        ):
154
+        self._invalidate_model_caches_if_needed()
155
+        if self._capability_profile is None:
136156
             self._capability_profile = resolve_capability_profile(
137157
                 self.model,
138158
                 model_details=self._model_details_cache,
@@ -148,9 +168,7 @@ class OllamaBackend(LLMBackend):
148168
         if self.force_react:
149169
             return False
150170
 
151
-        if self._capability_profile is not None and self._capability_profile.model_name != self.model:
152
-            self._capability_profile = None
153
-            self._supports_native_tools = None
171
+        self._invalidate_model_caches_if_needed()
154172
 
155173
         if self._supports_native_tools is not None:
156174
             return self._supports_native_tools
@@ -228,6 +246,8 @@ class OllamaBackend(LLMBackend):
228246
         max_tokens: int = 4096,
229247
     ) -> CompletionResponse:
230248
         """Generate a completion using Ollama."""
249
+        await self.describe_model()
250
+
231251
         payload: dict[str, Any] = {
232252
             "model": self.model,
233253
             "messages": self._format_messages(messages),
@@ -305,6 +325,8 @@ class OllamaBackend(LLMBackend):
305325
         max_tokens: int = 4096,
306326
     ) -> AsyncIterator[StreamChunk]:
307327
         """Stream a completion from Ollama."""
328
+        await self.describe_model()
329
+
308330
         payload: dict[str, Any] = {
309331
             "model": self.model,
310332
             "messages": self._format_messages(messages),
@@ -365,7 +387,6 @@ class OllamaBackend(LLMBackend):
365387
 
366388
     async def _stream_response(self, response) -> AsyncIterator[StreamChunk]:
367389
         """Internal helper to stream response chunks."""
368
-        import re
369390
 
370391
         full_content = ""
371392
         display_content = ""  # Content to show (filtered)
src/loader/runtime/conversation.pymodified
@@ -56,6 +56,8 @@ class ConversationRuntime:
5656
     ) -> TurnSummary:
5757
         """Run one task turn and return a structured summary."""
5858
 
59
+        await self._prepare_runtime_capabilities()
60
+
5961
         iterations = 0
6062
         final_response = ""
6163
         actions_taken: list[str] = []
@@ -690,6 +692,23 @@ class ConversationRuntime:
690692
         for key, value in update.items():
691693
             target[key] = target.get(key, 0) + value
692694
 
695
+    async def _prepare_runtime_capabilities(self) -> None:
696
+        describe_model = getattr(self.agent.backend, "describe_model", None)
697
+        if callable(describe_model):
698
+            await describe_model()
699
+
700
+        previous_profile = self.agent.capability_profile
701
+        self.agent.refresh_capability_profile()
702
+        if self.agent.capability_profile != previous_profile:
703
+            self.tracer.record(
704
+                "runtime.capabilities_refreshed",
705
+                model_name=self.agent.capability_profile.model_name,
706
+                supports_native_tools=self.agent.capability_profile.supports_native_tools,
707
+                preferred_tool_call_format=(
708
+                    self.agent.capability_profile.preferred_tool_call_format
709
+                ),
710
+            )
711
+
693712
     @staticmethod
694713
     def _emit_confirmation(emit: EventSink):
695714
         async def _emit(tool_name: str, message: str, details: str) -> None: