Python · 5790 bytes Raw Blame History
1 """Code search tools."""
2
3 import asyncio
4 import re
5 from pathlib import Path
6 from typing import Any
7
8 from ..runtime.permissions import PermissionMode
9 from .base import Tool, ToolResult
10 from .fs_safety import detect_binary_file, resolve_workspace_path
11
12
13 class GrepTool(Tool):
14 """Search for patterns in files."""
15
16 required_permission = PermissionMode.READ_ONLY
17
18 def __init__(self, workspace_root: Path | str | None = None) -> None:
19 self.workspace_root = (
20 Path(workspace_root).expanduser().resolve() if workspace_root else None
21 )
22
23 @property
24 def name(self) -> str:
25 return "grep"
26
27 def set_workspace_root(self, workspace_root: Path | None) -> None:
28 self.workspace_root = workspace_root
29
30 @property
31 def description(self) -> str:
32 return "Search for a regex pattern in files. Returns matching lines with file paths and line numbers."
33
34 @property
35 def parameters(self) -> dict[str, Any]:
36 return {
37 "type": "object",
38 "properties": {
39 "pattern": {
40 "type": "string",
41 "description": "Regex pattern to search for",
42 },
43 "path": {
44 "type": "string",
45 "description": "File or directory to search in (default: current directory)",
46 "default": ".",
47 },
48 "include": {
49 "type": "string",
50 "description": "Glob pattern for files to include (e.g., '*.py')",
51 },
52 "context": {
53 "type": "integer",
54 "description": "Number of context lines before and after match",
55 "default": 0,
56 },
57 "max_results": {
58 "type": "integer",
59 "description": "Maximum number of results to return",
60 "default": 50,
61 },
62 },
63 "required": ["pattern"],
64 }
65
66 async def execute(
67 self,
68 pattern: str,
69 path: str = ".",
70 include: str | None = None,
71 context: int = 0,
72 max_results: int = 50,
73 **kwargs: Any,
74 ) -> ToolResult:
75 try:
76 # Grep is read-only — don't enforce workspace boundary
77 base_path = resolve_workspace_path(
78 path,
79 workspace_root=None,
80 )
81 except FileNotFoundError:
82 return ToolResult(f"Path not found: {path}", is_error=True)
83 except Exception as exc:
84 return ToolResult(f"Error resolving search path: {exc}", is_error=True)
85
86 if not base_path.exists():
87 return ToolResult(f"Path not found: {path}", is_error=True)
88
89 try:
90 regex = re.compile(pattern)
91 except re.error as e:
92 return ToolResult(f"Invalid regex pattern: {e}", is_error=True)
93
94 # Collect files to search
95 files: list[Path] = []
96 if base_path.is_file():
97 files = [base_path]
98 else:
99 glob_pattern = include or "**/*"
100 for f in base_path.glob(glob_pattern):
101 if f.is_file():
102 # Skip binary files, hidden files, common non-code directories
103 if f.name.startswith("."):
104 continue
105 if any(part.startswith(".") for part in f.parts):
106 continue
107 if any(part in ("node_modules", "__pycache__", ".git", "venv", ".venv")
108 for part in f.parts):
109 continue
110 try:
111 if detect_binary_file(f):
112 continue
113 except OSError:
114 continue
115 files.append(f)
116
117 results: list[str] = []
118 total_matches = 0
119
120 for file_path in files:
121 if total_matches >= max_results:
122 break
123
124 try:
125 content = await asyncio.to_thread(file_path.read_text, errors="ignore")
126 lines = content.splitlines()
127
128 for i, line in enumerate(lines, 1):
129 if total_matches >= max_results:
130 break
131
132 if regex.search(line):
133 total_matches += 1
134
135 # Build result with context
136 result_lines = []
137
138 # Context before
139 for ctx_i in range(max(1, i - context), i):
140 if ctx_i > 0:
141 result_lines.append(f" {ctx_i}: {lines[ctx_i - 1]}")
142
143 # Match line (highlighted)
144 result_lines.append(f"> {i}: {line}")
145
146 # Context after
147 for ctx_i in range(i + 1, min(len(lines) + 1, i + context + 1)):
148 result_lines.append(f" {ctx_i}: {lines[ctx_i - 1]}")
149
150 results.append(f"{file_path}:\n" + "\n".join(result_lines))
151
152 except Exception:
153 # Skip files we can't read
154 continue
155
156 if not results:
157 return ToolResult(f"No matches found for pattern: {pattern}")
158
159 output = "\n\n".join(results)
160 if total_matches >= max_results:
161 output += f"\n\n... (showing first {max_results} matches)"
162
163 return ToolResult(
164 output,
165 metadata={
166 "path": str(base_path),
167 "num_files": len(files),
168 "num_matches": total_matches,
169 "max_results": max_results,
170 },
171 )