| 1 | """Base classes for the tool system.""" |
| 2 | |
| 3 | from abc import ABC, abstractmethod |
| 4 | from dataclasses import dataclass, field |
| 5 | from pathlib import Path |
| 6 | from typing import Any |
| 7 | |
| 8 | from ..runtime.permissions import PermissionMode |
| 9 | |
| 10 | |
| 11 | class ConfirmationRequired(Exception): |
| 12 | """Raised when a tool requires user confirmation before execution.""" |
| 13 | |
| 14 | def __init__(self, tool_name: str, message: str, details: str = ""): |
| 15 | self.tool_name = tool_name |
| 16 | self.message = message |
| 17 | self.details = details |
| 18 | super().__init__(message) |
| 19 | |
| 20 | |
| 21 | @dataclass |
| 22 | class ToolResult: |
| 23 | """Result of a tool execution.""" |
| 24 | |
| 25 | output: str |
| 26 | is_error: bool = False |
| 27 | metadata: dict[str, Any] = field(default_factory=dict) |
| 28 | |
| 29 | |
| 30 | class Tool(ABC): |
| 31 | """Abstract base class for tools.""" |
| 32 | |
| 33 | required_permission: PermissionMode = PermissionMode.DANGER_FULL_ACCESS |
| 34 | |
| 35 | @property |
| 36 | @abstractmethod |
| 37 | def name(self) -> str: |
| 38 | """Tool name used in function calls.""" |
| 39 | ... |
| 40 | |
| 41 | @property |
| 42 | @abstractmethod |
| 43 | def description(self) -> str: |
| 44 | """Description of what the tool does.""" |
| 45 | ... |
| 46 | |
| 47 | @property |
| 48 | @abstractmethod |
| 49 | def parameters(self) -> dict[str, Any]: |
| 50 | """JSON Schema for tool parameters.""" |
| 51 | ... |
| 52 | |
| 53 | @property |
| 54 | def is_destructive(self) -> bool: |
| 55 | """Whether this tool can modify files or run commands. |
| 56 | |
| 57 | Override in subclasses for tools that modify state. |
| 58 | """ |
| 59 | return False |
| 60 | |
| 61 | def set_workspace_root(self, workspace_root: Path | None) -> None: |
| 62 | """Update the workspace root used by the tool, if applicable.""" |
| 63 | return None |
| 64 | |
| 65 | def get_required_permission(self, **kwargs: Any) -> PermissionMode: |
| 66 | """Return the required permission for this invocation.""" |
| 67 | return self.required_permission |
| 68 | |
| 69 | def check_confirmation(self, skip_confirmation: bool = False, **kwargs: Any) -> None: |
| 70 | """Check if this operation requires confirmation. |
| 71 | |
| 72 | Args: |
| 73 | skip_confirmation: If True, skip the confirmation check |
| 74 | **kwargs: Tool arguments for context in confirmation message |
| 75 | |
| 76 | Raises: |
| 77 | ConfirmationRequired: If user confirmation is needed |
| 78 | """ |
| 79 | pass # Default: no confirmation needed |
| 80 | |
| 81 | @abstractmethod |
| 82 | async def execute(self, **kwargs: Any) -> ToolResult: |
| 83 | """Execute the tool with given parameters.""" |
| 84 | ... |
| 85 | |
| 86 | def to_schema(self) -> dict[str, Any]: |
| 87 | """Convert tool to JSON schema for LLM.""" |
| 88 | return { |
| 89 | "name": self.name, |
| 90 | "description": self.description, |
| 91 | "parameters": self.parameters, |
| 92 | } |
| 93 | |
| 94 | |
| 95 | class ToolRegistry: |
| 96 | """Registry of available tools.""" |
| 97 | |
| 98 | def __init__( |
| 99 | self, |
| 100 | skip_confirmation: bool = False, |
| 101 | workspace_root: Path | str | None = None, |
| 102 | ) -> None: |
| 103 | self._tools: dict[str, Tool] = {} |
| 104 | self.skip_confirmation = skip_confirmation |
| 105 | self.workspace_root = ( |
| 106 | Path(workspace_root).expanduser().resolve() if workspace_root else None |
| 107 | ) |
| 108 | |
| 109 | def register(self, tool: Tool) -> None: |
| 110 | """Register a tool.""" |
| 111 | if self.workspace_root is not None: |
| 112 | tool.set_workspace_root(self.workspace_root) |
| 113 | self._tools[tool.name] = tool |
| 114 | |
| 115 | def get(self, name: str) -> Tool | None: |
| 116 | """Get a tool by name.""" |
| 117 | return self._tools.get(name) |
| 118 | |
| 119 | def list_tools(self) -> list[Tool]: |
| 120 | """List all registered tools.""" |
| 121 | return list(self._tools.values()) |
| 122 | |
| 123 | def get_schemas(self) -> list[dict[str, Any]]: |
| 124 | """Get JSON schemas for all tools.""" |
| 125 | return [tool.to_schema() for tool in self._tools.values()] |
| 126 | |
| 127 | def configure_workspace_root(self, workspace_root: Path | str | None) -> None: |
| 128 | """Update the workspace root for the registry and registered tools.""" |
| 129 | self.workspace_root = ( |
| 130 | Path(workspace_root).expanduser().resolve() if workspace_root else None |
| 131 | ) |
| 132 | for tool in self._tools.values(): |
| 133 | tool.set_workspace_root(self.workspace_root) |
| 134 | |
| 135 | def get_tool_requirements(self) -> dict[str, PermissionMode]: |
| 136 | """Return the default permission requirement for each tool.""" |
| 137 | return { |
| 138 | tool.name: tool.required_permission |
| 139 | for tool in self._tools.values() |
| 140 | } |
| 141 | |
| 142 | async def execute(self, name: str, **kwargs: Any) -> ToolResult: |
| 143 | """Execute a tool by name. |
| 144 | |
| 145 | Args: |
| 146 | name: Tool name to execute |
| 147 | **kwargs: Arguments to pass to the tool |
| 148 | |
| 149 | Returns: |
| 150 | ToolResult with output |
| 151 | |
| 152 | Raises: |
| 153 | ConfirmationRequired: If tool needs confirmation and skip_confirmation is False |
| 154 | """ |
| 155 | tool = self.get(name) |
| 156 | if tool is None: |
| 157 | return ToolResult( |
| 158 | output=f"Unknown tool: {name}", |
| 159 | is_error=True, |
| 160 | ) |
| 161 | try: |
| 162 | # Check for confirmation (may raise ConfirmationRequired) |
| 163 | tool.check_confirmation( |
| 164 | skip_confirmation=self.skip_confirmation, |
| 165 | **kwargs, |
| 166 | ) |
| 167 | return await tool.execute(**kwargs) |
| 168 | except ConfirmationRequired: |
| 169 | raise # Re-raise confirmation requests |
| 170 | except Exception as e: |
| 171 | return ToolResult( |
| 172 | output=f"Tool execution error: {e}", |
| 173 | is_error=True, |
| 174 | ) |
| 175 | |
| 176 | |
| 177 | def create_default_registry( |
| 178 | workspace_root: Path | str | None = None, |
| 179 | ) -> ToolRegistry: |
| 180 | """Create a registry with default tools.""" |
| 181 | from .file_tools import EditTool, GlobTool, ReadTool, WriteTool |
| 182 | from .search_tools import GrepTool |
| 183 | from .shell_tools import BashTool |
| 184 | from .workflow_tools import AskUserQuestionTool, TodoWriteTool |
| 185 | |
| 186 | registry = ToolRegistry(workspace_root=workspace_root) |
| 187 | registry.register(ReadTool()) |
| 188 | registry.register(WriteTool()) |
| 189 | registry.register(EditTool()) |
| 190 | registry.register(GlobTool()) |
| 191 | registry.register(BashTool()) |
| 192 | registry.register(GrepTool()) |
| 193 | registry.register(TodoWriteTool()) |
| 194 | registry.register(AskUserQuestionTool()) |
| 195 | |
| 196 | return registry |