Refresh runtime capabilities before each turn
- SHA
53d9e5951ef2972a8cf0f8cb103c99085c797bad- Parents
-
ea386ab - Tree
348f885
53d9e59
53d9e5951ef2972a8cf0f8cb103c99085c797badea386ab
348f885| Status | File | + | - |
|---|---|---|---|
| M |
src/loader/agent/loop.py
|
55 | 114 |
| M |
src/loader/cli/main.py
|
16 | 9 |
| M |
src/loader/llm/ollama.py
|
32 | 11 |
| M |
src/loader/runtime/conversation.py
|
19 | 0 |
src/loader/agent/loop.pymodified@@ -1,54 +1,47 @@ | ||
| 1 | 1 | """The main agent loop.""" |
| 2 | 2 | |
| 3 | 3 | import asyncio |
| 4 | +import contextlib | |
| 5 | +from collections.abc import AsyncIterator, Awaitable, Callable | |
| 4 | 6 | from dataclasses import dataclass |
| 5 | 7 | from pathlib import Path |
| 6 | -from typing import AsyncIterator, Awaitable, Callable | |
| 7 | 8 | |
| 8 | -from ..llm.base import LLMBackend, Message, Role, ToolCall | |
| 9 | -from ..tools.base import ToolRegistry, create_default_registry, ConfirmationRequired | |
| 10 | 9 | from ..context.project import ProjectContext, detect_project |
| 10 | +from ..llm.base import LLMBackend, Message, Role, ToolCall | |
| 11 | 11 | from ..runtime.capabilities import resolve_backend_capability_profile |
| 12 | 12 | from ..runtime.conversation import ConversationRuntime |
| 13 | 13 | from ..runtime.events import AgentEvent, TurnSummary |
| 14 | 14 | from ..runtime.session import ConversationSession |
| 15 | +from ..tools.base import ToolRegistry, create_default_registry | |
| 16 | +from .planner import ( | |
| 17 | + PLANNING_PROMPT, | |
| 18 | + SHOULD_PLAN_PROMPT, | |
| 19 | + Plan, | |
| 20 | + format_step_prompt, | |
| 21 | + parse_plan, | |
| 22 | + should_plan, | |
| 23 | +) | |
| 15 | 24 | from .prompts import build_system_prompt |
| 16 | -from .parsing import parse_tool_calls, format_tool_result | |
| 17 | -from .planner import Plan, parse_plan, should_plan, format_step_prompt, PLANNING_PROMPT, SHOULD_PLAN_PROMPT | |
| 18 | -from .recovery import RecoveryContext, format_recovery_prompt, format_failure_message | |
| 19 | 25 | from .reasoning import ( |
| 20 | - TaskDecomposition, | |
| 21 | - Subtask, | |
| 22 | - SelfCritique, | |
| 23 | - ConfidenceAssessment, | |
| 24 | - ActionVerification, | |
| 25 | - ConfidenceLevel, | |
| 26 | - TaskCompletionCheck, | |
| 27 | - RollbackPlan, | |
| 28 | - RollbackAction, | |
| 29 | - RollbackType, | |
| 26 | + CONFIDENCE_PROMPT, | |
| 30 | 27 | DECOMPOSITION_PROMPT, |
| 31 | 28 | SELF_CRITIQUE_PROMPT, |
| 32 | - CONFIDENCE_PROMPT, | |
| 33 | 29 | VERIFICATION_PROMPT, |
| 34 | - COMPLETION_CHECK_PROMPT, | |
| 30 | + ActionVerification, | |
| 31 | + ConfidenceAssessment, | |
| 32 | + ConfidenceLevel, | |
| 33 | + SelfCritique, | |
| 34 | + TaskDecomposition, | |
| 35 | + estimate_confidence_quick, | |
| 36 | + is_conversational, | |
| 37 | + parse_confidence, | |
| 35 | 38 | parse_decomposition, |
| 36 | 39 | parse_self_critique, |
| 37 | - parse_confidence, | |
| 38 | 40 | parse_verification, |
| 39 | - parse_completion_check, | |
| 40 | - should_decompose, | |
| 41 | - should_self_critique, | |
| 42 | - estimate_confidence_quick, | |
| 43 | 41 | quick_verify, |
| 44 | - detect_premature_completion, | |
| 45 | - get_continuation_prompt, | |
| 46 | - is_destructive_tool, | |
| 47 | - create_rollback_plan_for_action, | |
| 48 | - is_conversational, | |
| 49 | - estimate_complexity, | |
| 50 | - get_token_budget, | |
| 42 | + should_decompose, | |
| 51 | 43 | ) |
| 44 | +from .recovery import RecoveryContext | |
| 52 | 45 | from .safeguards import RuntimeSafeguards |
| 53 | 46 | |
| 54 | 47 | |
@@ -209,8 +202,11 @@ class Agent: | ||
| 209 | 202 | |
| 210 | 203 | def refresh_capability_profile(self) -> None: |
| 211 | 204 | """Refresh the runtime capability profile from the current backend.""" |
| 212 | - | |
| 213 | - self.capability_profile = resolve_backend_capability_profile(self.backend) | |
| 205 | + previous_profile = self.capability_profile | |
| 206 | + refreshed_profile = resolve_backend_capability_profile(self.backend) | |
| 207 | + self.capability_profile = refreshed_profile | |
| 208 | + if refreshed_profile != previous_profile: | |
| 209 | + self._system_message = None | |
| 214 | 210 | self._use_react = None |
| 215 | 211 | |
| 216 | 212 | def _get_few_shot_examples(self) -> list[Message]: |
@@ -552,7 +548,7 @@ class Agent: | ||
| 552 | 548 | |
| 553 | 549 | # Run the step |
| 554 | 550 | step_prompt = format_step_prompt(plan, step) |
| 555 | - step_response = await self._run_inner( | |
| 551 | + await self._run_inner( | |
| 556 | 552 | step_prompt, emit, on_confirmation, |
| 557 | 553 | original_task=self._current_task, |
| 558 | 554 | ) |
@@ -596,91 +592,36 @@ class Agent: | ||
| 596 | 592 | self, |
| 597 | 593 | user_message: str, |
| 598 | 594 | ) -> AsyncIterator[AgentEvent]: |
| 599 | - """Run the agent with streaming output.""" | |
| 600 | - # Add user message | |
| 601 | - self.messages.append(Message(role=Role.USER, content=user_message)) | |
| 595 | + """Run the agent with streaming output from the primary runtime path.""" | |
| 602 | 596 | |
| 603 | - iterations = 0 | |
| 604 | - tools = None if self.use_react else self.registry.get_schemas() | |
| 597 | + queue: asyncio.Queue[AgentEvent | BaseException | None] = asyncio.Queue() | |
| 605 | 598 | |
| 606 | - while iterations < self.config.max_iterations: | |
| 607 | - iterations += 1 | |
| 599 | + async def on_event(event: AgentEvent) -> None: | |
| 600 | + await queue.put(event) | |
| 608 | 601 | |
| 609 | - yield AgentEvent(type="thinking") | |
| 610 | - | |
| 611 | - # Stream the response | |
| 612 | - full_content = "" | |
| 613 | - tool_calls: list[ToolCall] = [] | |
| 614 | - | |
| 615 | - async for chunk in self.backend.stream( | |
| 616 | - messages=self._build_messages(), | |
| 617 | - tools=tools, | |
| 618 | - temperature=self.config.temperature, | |
| 619 | - max_tokens=self.config.max_tokens, | |
| 620 | - ): | |
| 621 | - if chunk.content: | |
| 622 | - full_content += chunk.content | |
| 623 | - yield AgentEvent(type="response", content=chunk.content) | |
| 624 | - | |
| 625 | - if chunk.tool_calls: | |
| 626 | - tool_calls = chunk.tool_calls | |
| 627 | - | |
| 628 | - # In ReAct mode, parse tool calls from text | |
| 629 | - if self.use_react: | |
| 630 | - parsed = parse_tool_calls(full_content) | |
| 631 | - tool_calls = parsed.tool_calls | |
| 602 | + async def run_agent() -> None: | |
| 603 | + try: | |
| 604 | + await self.run(user_message, on_event=on_event) | |
| 605 | + except BaseException as exc: # pragma: no cover - propagated below | |
| 606 | + await queue.put(exc) | |
| 607 | + finally: | |
| 608 | + await queue.put(None) | |
| 632 | 609 | |
| 633 | - if parsed.is_final_answer and not tool_calls: | |
| 634 | - self.messages.append(Message( | |
| 635 | - role=Role.ASSISTANT, | |
| 636 | - content=full_content, | |
| 637 | - )) | |
| 610 | + task = asyncio.create_task(run_agent()) | |
| 611 | + try: | |
| 612 | + while True: | |
| 613 | + item = await queue.get() | |
| 614 | + if item is None: | |
| 638 | 615 | break |
| 639 | - | |
| 640 | - # If there are tool calls, execute them | |
| 641 | - if tool_calls: | |
| 642 | - self.messages.append(Message( | |
| 643 | - role=Role.ASSISTANT, | |
| 644 | - content=full_content, | |
| 645 | - tool_calls=tool_calls, | |
| 646 | - )) | |
| 647 | - | |
| 648 | - for tool_call in tool_calls: | |
| 649 | - yield AgentEvent( | |
| 650 | - type="tool_call", | |
| 651 | - tool_name=tool_call.name, | |
| 652 | - tool_args=tool_call.arguments, | |
| 653 | - ) | |
| 654 | - | |
| 655 | - result = await self.registry.execute( | |
| 656 | - tool_call.name, | |
| 657 | - **tool_call.arguments, | |
| 658 | - ) | |
| 659 | - | |
| 660 | - yield AgentEvent( | |
| 661 | - type="tool_result", | |
| 662 | - content=result.output, | |
| 663 | - tool_name=tool_call.name, | |
| 664 | - ) | |
| 665 | - | |
| 666 | - result_text = format_tool_result( | |
| 667 | - tool_call.name, | |
| 668 | - result.output, | |
| 669 | - result.is_error, | |
| 670 | - ) | |
| 671 | - self.messages.append(Message( | |
| 672 | - role=Role.TOOL, | |
| 673 | - content=result_text, | |
| 674 | - )) | |
| 675 | - | |
| 676 | - continue | |
| 677 | - | |
| 678 | - # No tool calls - done | |
| 679 | - self.messages.append(Message( | |
| 680 | - role=Role.ASSISTANT, | |
| 681 | - content=full_content, | |
| 682 | - )) | |
| 683 | - break | |
| 616 | + if isinstance(item, BaseException): | |
| 617 | + raise item | |
| 618 | + yield item | |
| 619 | + await task | |
| 620 | + finally: | |
| 621 | + if not task.done(): | |
| 622 | + task.cancel() | |
| 623 | + with contextlib.suppress(asyncio.CancelledError): | |
| 624 | + await task | |
| 684 | 625 | |
| 685 | 626 | def _contains_unexecuted_code(self, content: str) -> bool: |
| 686 | 627 | """Detect if response contains code blocks that should be tool calls. |
@@ -781,9 +722,9 @@ class Agent: | ||
| 781 | 722 | instead of using the proper tool calling API. This method tries to |
| 782 | 723 | parse and recover them. |
| 783 | 724 | """ |
| 784 | - import re | |
| 785 | 725 | import json |
| 786 | 726 | import os |
| 727 | + import re | |
| 787 | 728 | |
| 788 | 729 | tool_calls = [] |
| 789 | 730 | tool_names = ["write", "read", "edit", "bash", "glob", "grep"] |
src/loader/cli/main.pymodified@@ -2,6 +2,7 @@ | ||
| 2 | 2 | |
| 3 | 3 | import asyncio |
| 4 | 4 | import re |
| 5 | + | |
| 5 | 6 | import click |
| 6 | 7 | import httpx |
| 7 | 8 | from rich.console import Console |
@@ -28,11 +29,12 @@ async def select_model_interactive() -> str | None: | ||
| 28 | 29 | Returns: |
| 29 | 30 | Selected model name, or None if cancelled/no models. |
| 30 | 31 | """ |
| 31 | - from ..llm.ollama import OllamaBackend | |
| 32 | - from ..config import get_last_model | |
| 33 | 32 | from prompt_toolkit import PromptSession |
| 34 | 33 | from prompt_toolkit.completion import WordCompleter |
| 35 | 34 | |
| 35 | + from ..config import get_last_model | |
| 36 | + from ..llm.ollama import OllamaBackend | |
| 37 | + | |
| 36 | 38 | # Create a temporary client to list models |
| 37 | 39 | backend = OllamaBackend(model="") |
| 38 | 40 | models = await backend.list_models() |
@@ -194,10 +196,10 @@ async def _main( | ||
| 194 | 196 | reason: bool, |
| 195 | 197 | prompt: str | None, |
| 196 | 198 | ) -> None: |
| 197 | - from ..llm.ollama import OllamaBackend | |
| 198 | 199 | from ..agent.loop import Agent, AgentConfig, ReasoningConfig |
| 200 | + from ..config import get_default_model, get_last_model, set_last_model | |
| 201 | + from ..llm.ollama import OllamaBackend | |
| 199 | 202 | from ..tools.base import create_default_registry |
| 200 | - from ..config import get_default_model, set_last_model, get_last_model | |
| 201 | 203 | |
| 202 | 204 | # Handle model selection |
| 203 | 205 | if select_model: |
@@ -222,9 +224,6 @@ async def _main( | ||
| 222 | 224 | timeout=timeout, |
| 223 | 225 | ) |
| 224 | 226 | |
| 225 | - # Determine actual mode based on model capabilities (not just CLI flag) | |
| 226 | - mode_str = "ReAct" if react or not llm.supports_native_tools() else "Native" | |
| 227 | - | |
| 228 | 227 | # Check health |
| 229 | 228 | if not await llm.health_check(): |
| 230 | 229 | console.print("[red]Error: Cannot connect to Ollama. Is it running?[/red]") |
@@ -234,6 +233,11 @@ async def _main( | ||
| 234 | 233 | console.print("\nTry [cyan]loader --select-model[/cyan] to choose from available models.") |
| 235 | 234 | return |
| 236 | 235 | |
| 236 | + await llm.describe_model() | |
| 237 | + | |
| 238 | + # Determine actual mode based on resolved model capabilities (not just CLI flag) | |
| 239 | + mode_str = "ReAct" if react or not llm.supports_native_tools() else "Native" | |
| 240 | + | |
| 237 | 241 | # Save this model as the new default |
| 238 | 242 | set_last_model(model) |
| 239 | 243 | |
@@ -333,9 +337,10 @@ def _format_tool_args(args: dict | None) -> str: | ||
| 333 | 337 | |
| 334 | 338 | async def run_once(agent, prompt: str, skip_confirmation: bool = False) -> None: |
| 335 | 339 | """Run a single prompt.""" |
| 336 | - from ..tools.base import ConfirmationRequired | |
| 337 | 340 | import time |
| 338 | 341 | |
| 342 | + from ..tools.base import ConfirmationRequired | |
| 343 | + | |
| 339 | 344 | thinking_start = None |
| 340 | 345 | streamed_response = False |
| 341 | 346 | |
@@ -404,10 +409,12 @@ async def run_once(agent, prompt: str, skip_confirmation: bool = False) -> None: | ||
| 404 | 409 | |
| 405 | 410 | async def run_interactive(agent, skip_confirmation: bool = False) -> None: |
| 406 | 411 | """Run interactive chat loop.""" |
| 412 | + import os | |
| 413 | + | |
| 407 | 414 | from prompt_toolkit import PromptSession |
| 408 | 415 | from prompt_toolkit.history import FileHistory |
| 416 | + | |
| 409 | 417 | from ..tools.base import ConfirmationRequired |
| 410 | - import os | |
| 411 | 418 | |
| 412 | 419 | history_file = os.path.expanduser("~/.loader_history") |
| 413 | 420 | session = PromptSession(history=FileHistory(history_file)) |
src/loader/llm/ollama.pymodified@@ -1,7 +1,8 @@ | ||
| 1 | 1 | """Ollama backend implementation.""" |
| 2 | 2 | |
| 3 | 3 | import json |
| 4 | -from typing import Any, AsyncIterator | |
| 4 | +from collections.abc import AsyncIterator | |
| 5 | +from typing import Any | |
| 5 | 6 | |
| 6 | 7 | import httpx |
| 7 | 8 | |
@@ -10,7 +11,6 @@ from .base import ( | ||
| 10 | 11 | CompletionResponse, |
| 11 | 12 | LLMBackend, |
| 12 | 13 | Message, |
| 13 | - Role, | |
| 14 | 14 | StreamChunk, |
| 15 | 15 | ToolCall, |
| 16 | 16 | ) |
@@ -40,8 +40,26 @@ class OllamaBackend(LLMBackend): | ||
| 40 | 40 | self._client = httpx.AsyncClient(timeout=timeout) |
| 41 | 41 | self._supports_native_tools: bool | None = None |
| 42 | 42 | self._model_details_cache: dict[str, Any] | None = None |
| 43 | + self._model_details_loaded_for: str | None = None | |
| 43 | 44 | self._capability_profile: CapabilityProfile | None = None |
| 44 | 45 | |
| 46 | + def _invalidate_model_caches_if_needed(self) -> None: | |
| 47 | + """Clear cached capability state when the active model changes.""" | |
| 48 | + | |
| 49 | + if ( | |
| 50 | + self._capability_profile is not None | |
| 51 | + and self._capability_profile.model_name != self.model | |
| 52 | + ): | |
| 53 | + self._capability_profile = None | |
| 54 | + self._supports_native_tools = None | |
| 55 | + | |
| 56 | + if ( | |
| 57 | + self._model_details_loaded_for is not None | |
| 58 | + and self._model_details_loaded_for != self.model | |
| 59 | + ): | |
| 60 | + self._model_details_cache = None | |
| 61 | + self._model_details_loaded_for = None | |
| 62 | + | |
| 45 | 63 | def _build_options(self, temperature: float, max_tokens: int) -> dict: |
| 46 | 64 | """Build Ollama options dict with performance settings.""" |
| 47 | 65 | return { |
@@ -108,7 +126,9 @@ class OllamaBackend(LLMBackend): | ||
| 108 | 126 | async def describe_model(self) -> dict[str, Any] | None: |
| 109 | 127 | """Fetch and cache Ollama model details for capability resolution.""" |
| 110 | 128 | |
| 111 | - if self._model_details_cache is not None: | |
| 129 | + self._invalidate_model_caches_if_needed() | |
| 130 | + | |
| 131 | + if self._model_details_loaded_for == self.model: | |
| 112 | 132 | return self._model_details_cache |
| 113 | 133 | |
| 114 | 134 | if not self.model: |
@@ -121,18 +141,18 @@ class OllamaBackend(LLMBackend): | ||
| 121 | 141 | ) |
| 122 | 142 | response.raise_for_status() |
| 123 | 143 | self._model_details_cache = response.json() |
| 144 | + self._model_details_loaded_for = self.model | |
| 124 | 145 | except Exception: |
| 125 | 146 | self._model_details_cache = None |
| 147 | + self._model_details_loaded_for = self.model | |
| 126 | 148 | |
| 127 | 149 | return self._model_details_cache |
| 128 | 150 | |
| 129 | 151 | def capability_profile(self) -> CapabilityProfile: |
| 130 | 152 | """Return the resolved capability profile for the current model.""" |
| 131 | 153 | |
| 132 | - if ( | |
| 133 | - self._capability_profile is None | |
| 134 | - or self._capability_profile.model_name != self.model | |
| 135 | - ): | |
| 154 | + self._invalidate_model_caches_if_needed() | |
| 155 | + if self._capability_profile is None: | |
| 136 | 156 | self._capability_profile = resolve_capability_profile( |
| 137 | 157 | self.model, |
| 138 | 158 | model_details=self._model_details_cache, |
@@ -148,9 +168,7 @@ class OllamaBackend(LLMBackend): | ||
| 148 | 168 | if self.force_react: |
| 149 | 169 | return False |
| 150 | 170 | |
| 151 | - if self._capability_profile is not None and self._capability_profile.model_name != self.model: | |
| 152 | - self._capability_profile = None | |
| 153 | - self._supports_native_tools = None | |
| 171 | + self._invalidate_model_caches_if_needed() | |
| 154 | 172 | |
| 155 | 173 | if self._supports_native_tools is not None: |
| 156 | 174 | return self._supports_native_tools |
@@ -228,6 +246,8 @@ class OllamaBackend(LLMBackend): | ||
| 228 | 246 | max_tokens: int = 4096, |
| 229 | 247 | ) -> CompletionResponse: |
| 230 | 248 | """Generate a completion using Ollama.""" |
| 249 | + await self.describe_model() | |
| 250 | + | |
| 231 | 251 | payload: dict[str, Any] = { |
| 232 | 252 | "model": self.model, |
| 233 | 253 | "messages": self._format_messages(messages), |
@@ -305,6 +325,8 @@ class OllamaBackend(LLMBackend): | ||
| 305 | 325 | max_tokens: int = 4096, |
| 306 | 326 | ) -> AsyncIterator[StreamChunk]: |
| 307 | 327 | """Stream a completion from Ollama.""" |
| 328 | + await self.describe_model() | |
| 329 | + | |
| 308 | 330 | payload: dict[str, Any] = { |
| 309 | 331 | "model": self.model, |
| 310 | 332 | "messages": self._format_messages(messages), |
@@ -365,7 +387,6 @@ class OllamaBackend(LLMBackend): | ||
| 365 | 387 | |
| 366 | 388 | async def _stream_response(self, response) -> AsyncIterator[StreamChunk]: |
| 367 | 389 | """Internal helper to stream response chunks.""" |
| 368 | - import re | |
| 369 | 390 | |
| 370 | 391 | full_content = "" |
| 371 | 392 | display_content = "" # Content to show (filtered) |
src/loader/runtime/conversation.pymodified@@ -56,6 +56,8 @@ class ConversationRuntime: | ||
| 56 | 56 | ) -> TurnSummary: |
| 57 | 57 | """Run one task turn and return a structured summary.""" |
| 58 | 58 | |
| 59 | + await self._prepare_runtime_capabilities() | |
| 60 | + | |
| 59 | 61 | iterations = 0 |
| 60 | 62 | final_response = "" |
| 61 | 63 | actions_taken: list[str] = [] |
@@ -690,6 +692,23 @@ class ConversationRuntime: | ||
| 690 | 692 | for key, value in update.items(): |
| 691 | 693 | target[key] = target.get(key, 0) + value |
| 692 | 694 | |
| 695 | + async def _prepare_runtime_capabilities(self) -> None: | |
| 696 | + describe_model = getattr(self.agent.backend, "describe_model", None) | |
| 697 | + if callable(describe_model): | |
| 698 | + await describe_model() | |
| 699 | + | |
| 700 | + previous_profile = self.agent.capability_profile | |
| 701 | + self.agent.refresh_capability_profile() | |
| 702 | + if self.agent.capability_profile != previous_profile: | |
| 703 | + self.tracer.record( | |
| 704 | + "runtime.capabilities_refreshed", | |
| 705 | + model_name=self.agent.capability_profile.model_name, | |
| 706 | + supports_native_tools=self.agent.capability_profile.supports_native_tools, | |
| 707 | + preferred_tool_call_format=( | |
| 708 | + self.agent.capability_profile.preferred_tool_call_format | |
| 709 | + ), | |
| 710 | + ) | |
| 711 | + | |
| 693 | 712 | @staticmethod |
| 694 | 713 | def _emit_confirmation(emit: EventSink): |
| 695 | 714 | async def _emit(tool_name: str, message: str, details: str) -> None: |