Python · 17823 bytes Raw Blame History
1 """Runtime-owned safeguard services shared by hooks and agent adapters."""
2
3 from __future__ import annotations
4
5 import re
6 from dataclasses import dataclass
7 from pathlib import Path
8
9
10 class ActionTracker:
11 """Tracks completed actions to prevent duplicates and detect loops."""
12
13 MAX_SEQUENCE_LENGTH = 20
14 LOOP_PATTERN_MIN = 2
15 LOOP_REPEAT_THRESHOLD = 2
16 MAX_RESPONSE_HISTORY = 5
17
18 def __init__(self) -> None:
19 self._file_writes: dict[str, list[str]] = {}
20 self._files_edited: dict[str, list[str]] = {}
21 self._commands_run: set[str] = set()
22 self._dirs_created: set[str] = set()
23 self._action_sequence: list[str] = []
24 self._response_history: list[str] = []
25
26 def reset(self) -> None:
27 self._file_writes.clear()
28 self._files_edited.clear()
29 self._commands_run.clear()
30 self._dirs_created.clear()
31 self._action_sequence.clear()
32 self._response_history.clear()
33
34 def _normalize_path(self, path: str) -> str:
35 expanded = Path(path).expanduser()
36 try:
37 return str(expanded.resolve())
38 except Exception:
39 return str(expanded)
40
41 @staticmethod
42 def _make_edit_signature(old_string: str, new_string: str) -> str:
43 return f"{hash(old_string)}:{hash(new_string)}"
44
45 @staticmethod
46 def _make_write_signature(content: str) -> str:
47 return str(hash(content))
48
49 def would_duplicate_file_create(self, file_path: str, content: str) -> bool:
50 norm_path = self._normalize_path(file_path)
51 sig = self._make_write_signature(content)
52 return sig in self._file_writes.get(norm_path, [])
53
54 def would_duplicate_edit(self, file_path: str, old_string: str, new_string: str) -> bool:
55 norm_path = self._normalize_path(file_path)
56 sig = self._make_edit_signature(old_string, new_string)
57 return sig in self._files_edited.get(norm_path, [])
58
59 def would_duplicate_patch(self, file_path: str, hunks: list[dict]) -> bool:
60 norm_path = self._normalize_path(file_path)
61 sig = str(hash(str(hunks)))
62 return sig in self._files_edited.get(norm_path, [])
63
64 def would_duplicate_command(self, command: str) -> bool:
65 norm_cmd = " ".join(command.split())
66 return norm_cmd in self._commands_run
67
68 def would_duplicate_mkdir(self, dir_path: str) -> bool:
69 norm_path = self._normalize_path(dir_path)
70 return norm_path in self._dirs_created
71
72 def record_file_create(self, file_path: str, content: str) -> None:
73 norm_path = self._normalize_path(file_path)
74 sig = self._make_write_signature(content)
75 self._file_writes.setdefault(norm_path, []).append(sig)
76
77 def record_edit(self, file_path: str, old_string: str, new_string: str) -> None:
78 norm_path = self._normalize_path(file_path)
79 sig = self._make_edit_signature(old_string, new_string)
80 self._files_edited.setdefault(norm_path, []).append(sig)
81
82 def record_command(self, command: str) -> None:
83 norm_cmd = " ".join(command.split())
84 self._commands_run.add(norm_cmd)
85
86 mkdir_match = re.match(r'mkdir\s+(-p\s+)?(.+)', norm_cmd)
87 if mkdir_match:
88 dir_path = mkdir_match.group(2).strip().strip('"\'')
89 self._dirs_created.add(self._normalize_path(dir_path))
90
91 def record_mkdir(self, dir_path: str) -> None:
92 self._dirs_created.add(self._normalize_path(dir_path))
93
94 def check_tool_call(self, tool_name: str, arguments: dict) -> tuple[bool, str]:
95 if tool_name == "write":
96 file_path = arguments.get("file_path", "")
97 content = arguments.get("content", "")
98 if self.would_duplicate_file_create(file_path, content):
99 return True, f"Same file content already written: {file_path}"
100
101 elif tool_name == "edit":
102 file_path = arguments.get("file_path", "")
103 old_string = arguments.get("old_string", "")
104 new_string = arguments.get("new_string", "")
105 if self.would_duplicate_edit(file_path, old_string, new_string):
106 return True, f"Same edit already applied to: {file_path}"
107
108 elif tool_name == "patch":
109 file_path = arguments.get("file_path", "")
110 hunks = arguments.get("hunks", [])
111 if isinstance(hunks, list) and self.would_duplicate_patch(file_path, hunks):
112 return True, f"Same patch already applied to: {file_path}"
113
114 # Bash commands intentionally skip exact-command dedupe here.
115 # Re-running the same shell probe after a filesystem change is often valid,
116 # and higher-level loop detection is a safer backstop than blocking `ls`.
117 return False, ""
118
119 def record_tool_call(self, tool_name: str, arguments: dict) -> None:
120 self._action_sequence.append(tool_name)
121 if len(self._action_sequence) > self.MAX_SEQUENCE_LENGTH:
122 self._action_sequence.pop(0)
123
124 if tool_name == "write":
125 file_path = arguments.get("file_path", "")
126 content = arguments.get("content", "")
127 if file_path:
128 self.record_file_create(file_path, content)
129
130 elif tool_name == "edit":
131 file_path = arguments.get("file_path", "")
132 old_string = arguments.get("old_string", "")
133 new_string = arguments.get("new_string", "")
134 if file_path:
135 self.record_edit(file_path, old_string, new_string)
136
137 elif tool_name == "patch":
138 file_path = arguments.get("file_path", "")
139 hunks = arguments.get("hunks", [])
140 if file_path:
141 self.record_edit(file_path, str(hunks), "structured_patch")
142
143 elif tool_name == "bash":
144 command = arguments.get("command", "")
145 if command:
146 self.record_command(command)
147
148 def detect_loop(self) -> tuple[bool, str]:
149 seq = self._action_sequence
150 if len(seq) < self.LOOP_PATTERN_MIN * self.LOOP_REPEAT_THRESHOLD:
151 return False, ""
152
153 for pattern_len in range(self.LOOP_PATTERN_MIN, min(6, len(seq) // 2 + 1)):
154 pattern = seq[-pattern_len:]
155 repeats = 1
156 for i in range(len(seq) - pattern_len * 2, -1, -pattern_len):
157 if seq[i:i + pattern_len] == pattern:
158 repeats += 1
159 else:
160 break
161
162 if repeats >= self.LOOP_REPEAT_THRESHOLD:
163 pattern_str = " → ".join(pattern)
164 return True, f"Repeating pattern detected ({repeats}x): {pattern_str}"
165
166 return False, ""
167
168 @staticmethod
169 def _normalize_response(response: str) -> str:
170 normalized = response.strip().lower()[:200]
171 normalized = re.sub(r'/[\w/.-]+', '<PATH>', normalized)
172 normalized = re.sub(r'\d+', '<NUM>', normalized)
173 return normalized
174
175 def record_response(self, response: str) -> None:
176 normalized = self._normalize_response(response)
177 self._response_history.append(normalized)
178 if len(self._response_history) > self.MAX_RESPONSE_HISTORY:
179 self._response_history.pop(0)
180
181 def detect_text_loop(self, response: str) -> tuple[bool, str]:
182 if len(self._response_history) < 2:
183 return False, ""
184
185 normalized = self._normalize_response(response)
186 exact_matches = sum(1 for r in self._response_history if r == normalized)
187 if exact_matches >= 2:
188 return True, f"Agent repeated the same response {exact_matches + 1} times"
189
190 repetitive_phrases = [
191 "apologies for any confusion",
192 "let me proceed",
193 "i will now use the",
194 ]
195 response_lower = response.lower()
196 for phrase in repetitive_phrases:
197 if phrase in response_lower:
198 phrase_count = sum(1 for r in self._response_history if phrase in r)
199 if phrase_count >= 2:
200 return True, f"Agent is stuck repeating '{phrase}'"
201
202 current_words = set(normalized.split())
203 similarity_matches = 0
204 for prev in self._response_history[-3:]:
205 prev_words = set(prev.split())
206 if len(current_words) > 10 and len(prev_words) > 10:
207 overlap = len(current_words & prev_words)
208 similarity = overlap / max(len(current_words), len(prev_words))
209 if similarity > 0.85:
210 similarity_matches += 1
211
212 if similarity_matches >= 2:
213 return True, "Agent responses are highly repetitive"
214
215 return False, ""
216
217 def reset_response_history(self) -> None:
218 """Clear response history between turns to prevent cross-turn false positives."""
219 self._response_history.clear()
220
221
222 @dataclass
223 class ValidationResult:
224 """Result of pre-action validation."""
225
226 valid: bool
227 reason: str = ""
228 suggestion: str = ""
229 severity: str = "warning"
230
231
232 class PreActionValidator:
233 """Validates tool calls before execution to catch problematic actions."""
234
235 DANGEROUS_PATTERNS = [
236 (r'rm\s+(-[rf]+\s+)?/', "Dangerous: removing from root directory"),
237 (r'rm\s+-rf\s+~', "Dangerous: removing home directory"),
238 (r'>\s*/dev/sd[a-z]', "Dangerous: writing directly to disk device"),
239 (r'mkfs\.', "Dangerous: formatting filesystem"),
240 (r'dd\s+.*of=/dev/', "Dangerous: dd to device"),
241 (r'chmod\s+-R\s+777\s+/', "Dangerous: making everything world-writable"),
242 (r':\(\)\s*\{\s*:\|:\s*&\s*\}\s*;', "Dangerous: fork bomb"),
243 ]
244
245 SUSPICIOUS_PATTERNS = [
246 (r'rm\s+-rf\s+', "Warning: recursive force delete"),
247 (r'>\s*/etc/', "Warning: overwriting system config"),
248 (r'curl\s+.*\|\s*sh', "Warning: piping curl to shell"),
249 (r'wget\s+.*\|\s*sh', "Warning: piping wget to shell"),
250 (r'eval\s+', "Warning: using eval"),
251 (r'sudo\s+', "Warning: using sudo"),
252 ]
253
254 def validate(self, tool_name: str, arguments: dict) -> ValidationResult:
255 if tool_name == "bash":
256 return self._validate_bash(arguments)
257 if tool_name == "write":
258 return self._validate_write(arguments)
259 if tool_name == "edit":
260 return self._validate_edit(arguments)
261 if tool_name == "patch":
262 return self._validate_patch(arguments)
263 if tool_name == "read":
264 return self._validate_read(arguments)
265 if tool_name in ("glob", "grep"):
266 return self._validate_search(tool_name, arguments)
267 return ValidationResult(valid=True)
268
269 def _validate_bash(self, arguments: dict) -> ValidationResult:
270 command = arguments.get("command", "")
271
272 if not command or not command.strip():
273 return ValidationResult(
274 valid=False,
275 reason="Empty command",
276 suggestion="Provide a valid command to execute",
277 severity="error",
278 )
279
280 for pattern, reason in self.DANGEROUS_PATTERNS:
281 if re.search(pattern, command):
282 return ValidationResult(
283 valid=False,
284 reason=reason,
285 suggestion="This command is too dangerous to execute",
286 severity="block",
287 )
288
289 for pattern, reason in self.SUSPICIOUS_PATTERNS:
290 if re.search(pattern, command):
291 return ValidationResult(valid=True, reason=reason, severity="warning")
292
293 interactive_patterns = [
294 (r'\bnano\b', "nano requires interactive terminal"),
295 (r'\bvim?\b', "vim requires interactive terminal"),
296 (r'\bemacs\b', "emacs requires interactive terminal"),
297 (r'\bless\b', "less requires interactive terminal"),
298 (r'\bmore\b', "more requires interactive terminal"),
299 (r'\btop\b', "top requires interactive terminal"),
300 (r'\bhtop\b', "htop requires interactive terminal"),
301 ]
302 for pattern, reason in interactive_patterns:
303 if re.search(pattern, command):
304 return ValidationResult(
305 valid=False,
306 reason=reason,
307 suggestion=(
308 "Use non-interactive alternatives (cat, head, tail for viewing; "
309 "sed for editing)"
310 ),
311 severity="error",
312 )
313
314 return ValidationResult(valid=True)
315
316 def _validate_write(self, arguments: dict) -> ValidationResult:
317 file_path = arguments.get("file_path", "")
318 content = arguments.get("content", "")
319
320 if not file_path or not file_path.strip():
321 return ValidationResult(
322 valid=False,
323 reason="Empty file path",
324 suggestion="Provide a valid file path",
325 severity="error",
326 )
327
328 path_result = self._validate_path(file_path)
329 if not path_result.valid:
330 return path_result
331
332 if content is None or (isinstance(content, str) and not content.strip()):
333 return ValidationResult(
334 valid=True,
335 reason="Writing empty content to file",
336 severity="warning",
337 )
338
339 sensitive_paths = ['/etc/', '/usr/', '/bin/', '/sbin/', '/boot/', '/sys/', '/proc/']
340 for sensitive in sensitive_paths:
341 if file_path.startswith(sensitive):
342 return ValidationResult(
343 valid=False,
344 reason=f"Cannot write to system directory: {sensitive}",
345 suggestion="Write to a user directory instead",
346 severity="block",
347 )
348
349 return ValidationResult(valid=True)
350
351 def _validate_edit(self, arguments: dict) -> ValidationResult:
352 file_path = arguments.get("file_path", "")
353 old_string = arguments.get("old_string", "")
354 new_string = arguments.get("new_string", "")
355
356 if not file_path or not file_path.strip():
357 return ValidationResult(
358 valid=False,
359 reason="Empty file path",
360 suggestion="Provide a valid file path",
361 severity="error",
362 )
363
364 path_result = self._validate_path(file_path)
365 if not path_result.valid:
366 return path_result
367
368 if old_string is None:
369 return ValidationResult(
370 valid=False,
371 reason="old_string is None",
372 suggestion="Provide the text to replace (can be empty string for prepend)",
373 severity="error",
374 )
375
376 if new_string is None:
377 return ValidationResult(
378 valid=False,
379 reason="new_string is None",
380 suggestion="Provide the replacement text (can be empty string for deletion)",
381 severity="error",
382 )
383
384 if old_string == new_string:
385 return ValidationResult(
386 valid=False,
387 reason="old_string and new_string are identical - no change would occur",
388 suggestion="Provide different old and new strings",
389 severity="error",
390 )
391
392 return ValidationResult(valid=True)
393
394 def _validate_patch(self, arguments: dict) -> ValidationResult:
395 file_path = arguments.get("file_path", "")
396 hunks = arguments.get("hunks", [])
397
398 if not file_path or not str(file_path).strip():
399 return ValidationResult(
400 valid=False,
401 reason="Empty file path",
402 suggestion="Provide a valid file path",
403 severity="error",
404 )
405
406 path_result = self._validate_path(str(file_path))
407 if not path_result.valid:
408 return path_result
409
410 if not isinstance(hunks, list) or not hunks:
411 return ValidationResult(
412 valid=False,
413 reason="Patch hunks are missing",
414 suggestion="Provide one or more structured patch hunks",
415 severity="error",
416 )
417
418 return ValidationResult(valid=True)
419
420 def _validate_read(self, arguments: dict) -> ValidationResult:
421 file_path = arguments.get("file_path", "")
422
423 if not file_path or not file_path.strip():
424 return ValidationResult(
425 valid=False,
426 reason="Empty file path",
427 suggestion="Provide a valid file path",
428 severity="error",
429 )
430
431 return self._validate_path(file_path)
432
433 def _validate_search(self, tool_name: str, arguments: dict) -> ValidationResult:
434 pattern = arguments.get("pattern", "")
435
436 if not pattern or not pattern.strip():
437 return ValidationResult(
438 valid=False,
439 reason=f"Empty {tool_name} pattern",
440 suggestion="Provide a valid search pattern",
441 severity="error",
442 )
443
444 return ValidationResult(valid=True)
445
446 def _validate_path(self, file_path: str) -> ValidationResult:
447 if '\x00' in file_path:
448 return ValidationResult(
449 valid=False,
450 reason="Path contains null byte",
451 suggestion="Remove null bytes from path",
452 severity="block",
453 )
454
455 if '/../../../' in file_path or file_path.count('..') > 5:
456 return ValidationResult(
457 valid=False,
458 reason="Excessive path traversal",
459 suggestion="Use a direct path instead",
460 severity="warning",
461 )
462
463 return ValidationResult(valid=True)