| 1 | """Runtime-owned safeguard services shared by hooks and agent adapters.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import re |
| 6 | from dataclasses import dataclass |
| 7 | from pathlib import Path |
| 8 | |
| 9 | |
| 10 | class ActionTracker: |
| 11 | """Tracks completed actions to prevent duplicates and detect loops.""" |
| 12 | |
| 13 | MAX_SEQUENCE_LENGTH = 20 |
| 14 | LOOP_PATTERN_MIN = 2 |
| 15 | LOOP_REPEAT_THRESHOLD = 2 |
| 16 | MAX_RESPONSE_HISTORY = 5 |
| 17 | |
| 18 | def __init__(self) -> None: |
| 19 | self._file_writes: dict[str, list[str]] = {} |
| 20 | self._files_edited: dict[str, list[str]] = {} |
| 21 | self._commands_run: set[str] = set() |
| 22 | self._dirs_created: set[str] = set() |
| 23 | self._action_sequence: list[str] = [] |
| 24 | self._response_history: list[str] = [] |
| 25 | |
| 26 | def reset(self) -> None: |
| 27 | self._file_writes.clear() |
| 28 | self._files_edited.clear() |
| 29 | self._commands_run.clear() |
| 30 | self._dirs_created.clear() |
| 31 | self._action_sequence.clear() |
| 32 | self._response_history.clear() |
| 33 | |
| 34 | def _normalize_path(self, path: str) -> str: |
| 35 | expanded = Path(path).expanduser() |
| 36 | try: |
| 37 | return str(expanded.resolve()) |
| 38 | except Exception: |
| 39 | return str(expanded) |
| 40 | |
| 41 | @staticmethod |
| 42 | def _make_edit_signature(old_string: str, new_string: str) -> str: |
| 43 | return f"{hash(old_string)}:{hash(new_string)}" |
| 44 | |
| 45 | @staticmethod |
| 46 | def _make_write_signature(content: str) -> str: |
| 47 | return str(hash(content)) |
| 48 | |
| 49 | def would_duplicate_file_create(self, file_path: str, content: str) -> bool: |
| 50 | norm_path = self._normalize_path(file_path) |
| 51 | sig = self._make_write_signature(content) |
| 52 | return sig in self._file_writes.get(norm_path, []) |
| 53 | |
| 54 | def would_duplicate_edit(self, file_path: str, old_string: str, new_string: str) -> bool: |
| 55 | norm_path = self._normalize_path(file_path) |
| 56 | sig = self._make_edit_signature(old_string, new_string) |
| 57 | return sig in self._files_edited.get(norm_path, []) |
| 58 | |
| 59 | def would_duplicate_patch(self, file_path: str, hunks: list[dict]) -> bool: |
| 60 | norm_path = self._normalize_path(file_path) |
| 61 | sig = str(hash(str(hunks))) |
| 62 | return sig in self._files_edited.get(norm_path, []) |
| 63 | |
| 64 | def would_duplicate_command(self, command: str) -> bool: |
| 65 | norm_cmd = " ".join(command.split()) |
| 66 | return norm_cmd in self._commands_run |
| 67 | |
| 68 | def would_duplicate_mkdir(self, dir_path: str) -> bool: |
| 69 | norm_path = self._normalize_path(dir_path) |
| 70 | return norm_path in self._dirs_created |
| 71 | |
| 72 | def record_file_create(self, file_path: str, content: str) -> None: |
| 73 | norm_path = self._normalize_path(file_path) |
| 74 | sig = self._make_write_signature(content) |
| 75 | self._file_writes.setdefault(norm_path, []).append(sig) |
| 76 | |
| 77 | def record_edit(self, file_path: str, old_string: str, new_string: str) -> None: |
| 78 | norm_path = self._normalize_path(file_path) |
| 79 | sig = self._make_edit_signature(old_string, new_string) |
| 80 | self._files_edited.setdefault(norm_path, []).append(sig) |
| 81 | |
| 82 | def record_command(self, command: str) -> None: |
| 83 | norm_cmd = " ".join(command.split()) |
| 84 | self._commands_run.add(norm_cmd) |
| 85 | |
| 86 | mkdir_match = re.match(r'mkdir\s+(-p\s+)?(.+)', norm_cmd) |
| 87 | if mkdir_match: |
| 88 | dir_path = mkdir_match.group(2).strip().strip('"\'') |
| 89 | self._dirs_created.add(self._normalize_path(dir_path)) |
| 90 | |
| 91 | def record_mkdir(self, dir_path: str) -> None: |
| 92 | self._dirs_created.add(self._normalize_path(dir_path)) |
| 93 | |
| 94 | def check_tool_call(self, tool_name: str, arguments: dict) -> tuple[bool, str]: |
| 95 | if tool_name == "write": |
| 96 | file_path = arguments.get("file_path", "") |
| 97 | content = arguments.get("content", "") |
| 98 | if self.would_duplicate_file_create(file_path, content): |
| 99 | return True, f"Same file content already written: {file_path}" |
| 100 | |
| 101 | elif tool_name == "edit": |
| 102 | file_path = arguments.get("file_path", "") |
| 103 | old_string = arguments.get("old_string", "") |
| 104 | new_string = arguments.get("new_string", "") |
| 105 | if self.would_duplicate_edit(file_path, old_string, new_string): |
| 106 | return True, f"Same edit already applied to: {file_path}" |
| 107 | |
| 108 | elif tool_name == "patch": |
| 109 | file_path = arguments.get("file_path", "") |
| 110 | hunks = arguments.get("hunks", []) |
| 111 | if isinstance(hunks, list) and self.would_duplicate_patch(file_path, hunks): |
| 112 | return True, f"Same patch already applied to: {file_path}" |
| 113 | |
| 114 | # Bash commands intentionally skip exact-command dedupe here. |
| 115 | # Re-running the same shell probe after a filesystem change is often valid, |
| 116 | # and higher-level loop detection is a safer backstop than blocking `ls`. |
| 117 | return False, "" |
| 118 | |
| 119 | def record_tool_call(self, tool_name: str, arguments: dict) -> None: |
| 120 | self._action_sequence.append(tool_name) |
| 121 | if len(self._action_sequence) > self.MAX_SEQUENCE_LENGTH: |
| 122 | self._action_sequence.pop(0) |
| 123 | |
| 124 | if tool_name == "write": |
| 125 | file_path = arguments.get("file_path", "") |
| 126 | content = arguments.get("content", "") |
| 127 | if file_path: |
| 128 | self.record_file_create(file_path, content) |
| 129 | |
| 130 | elif tool_name == "edit": |
| 131 | file_path = arguments.get("file_path", "") |
| 132 | old_string = arguments.get("old_string", "") |
| 133 | new_string = arguments.get("new_string", "") |
| 134 | if file_path: |
| 135 | self.record_edit(file_path, old_string, new_string) |
| 136 | |
| 137 | elif tool_name == "patch": |
| 138 | file_path = arguments.get("file_path", "") |
| 139 | hunks = arguments.get("hunks", []) |
| 140 | if file_path: |
| 141 | self.record_edit(file_path, str(hunks), "structured_patch") |
| 142 | |
| 143 | elif tool_name == "bash": |
| 144 | command = arguments.get("command", "") |
| 145 | if command: |
| 146 | self.record_command(command) |
| 147 | |
| 148 | def detect_loop(self) -> tuple[bool, str]: |
| 149 | seq = self._action_sequence |
| 150 | if len(seq) < self.LOOP_PATTERN_MIN * self.LOOP_REPEAT_THRESHOLD: |
| 151 | return False, "" |
| 152 | |
| 153 | for pattern_len in range(self.LOOP_PATTERN_MIN, min(6, len(seq) // 2 + 1)): |
| 154 | pattern = seq[-pattern_len:] |
| 155 | repeats = 1 |
| 156 | for i in range(len(seq) - pattern_len * 2, -1, -pattern_len): |
| 157 | if seq[i:i + pattern_len] == pattern: |
| 158 | repeats += 1 |
| 159 | else: |
| 160 | break |
| 161 | |
| 162 | if repeats >= self.LOOP_REPEAT_THRESHOLD: |
| 163 | pattern_str = " → ".join(pattern) |
| 164 | return True, f"Repeating pattern detected ({repeats}x): {pattern_str}" |
| 165 | |
| 166 | return False, "" |
| 167 | |
| 168 | @staticmethod |
| 169 | def _normalize_response(response: str) -> str: |
| 170 | normalized = response.strip().lower()[:200] |
| 171 | normalized = re.sub(r'/[\w/.-]+', '<PATH>', normalized) |
| 172 | normalized = re.sub(r'\d+', '<NUM>', normalized) |
| 173 | return normalized |
| 174 | |
| 175 | def record_response(self, response: str) -> None: |
| 176 | normalized = self._normalize_response(response) |
| 177 | self._response_history.append(normalized) |
| 178 | if len(self._response_history) > self.MAX_RESPONSE_HISTORY: |
| 179 | self._response_history.pop(0) |
| 180 | |
| 181 | def detect_text_loop(self, response: str) -> tuple[bool, str]: |
| 182 | if len(self._response_history) < 2: |
| 183 | return False, "" |
| 184 | |
| 185 | normalized = self._normalize_response(response) |
| 186 | exact_matches = sum(1 for r in self._response_history if r == normalized) |
| 187 | if exact_matches >= 2: |
| 188 | return True, f"Agent repeated the same response {exact_matches + 1} times" |
| 189 | |
| 190 | repetitive_phrases = [ |
| 191 | "apologies for any confusion", |
| 192 | "let me proceed", |
| 193 | "i will now use the", |
| 194 | ] |
| 195 | response_lower = response.lower() |
| 196 | for phrase in repetitive_phrases: |
| 197 | if phrase in response_lower: |
| 198 | phrase_count = sum(1 for r in self._response_history if phrase in r) |
| 199 | if phrase_count >= 2: |
| 200 | return True, f"Agent is stuck repeating '{phrase}'" |
| 201 | |
| 202 | current_words = set(normalized.split()) |
| 203 | similarity_matches = 0 |
| 204 | for prev in self._response_history[-3:]: |
| 205 | prev_words = set(prev.split()) |
| 206 | if len(current_words) > 10 and len(prev_words) > 10: |
| 207 | overlap = len(current_words & prev_words) |
| 208 | similarity = overlap / max(len(current_words), len(prev_words)) |
| 209 | if similarity > 0.85: |
| 210 | similarity_matches += 1 |
| 211 | |
| 212 | if similarity_matches >= 2: |
| 213 | return True, "Agent responses are highly repetitive" |
| 214 | |
| 215 | return False, "" |
| 216 | |
| 217 | def reset_response_history(self) -> None: |
| 218 | """Clear response history between turns to prevent cross-turn false positives.""" |
| 219 | self._response_history.clear() |
| 220 | |
| 221 | |
| 222 | @dataclass |
| 223 | class ValidationResult: |
| 224 | """Result of pre-action validation.""" |
| 225 | |
| 226 | valid: bool |
| 227 | reason: str = "" |
| 228 | suggestion: str = "" |
| 229 | severity: str = "warning" |
| 230 | |
| 231 | |
| 232 | class PreActionValidator: |
| 233 | """Validates tool calls before execution to catch problematic actions.""" |
| 234 | |
| 235 | DANGEROUS_PATTERNS = [ |
| 236 | (r'rm\s+(-[rf]+\s+)?/', "Dangerous: removing from root directory"), |
| 237 | (r'rm\s+-rf\s+~', "Dangerous: removing home directory"), |
| 238 | (r'>\s*/dev/sd[a-z]', "Dangerous: writing directly to disk device"), |
| 239 | (r'mkfs\.', "Dangerous: formatting filesystem"), |
| 240 | (r'dd\s+.*of=/dev/', "Dangerous: dd to device"), |
| 241 | (r'chmod\s+-R\s+777\s+/', "Dangerous: making everything world-writable"), |
| 242 | (r':\(\)\s*\{\s*:\|:\s*&\s*\}\s*;', "Dangerous: fork bomb"), |
| 243 | ] |
| 244 | |
| 245 | SUSPICIOUS_PATTERNS = [ |
| 246 | (r'rm\s+-rf\s+', "Warning: recursive force delete"), |
| 247 | (r'>\s*/etc/', "Warning: overwriting system config"), |
| 248 | (r'curl\s+.*\|\s*sh', "Warning: piping curl to shell"), |
| 249 | (r'wget\s+.*\|\s*sh', "Warning: piping wget to shell"), |
| 250 | (r'eval\s+', "Warning: using eval"), |
| 251 | (r'sudo\s+', "Warning: using sudo"), |
| 252 | ] |
| 253 | |
| 254 | def validate(self, tool_name: str, arguments: dict) -> ValidationResult: |
| 255 | if tool_name == "bash": |
| 256 | return self._validate_bash(arguments) |
| 257 | if tool_name == "write": |
| 258 | return self._validate_write(arguments) |
| 259 | if tool_name == "edit": |
| 260 | return self._validate_edit(arguments) |
| 261 | if tool_name == "patch": |
| 262 | return self._validate_patch(arguments) |
| 263 | if tool_name == "read": |
| 264 | return self._validate_read(arguments) |
| 265 | if tool_name in ("glob", "grep"): |
| 266 | return self._validate_search(tool_name, arguments) |
| 267 | return ValidationResult(valid=True) |
| 268 | |
| 269 | def _validate_bash(self, arguments: dict) -> ValidationResult: |
| 270 | command = arguments.get("command", "") |
| 271 | |
| 272 | if not command or not command.strip(): |
| 273 | return ValidationResult( |
| 274 | valid=False, |
| 275 | reason="Empty command", |
| 276 | suggestion="Provide a valid command to execute", |
| 277 | severity="error", |
| 278 | ) |
| 279 | |
| 280 | for pattern, reason in self.DANGEROUS_PATTERNS: |
| 281 | if re.search(pattern, command): |
| 282 | return ValidationResult( |
| 283 | valid=False, |
| 284 | reason=reason, |
| 285 | suggestion="This command is too dangerous to execute", |
| 286 | severity="block", |
| 287 | ) |
| 288 | |
| 289 | for pattern, reason in self.SUSPICIOUS_PATTERNS: |
| 290 | if re.search(pattern, command): |
| 291 | return ValidationResult(valid=True, reason=reason, severity="warning") |
| 292 | |
| 293 | interactive_patterns = [ |
| 294 | (r'\bnano\b', "nano requires interactive terminal"), |
| 295 | (r'\bvim?\b', "vim requires interactive terminal"), |
| 296 | (r'\bemacs\b', "emacs requires interactive terminal"), |
| 297 | (r'\bless\b', "less requires interactive terminal"), |
| 298 | (r'\bmore\b', "more requires interactive terminal"), |
| 299 | (r'\btop\b', "top requires interactive terminal"), |
| 300 | (r'\bhtop\b', "htop requires interactive terminal"), |
| 301 | ] |
| 302 | for pattern, reason in interactive_patterns: |
| 303 | if re.search(pattern, command): |
| 304 | return ValidationResult( |
| 305 | valid=False, |
| 306 | reason=reason, |
| 307 | suggestion=( |
| 308 | "Use non-interactive alternatives (cat, head, tail for viewing; " |
| 309 | "sed for editing)" |
| 310 | ), |
| 311 | severity="error", |
| 312 | ) |
| 313 | |
| 314 | return ValidationResult(valid=True) |
| 315 | |
| 316 | def _validate_write(self, arguments: dict) -> ValidationResult: |
| 317 | file_path = arguments.get("file_path", "") |
| 318 | content = arguments.get("content", "") |
| 319 | |
| 320 | if not file_path or not file_path.strip(): |
| 321 | return ValidationResult( |
| 322 | valid=False, |
| 323 | reason="Empty file path", |
| 324 | suggestion="Provide a valid file path", |
| 325 | severity="error", |
| 326 | ) |
| 327 | |
| 328 | path_result = self._validate_path(file_path) |
| 329 | if not path_result.valid: |
| 330 | return path_result |
| 331 | |
| 332 | if content is None or (isinstance(content, str) and not content.strip()): |
| 333 | return ValidationResult( |
| 334 | valid=True, |
| 335 | reason="Writing empty content to file", |
| 336 | severity="warning", |
| 337 | ) |
| 338 | |
| 339 | sensitive_paths = ['/etc/', '/usr/', '/bin/', '/sbin/', '/boot/', '/sys/', '/proc/'] |
| 340 | for sensitive in sensitive_paths: |
| 341 | if file_path.startswith(sensitive): |
| 342 | return ValidationResult( |
| 343 | valid=False, |
| 344 | reason=f"Cannot write to system directory: {sensitive}", |
| 345 | suggestion="Write to a user directory instead", |
| 346 | severity="block", |
| 347 | ) |
| 348 | |
| 349 | return ValidationResult(valid=True) |
| 350 | |
| 351 | def _validate_edit(self, arguments: dict) -> ValidationResult: |
| 352 | file_path = arguments.get("file_path", "") |
| 353 | old_string = arguments.get("old_string", "") |
| 354 | new_string = arguments.get("new_string", "") |
| 355 | |
| 356 | if not file_path or not file_path.strip(): |
| 357 | return ValidationResult( |
| 358 | valid=False, |
| 359 | reason="Empty file path", |
| 360 | suggestion="Provide a valid file path", |
| 361 | severity="error", |
| 362 | ) |
| 363 | |
| 364 | path_result = self._validate_path(file_path) |
| 365 | if not path_result.valid: |
| 366 | return path_result |
| 367 | |
| 368 | if old_string is None: |
| 369 | return ValidationResult( |
| 370 | valid=False, |
| 371 | reason="old_string is None", |
| 372 | suggestion="Provide the text to replace (can be empty string for prepend)", |
| 373 | severity="error", |
| 374 | ) |
| 375 | |
| 376 | if new_string is None: |
| 377 | return ValidationResult( |
| 378 | valid=False, |
| 379 | reason="new_string is None", |
| 380 | suggestion="Provide the replacement text (can be empty string for deletion)", |
| 381 | severity="error", |
| 382 | ) |
| 383 | |
| 384 | if old_string == new_string: |
| 385 | return ValidationResult( |
| 386 | valid=False, |
| 387 | reason="old_string and new_string are identical - no change would occur", |
| 388 | suggestion="Provide different old and new strings", |
| 389 | severity="error", |
| 390 | ) |
| 391 | |
| 392 | return ValidationResult(valid=True) |
| 393 | |
| 394 | def _validate_patch(self, arguments: dict) -> ValidationResult: |
| 395 | file_path = arguments.get("file_path", "") |
| 396 | hunks = arguments.get("hunks", []) |
| 397 | |
| 398 | if not file_path or not str(file_path).strip(): |
| 399 | return ValidationResult( |
| 400 | valid=False, |
| 401 | reason="Empty file path", |
| 402 | suggestion="Provide a valid file path", |
| 403 | severity="error", |
| 404 | ) |
| 405 | |
| 406 | path_result = self._validate_path(str(file_path)) |
| 407 | if not path_result.valid: |
| 408 | return path_result |
| 409 | |
| 410 | if not isinstance(hunks, list) or not hunks: |
| 411 | return ValidationResult( |
| 412 | valid=False, |
| 413 | reason="Patch hunks are missing", |
| 414 | suggestion="Provide one or more structured patch hunks", |
| 415 | severity="error", |
| 416 | ) |
| 417 | |
| 418 | return ValidationResult(valid=True) |
| 419 | |
| 420 | def _validate_read(self, arguments: dict) -> ValidationResult: |
| 421 | file_path = arguments.get("file_path", "") |
| 422 | |
| 423 | if not file_path or not file_path.strip(): |
| 424 | return ValidationResult( |
| 425 | valid=False, |
| 426 | reason="Empty file path", |
| 427 | suggestion="Provide a valid file path", |
| 428 | severity="error", |
| 429 | ) |
| 430 | |
| 431 | return self._validate_path(file_path) |
| 432 | |
| 433 | def _validate_search(self, tool_name: str, arguments: dict) -> ValidationResult: |
| 434 | pattern = arguments.get("pattern", "") |
| 435 | |
| 436 | if not pattern or not pattern.strip(): |
| 437 | return ValidationResult( |
| 438 | valid=False, |
| 439 | reason=f"Empty {tool_name} pattern", |
| 440 | suggestion="Provide a valid search pattern", |
| 441 | severity="error", |
| 442 | ) |
| 443 | |
| 444 | return ValidationResult(valid=True) |
| 445 | |
| 446 | def _validate_path(self, file_path: str) -> ValidationResult: |
| 447 | if '\x00' in file_path: |
| 448 | return ValidationResult( |
| 449 | valid=False, |
| 450 | reason="Path contains null byte", |
| 451 | suggestion="Remove null bytes from path", |
| 452 | severity="block", |
| 453 | ) |
| 454 | |
| 455 | if '/../../../' in file_path or file_path.count('..') > 5: |
| 456 | return ValidationResult( |
| 457 | valid=False, |
| 458 | reason="Excessive path traversal", |
| 459 | suggestion="Use a direct path instead", |
| 460 | severity="warning", |
| 461 | ) |
| 462 | |
| 463 | return ValidationResult(valid=True) |