Move rollback planning into runtime
- SHA
098f46739a5ca996b9e4bac80d03d467ac38b689- Parents
-
34effb0 - Tree
d8d389c
098f467
098f46739a5ca996b9e4bac80d03d467ac38b68934effb0
d8d389c| Status | File | + | - |
|---|---|---|---|
| M |
src/loader/agent/reasoning.py
|
11 | 329 |
| M |
src/loader/runtime/conversation.py
|
1 | 1 |
| M |
src/loader/runtime/events.py
|
1 | 2 |
| M |
src/loader/runtime/hooks.py
|
2 | 2 |
| A |
src/loader/runtime/rollback.py
|
331 | 0 |
| M |
src/loader/ui/adapter.py
|
1 | 2 |
| A |
tests/test_runtime_rollback.py
|
83 | 0 |
src/loader/agent/reasoning.pymodified@@ -11,9 +11,19 @@ enabled to improve the agent's decision-making: | ||
| 11 | 11 | |
| 12 | 12 | import re |
| 13 | 13 | from dataclasses import dataclass, field |
| 14 | -from enum import Enum, auto | |
| 14 | +from enum import Enum | |
| 15 | 15 | from typing import Any |
| 16 | 16 | |
| 17 | +from ..runtime.rollback import ( | |
| 18 | + RollbackAction, | |
| 19 | + RollbackPlan, | |
| 20 | + RollbackType, | |
| 21 | + create_rollback_plan_for_action, | |
| 22 | + execute_rollback, | |
| 23 | + get_undo_command, | |
| 24 | + is_destructive_tool, | |
| 25 | +) | |
| 26 | + | |
| 17 | 27 | |
| 18 | 28 | # === Query Classification === |
| 19 | 29 | |
@@ -904,331 +914,3 @@ def parse_completion_check(response: str, original_task: str) -> TaskCompletionC | ||
| 904 | 914 | ) |
| 905 | 915 | except json.JSONDecodeError: |
| 906 | 916 | return TaskCompletionCheck(original_task=original_task) |
| 907 | - | |
| 908 | - | |
| 909 | -# === Rollback Planning === | |
| 910 | - | |
| 911 | -class RollbackType(Enum): | |
| 912 | - """Types of rollback actions.""" | |
| 913 | - FILE_RESTORE = auto() # Restore file from backup | |
| 914 | - FILE_DELETE = auto() # Delete a created file | |
| 915 | - GIT_CHECKOUT = auto() # git checkout to restore | |
| 916 | - GIT_STASH_POP = auto() # git stash pop to restore | |
| 917 | - COMMAND_UNDO = auto() # Run an undo command | |
| 918 | - NO_ROLLBACK = auto() # Cannot be rolled back | |
| 919 | - | |
| 920 | - | |
| 921 | -@dataclass | |
| 922 | -class RollbackAction: | |
| 923 | - """A single rollback action.""" | |
| 924 | - type: RollbackType | |
| 925 | - description: str | |
| 926 | - file_path: str = "" | |
| 927 | - original_content: str = "" # For file restores | |
| 928 | - undo_command: str = "" # For command undos | |
| 929 | - executed: bool = False | |
| 930 | - | |
| 931 | - | |
| 932 | -@dataclass | |
| 933 | -class RollbackPlan: | |
| 934 | - """Plan for rolling back a series of actions.""" | |
| 935 | - actions: list[RollbackAction] = field(default_factory=list) | |
| 936 | - created_files: list[str] = field(default_factory=list) | |
| 937 | - modified_files: dict[str, str] = field(default_factory=dict) # path -> original content | |
| 938 | - git_stashed: bool = False | |
| 939 | - can_rollback: bool = True | |
| 940 | - | |
| 941 | - def add_file_creation(self, file_path: str) -> None: | |
| 942 | - """Track a file that was created (can be deleted to rollback).""" | |
| 943 | - self.created_files.append(file_path) | |
| 944 | - self.actions.append(RollbackAction( | |
| 945 | - type=RollbackType.FILE_DELETE, | |
| 946 | - description=f"Delete created file: {file_path}", | |
| 947 | - file_path=file_path, | |
| 948 | - )) | |
| 949 | - | |
| 950 | - def add_file_modification(self, file_path: str, original_content: str) -> None: | |
| 951 | - """Track a file modification (can restore original content).""" | |
| 952 | - if file_path not in self.modified_files: | |
| 953 | - self.modified_files[file_path] = original_content | |
| 954 | - self.actions.append(RollbackAction( | |
| 955 | - type=RollbackType.FILE_RESTORE, | |
| 956 | - description=f"Restore original: {file_path}", | |
| 957 | - file_path=file_path, | |
| 958 | - original_content=original_content, | |
| 959 | - )) | |
| 960 | - | |
| 961 | - def add_git_stash(self) -> None: | |
| 962 | - """Track that we stashed git changes.""" | |
| 963 | - if not self.git_stashed: | |
| 964 | - self.git_stashed = True | |
| 965 | - self.actions.append(RollbackAction( | |
| 966 | - type=RollbackType.GIT_STASH_POP, | |
| 967 | - description="Restore stashed changes: git stash pop", | |
| 968 | - )) | |
| 969 | - | |
| 970 | - def add_command_undo(self, description: str, undo_command: str) -> None: | |
| 971 | - """Track a command that can be undone.""" | |
| 972 | - self.actions.append(RollbackAction( | |
| 973 | - type=RollbackType.COMMAND_UNDO, | |
| 974 | - description=description, | |
| 975 | - undo_command=undo_command, | |
| 976 | - )) | |
| 977 | - | |
| 978 | - def add_no_rollback(self, description: str) -> None: | |
| 979 | - """Track an action that cannot be rolled back.""" | |
| 980 | - self.can_rollback = False | |
| 981 | - self.actions.append(RollbackAction( | |
| 982 | - type=RollbackType.NO_ROLLBACK, | |
| 983 | - description=f"Cannot undo: {description}", | |
| 984 | - )) | |
| 985 | - | |
| 986 | - def get_rollback_steps(self) -> list[str]: | |
| 987 | - """Get human-readable rollback steps (in reverse order).""" | |
| 988 | - steps = [] | |
| 989 | - for action in reversed(self.actions): | |
| 990 | - if action.type == RollbackType.FILE_DELETE: | |
| 991 | - steps.append(f"Delete: {action.file_path}") | |
| 992 | - elif action.type == RollbackType.FILE_RESTORE: | |
| 993 | - steps.append(f"Restore: {action.file_path}") | |
| 994 | - elif action.type == RollbackType.GIT_CHECKOUT: | |
| 995 | - steps.append(f"Git restore: {action.file_path}") | |
| 996 | - elif action.type == RollbackType.GIT_STASH_POP: | |
| 997 | - steps.append("Run: git stash pop") | |
| 998 | - elif action.type == RollbackType.COMMAND_UNDO: | |
| 999 | - steps.append(f"Run: {action.undo_command}") | |
| 1000 | - elif action.type == RollbackType.NO_ROLLBACK: | |
| 1001 | - steps.append(f"⚠ {action.description}") | |
| 1002 | - return steps | |
| 1003 | - | |
| 1004 | - def to_prompt(self) -> str: | |
| 1005 | - """Format rollback plan for display.""" | |
| 1006 | - if not self.actions: | |
| 1007 | - return "No rollback actions recorded." | |
| 1008 | - | |
| 1009 | - lines = ["Rollback plan:"] | |
| 1010 | - for i, step in enumerate(self.get_rollback_steps(), 1): | |
| 1011 | - lines.append(f" {i}. {step}") | |
| 1012 | - | |
| 1013 | - if not self.can_rollback: | |
| 1014 | - lines.append("\n⚠ Warning: Some actions cannot be undone!") | |
| 1015 | - | |
| 1016 | - return "\n".join(lines) | |
| 1017 | - | |
| 1018 | - | |
| 1019 | -def is_destructive_tool(tool_name: str, tool_args: dict) -> bool: | |
| 1020 | - """Check if a tool call is potentially destructive.""" | |
| 1021 | - if tool_name == "write": | |
| 1022 | - return True # Creating/overwriting files | |
| 1023 | - | |
| 1024 | - if tool_name == "edit": | |
| 1025 | - return True # Modifying files | |
| 1026 | - | |
| 1027 | - if tool_name == "patch": | |
| 1028 | - return True # Applying structured file edits | |
| 1029 | - | |
| 1030 | - if tool_name == "bash": | |
| 1031 | - command = tool_args.get("command", "").lower() | |
| 1032 | - destructive_patterns = [ | |
| 1033 | - "rm ", "rm -", "rmdir", # Delete | |
| 1034 | - "mv ", "rename", # Move/rename | |
| 1035 | - "> ", ">>", # Redirect/overwrite | |
| 1036 | - "chmod", "chown", # Permissions | |
| 1037 | - "git reset", "git checkout", # Git destructive | |
| 1038 | - "git clean", "git stash", | |
| 1039 | - "npm uninstall", "pip uninstall", # Package removal | |
| 1040 | - "drop ", "delete ", "truncate", # Database | |
| 1041 | - ] | |
| 1042 | - return any(p in command for p in destructive_patterns) | |
| 1043 | - | |
| 1044 | - return False | |
| 1045 | - | |
| 1046 | - | |
| 1047 | -def get_undo_command(command: str) -> str | None: | |
| 1048 | - """Get the undo command for a bash command, if possible.""" | |
| 1049 | - command_lower = command.lower().strip() | |
| 1050 | - | |
| 1051 | - # mkdir -> rmdir (only for empty dirs) | |
| 1052 | - if command_lower.startswith("mkdir "): | |
| 1053 | - dir_path = command.split("mkdir", 1)[1].strip().split()[0] | |
| 1054 | - return f"rmdir {dir_path}" | |
| 1055 | - | |
| 1056 | - # git stash -> git stash pop | |
| 1057 | - if "git stash" in command_lower and "pop" not in command_lower: | |
| 1058 | - return "git stash pop" | |
| 1059 | - | |
| 1060 | - # npm install -> npm uninstall (for specific packages) | |
| 1061 | - if "npm install " in command_lower or "npm i " in command_lower: | |
| 1062 | - # Extract package name | |
| 1063 | - parts = command.split() | |
| 1064 | - for i, part in enumerate(parts): | |
| 1065 | - if part in ("install", "i") and i + 1 < len(parts): | |
| 1066 | - pkg = parts[i + 1] | |
| 1067 | - if not pkg.startswith("-"): | |
| 1068 | - return f"npm uninstall {pkg}" | |
| 1069 | - | |
| 1070 | - # pip install -> pip uninstall | |
| 1071 | - if "pip install " in command_lower or "pip3 install " in command_lower: | |
| 1072 | - parts = command.split() | |
| 1073 | - for i, part in enumerate(parts): | |
| 1074 | - if part == "install" and i + 1 < len(parts): | |
| 1075 | - pkg = parts[i + 1] | |
| 1076 | - if not pkg.startswith("-"): | |
| 1077 | - return f"pip uninstall -y {pkg}" | |
| 1078 | - | |
| 1079 | - return None | |
| 1080 | - | |
| 1081 | - | |
| 1082 | -async def create_rollback_plan_for_action( | |
| 1083 | - tool_name: str, | |
| 1084 | - tool_args: dict, | |
| 1085 | - read_file_func, # async function to read file contents | |
| 1086 | -) -> RollbackAction | None: | |
| 1087 | - """Create a rollback action for a tool call. | |
| 1088 | - | |
| 1089 | - Args: | |
| 1090 | - tool_name: Name of the tool | |
| 1091 | - tool_args: Tool arguments | |
| 1092 | - read_file_func: Async function that reads file contents given a path | |
| 1093 | - | |
| 1094 | - Returns: | |
| 1095 | - RollbackAction or None if no rollback needed/possible | |
| 1096 | - """ | |
| 1097 | - import os | |
| 1098 | - | |
| 1099 | - if tool_name == "write": | |
| 1100 | - file_path = tool_args.get("file_path", "") | |
| 1101 | - if not file_path: | |
| 1102 | - return None | |
| 1103 | - | |
| 1104 | - # Check if file exists (we'd be overwriting) | |
| 1105 | - if os.path.exists(file_path): | |
| 1106 | - try: | |
| 1107 | - original = await read_file_func(file_path) | |
| 1108 | - return RollbackAction( | |
| 1109 | - type=RollbackType.FILE_RESTORE, | |
| 1110 | - description=f"Restore original: {file_path}", | |
| 1111 | - file_path=file_path, | |
| 1112 | - original_content=original, | |
| 1113 | - ) | |
| 1114 | - except Exception: | |
| 1115 | - return RollbackAction( | |
| 1116 | - type=RollbackType.NO_ROLLBACK, | |
| 1117 | - description=f"Could not backup: {file_path}", | |
| 1118 | - file_path=file_path, | |
| 1119 | - ) | |
| 1120 | - else: | |
| 1121 | - # New file - can delete to rollback | |
| 1122 | - return RollbackAction( | |
| 1123 | - type=RollbackType.FILE_DELETE, | |
| 1124 | - description=f"Delete created file: {file_path}", | |
| 1125 | - file_path=file_path, | |
| 1126 | - ) | |
| 1127 | - | |
| 1128 | - if tool_name == "edit": | |
| 1129 | - file_path = tool_args.get("file_path", "") | |
| 1130 | - if not file_path: | |
| 1131 | - return None | |
| 1132 | - | |
| 1133 | - try: | |
| 1134 | - original = await read_file_func(file_path) | |
| 1135 | - return RollbackAction( | |
| 1136 | - type=RollbackType.FILE_RESTORE, | |
| 1137 | - description=f"Restore original: {file_path}", | |
| 1138 | - file_path=file_path, | |
| 1139 | - original_content=original, | |
| 1140 | - ) | |
| 1141 | - except Exception: | |
| 1142 | - return RollbackAction( | |
| 1143 | - type=RollbackType.NO_ROLLBACK, | |
| 1144 | - description=f"Could not backup: {file_path}", | |
| 1145 | - file_path=file_path, | |
| 1146 | - ) | |
| 1147 | - | |
| 1148 | - if tool_name == "patch": | |
| 1149 | - file_path = tool_args.get("file_path", "") | |
| 1150 | - if not file_path: | |
| 1151 | - return None | |
| 1152 | - | |
| 1153 | - try: | |
| 1154 | - original = await read_file_func(file_path) | |
| 1155 | - return RollbackAction( | |
| 1156 | - type=RollbackType.FILE_RESTORE, | |
| 1157 | - description=f"Restore original: {file_path}", | |
| 1158 | - file_path=file_path, | |
| 1159 | - original_content=original, | |
| 1160 | - ) | |
| 1161 | - except Exception: | |
| 1162 | - return RollbackAction( | |
| 1163 | - type=RollbackType.NO_ROLLBACK, | |
| 1164 | - description=f"Could not backup: {file_path}", | |
| 1165 | - file_path=file_path, | |
| 1166 | - ) | |
| 1167 | - | |
| 1168 | - if tool_name == "bash": | |
| 1169 | - command = tool_args.get("command", "") | |
| 1170 | - undo = get_undo_command(command) | |
| 1171 | - if undo: | |
| 1172 | - return RollbackAction( | |
| 1173 | - type=RollbackType.COMMAND_UNDO, | |
| 1174 | - description=f"Undo with: {undo}", | |
| 1175 | - undo_command=undo, | |
| 1176 | - ) | |
| 1177 | - elif is_destructive_tool(tool_name, tool_args): | |
| 1178 | - return RollbackAction( | |
| 1179 | - type=RollbackType.NO_ROLLBACK, | |
| 1180 | - description=f"Cannot undo: {command[:50]}...", | |
| 1181 | - ) | |
| 1182 | - | |
| 1183 | - return None | |
| 1184 | - | |
| 1185 | - | |
| 1186 | -async def execute_rollback(plan: RollbackPlan, write_file_func, run_command_func) -> list[str]: | |
| 1187 | - """Execute a rollback plan. | |
| 1188 | - | |
| 1189 | - Args: | |
| 1190 | - plan: The rollback plan to execute | |
| 1191 | - write_file_func: Async function to write file contents | |
| 1192 | - run_command_func: Async function to run shell commands | |
| 1193 | - | |
| 1194 | - Returns: | |
| 1195 | - List of results/errors from rollback actions | |
| 1196 | - """ | |
| 1197 | - import os | |
| 1198 | - | |
| 1199 | - results = [] | |
| 1200 | - | |
| 1201 | - # Execute in reverse order | |
| 1202 | - for action in reversed(plan.actions): | |
| 1203 | - if action.executed: | |
| 1204 | - continue | |
| 1205 | - | |
| 1206 | - try: | |
| 1207 | - if action.type == RollbackType.FILE_DELETE: | |
| 1208 | - if os.path.exists(action.file_path): | |
| 1209 | - os.remove(action.file_path) | |
| 1210 | - results.append(f"✓ Deleted: {action.file_path}") | |
| 1211 | - action.executed = True | |
| 1212 | - | |
| 1213 | - elif action.type == RollbackType.FILE_RESTORE: | |
| 1214 | - await write_file_func(action.file_path, action.original_content) | |
| 1215 | - results.append(f"✓ Restored: {action.file_path}") | |
| 1216 | - action.executed = True | |
| 1217 | - | |
| 1218 | - elif action.type == RollbackType.COMMAND_UNDO: | |
| 1219 | - result = await run_command_func(action.undo_command) | |
| 1220 | - results.append(f"✓ Ran: {action.undo_command}") | |
| 1221 | - action.executed = True | |
| 1222 | - | |
| 1223 | - elif action.type == RollbackType.GIT_STASH_POP: | |
| 1224 | - result = await run_command_func("git stash pop") | |
| 1225 | - results.append("✓ Restored git stash") | |
| 1226 | - action.executed = True | |
| 1227 | - | |
| 1228 | - elif action.type == RollbackType.NO_ROLLBACK: | |
| 1229 | - results.append(f"⚠ Skipped (no rollback): {action.description}") | |
| 1230 | - | |
| 1231 | - except Exception as e: | |
| 1232 | - results.append(f"✗ Failed {action.description}: {e}") | |
| 1233 | - | |
| 1234 | - return results | |
src/loader/runtime/conversation.pymodified@@ -8,7 +8,6 @@ from pathlib import Path | ||
| 8 | 8 | from typing import Any |
| 9 | 9 | |
| 10 | 10 | from ..agent.reasoning import ( |
| 11 | - RollbackPlan, | |
| 12 | 11 | estimate_complexity, |
| 13 | 12 | get_token_budget, |
| 14 | 13 | ) |
@@ -23,6 +22,7 @@ from .finalization import TurnFinalizer, merge_usage | ||
| 23 | 22 | from .hooks import build_default_tool_hooks |
| 24 | 23 | from .phases import TurnPhase, TurnPhaseTracker |
| 25 | 24 | from .repair import ResponseRepairer |
| 25 | +from .rollback import RollbackPlan | |
| 26 | 26 | from .tool_batches import ToolBatchRunner |
| 27 | 27 | from .tracing import RuntimeTracer |
| 28 | 28 | from .workflow import ( |
src/loader/runtime/events.pymodified@@ -8,8 +8,6 @@ from typing import Any | ||
| 8 | 8 | from ..agent.reasoning import ( |
| 9 | 9 | ActionVerification, |
| 10 | 10 | ConfidenceAssessment, |
| 11 | - RollbackAction, | |
| 12 | - RollbackPlan, | |
| 13 | 11 | SelfCritique, |
| 14 | 12 | Subtask, |
| 15 | 13 | TaskCompletionCheck, |
@@ -17,6 +15,7 @@ from ..agent.reasoning import ( | ||
| 17 | 15 | ) |
| 18 | 16 | from ..llm.base import Message |
| 19 | 17 | from .dod import DefinitionOfDone |
| 18 | +from .rollback import RollbackAction, RollbackPlan | |
| 20 | 19 | from .tracing import RuntimeTraceEvent |
| 21 | 20 | |
| 22 | 21 | |
src/loader/runtime/hooks.pymodified@@ -7,13 +7,13 @@ from dataclasses import dataclass, field | ||
| 7 | 7 | from enum import StrEnum |
| 8 | 8 | from typing import Any, Protocol |
| 9 | 9 | |
| 10 | -from ..agent.reasoning import RollbackPlan, create_rollback_plan_for_action, is_destructive_tool | |
| 11 | 10 | from ..llm.base import ToolCall |
| 12 | -from ..runtime.safeguard_services import ActionTracker, PreActionValidator | |
| 13 | 11 | from ..tools.base import Tool, ToolRegistry |
| 14 | 12 | from ..tools.base import ToolResult as RegistryToolResult |
| 15 | 13 | from .memory import MemoryStore |
| 16 | 14 | from .permissions import PermissionOverride, PermissionPolicy |
| 15 | +from .rollback import RollbackPlan, create_rollback_plan_for_action, is_destructive_tool | |
| 16 | +from .safeguard_services import ActionTracker, PreActionValidator | |
| 17 | 17 | |
| 18 | 18 | |
| 19 | 19 | class HookEvent(StrEnum): |
src/loader/runtime/rollback.pyadded@@ -0,0 +1,331 @@ | ||
| 1 | +"""Runtime-owned rollback planning services.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +import os | |
| 6 | +from dataclasses import dataclass, field | |
| 7 | +from enum import Enum, auto | |
| 8 | + | |
| 9 | + | |
| 10 | +class RollbackType(Enum): | |
| 11 | + """Types of rollback actions.""" | |
| 12 | + | |
| 13 | + FILE_RESTORE = auto() | |
| 14 | + FILE_DELETE = auto() | |
| 15 | + GIT_CHECKOUT = auto() | |
| 16 | + GIT_STASH_POP = auto() | |
| 17 | + COMMAND_UNDO = auto() | |
| 18 | + NO_ROLLBACK = auto() | |
| 19 | + | |
| 20 | + | |
| 21 | +@dataclass | |
| 22 | +class RollbackAction: | |
| 23 | + """A single rollback action.""" | |
| 24 | + | |
| 25 | + type: RollbackType | |
| 26 | + description: str | |
| 27 | + file_path: str = "" | |
| 28 | + original_content: str = "" | |
| 29 | + undo_command: str = "" | |
| 30 | + executed: bool = False | |
| 31 | + | |
| 32 | + | |
| 33 | +@dataclass | |
| 34 | +class RollbackPlan: | |
| 35 | + """Plan for rolling back a series of actions.""" | |
| 36 | + | |
| 37 | + actions: list[RollbackAction] = field(default_factory=list) | |
| 38 | + created_files: list[str] = field(default_factory=list) | |
| 39 | + modified_files: dict[str, str] = field(default_factory=dict) | |
| 40 | + git_stashed: bool = False | |
| 41 | + can_rollback: bool = True | |
| 42 | + | |
| 43 | + def add_file_creation(self, file_path: str) -> None: | |
| 44 | + """Track a file that was created (can be deleted to rollback).""" | |
| 45 | + | |
| 46 | + self.created_files.append(file_path) | |
| 47 | + self.actions.append( | |
| 48 | + RollbackAction( | |
| 49 | + type=RollbackType.FILE_DELETE, | |
| 50 | + description=f"Delete created file: {file_path}", | |
| 51 | + file_path=file_path, | |
| 52 | + ) | |
| 53 | + ) | |
| 54 | + | |
| 55 | + def add_file_modification(self, file_path: str, original_content: str) -> None: | |
| 56 | + """Track a file modification (can restore original content).""" | |
| 57 | + | |
| 58 | + if file_path not in self.modified_files: | |
| 59 | + self.modified_files[file_path] = original_content | |
| 60 | + self.actions.append( | |
| 61 | + RollbackAction( | |
| 62 | + type=RollbackType.FILE_RESTORE, | |
| 63 | + description=f"Restore original: {file_path}", | |
| 64 | + file_path=file_path, | |
| 65 | + original_content=original_content, | |
| 66 | + ) | |
| 67 | + ) | |
| 68 | + | |
| 69 | + def add_git_stash(self) -> None: | |
| 70 | + """Track that we stashed git changes.""" | |
| 71 | + | |
| 72 | + if not self.git_stashed: | |
| 73 | + self.git_stashed = True | |
| 74 | + self.actions.append( | |
| 75 | + RollbackAction( | |
| 76 | + type=RollbackType.GIT_STASH_POP, | |
| 77 | + description="Restore stashed changes: git stash pop", | |
| 78 | + ) | |
| 79 | + ) | |
| 80 | + | |
| 81 | + def add_command_undo(self, description: str, undo_command: str) -> None: | |
| 82 | + """Track a command that can be undone.""" | |
| 83 | + | |
| 84 | + self.actions.append( | |
| 85 | + RollbackAction( | |
| 86 | + type=RollbackType.COMMAND_UNDO, | |
| 87 | + description=description, | |
| 88 | + undo_command=undo_command, | |
| 89 | + ) | |
| 90 | + ) | |
| 91 | + | |
| 92 | + def add_no_rollback(self, description: str) -> None: | |
| 93 | + """Track an action that cannot be rolled back.""" | |
| 94 | + | |
| 95 | + self.can_rollback = False | |
| 96 | + self.actions.append( | |
| 97 | + RollbackAction( | |
| 98 | + type=RollbackType.NO_ROLLBACK, | |
| 99 | + description=f"Cannot undo: {description}", | |
| 100 | + ) | |
| 101 | + ) | |
| 102 | + | |
| 103 | + def get_rollback_steps(self) -> list[str]: | |
| 104 | + """Get human-readable rollback steps in reverse order.""" | |
| 105 | + | |
| 106 | + steps = [] | |
| 107 | + for action in reversed(self.actions): | |
| 108 | + if action.type == RollbackType.FILE_DELETE: | |
| 109 | + steps.append(f"Delete: {action.file_path}") | |
| 110 | + elif action.type == RollbackType.FILE_RESTORE: | |
| 111 | + steps.append(f"Restore: {action.file_path}") | |
| 112 | + elif action.type == RollbackType.GIT_CHECKOUT: | |
| 113 | + steps.append(f"Git restore: {action.file_path}") | |
| 114 | + elif action.type == RollbackType.GIT_STASH_POP: | |
| 115 | + steps.append("Run: git stash pop") | |
| 116 | + elif action.type == RollbackType.COMMAND_UNDO: | |
| 117 | + steps.append(f"Run: {action.undo_command}") | |
| 118 | + elif action.type == RollbackType.NO_ROLLBACK: | |
| 119 | + steps.append(f"⚠ {action.description}") | |
| 120 | + return steps | |
| 121 | + | |
| 122 | + def to_prompt(self) -> str: | |
| 123 | + """Format rollback plan for display.""" | |
| 124 | + | |
| 125 | + if not self.actions: | |
| 126 | + return "No rollback actions recorded." | |
| 127 | + | |
| 128 | + lines = ["Rollback plan:"] | |
| 129 | + for index, step in enumerate(self.get_rollback_steps(), 1): | |
| 130 | + lines.append(f" {index}. {step}") | |
| 131 | + | |
| 132 | + if not self.can_rollback: | |
| 133 | + lines.append("\n⚠ Warning: Some actions cannot be undone!") | |
| 134 | + | |
| 135 | + return "\n".join(lines) | |
| 136 | + | |
| 137 | + | |
| 138 | +def is_destructive_tool(tool_name: str, tool_args: dict) -> bool: | |
| 139 | + """Check if a tool call is potentially destructive.""" | |
| 140 | + | |
| 141 | + if tool_name in {"write", "edit", "patch"}: | |
| 142 | + return True | |
| 143 | + | |
| 144 | + if tool_name == "bash": | |
| 145 | + command = tool_args.get("command", "").lower() | |
| 146 | + destructive_patterns = [ | |
| 147 | + "rm ", | |
| 148 | + "rm -", | |
| 149 | + "rmdir", | |
| 150 | + "mv ", | |
| 151 | + "rename", | |
| 152 | + "> ", | |
| 153 | + ">>", | |
| 154 | + "chmod", | |
| 155 | + "chown", | |
| 156 | + "git reset", | |
| 157 | + "git checkout", | |
| 158 | + "git clean", | |
| 159 | + "git stash", | |
| 160 | + "npm uninstall", | |
| 161 | + "pip uninstall", | |
| 162 | + "drop ", | |
| 163 | + "delete ", | |
| 164 | + "truncate", | |
| 165 | + ] | |
| 166 | + return any(pattern in command for pattern in destructive_patterns) | |
| 167 | + | |
| 168 | + return False | |
| 169 | + | |
| 170 | + | |
| 171 | +def get_undo_command(command: str) -> str | None: | |
| 172 | + """Get the undo command for a bash command, if possible.""" | |
| 173 | + | |
| 174 | + command_lower = command.lower().strip() | |
| 175 | + | |
| 176 | + if command_lower.startswith("mkdir "): | |
| 177 | + dir_path = command.split("mkdir", 1)[1].strip().split()[0] | |
| 178 | + return f"rmdir {dir_path}" | |
| 179 | + | |
| 180 | + if "git stash" in command_lower and "pop" not in command_lower: | |
| 181 | + return "git stash pop" | |
| 182 | + | |
| 183 | + if "npm install " in command_lower or "npm i " in command_lower: | |
| 184 | + parts = command.split() | |
| 185 | + for index, part in enumerate(parts): | |
| 186 | + if part in ("install", "i") and index + 1 < len(parts): | |
| 187 | + package = parts[index + 1] | |
| 188 | + if not package.startswith("-"): | |
| 189 | + return f"npm uninstall {package}" | |
| 190 | + | |
| 191 | + if "pip install " in command_lower or "pip3 install " in command_lower: | |
| 192 | + parts = command.split() | |
| 193 | + for index, part in enumerate(parts): | |
| 194 | + if part == "install" and index + 1 < len(parts): | |
| 195 | + package = parts[index + 1] | |
| 196 | + if not package.startswith("-"): | |
| 197 | + return f"pip uninstall -y {package}" | |
| 198 | + | |
| 199 | + return None | |
| 200 | + | |
| 201 | + | |
| 202 | +async def create_rollback_plan_for_action( | |
| 203 | + tool_name: str, | |
| 204 | + tool_args: dict, | |
| 205 | + read_file_func, | |
| 206 | +) -> RollbackAction | None: | |
| 207 | + """Create a rollback action for a tool call.""" | |
| 208 | + | |
| 209 | + if tool_name == "write": | |
| 210 | + file_path = tool_args.get("file_path", "") | |
| 211 | + if not file_path: | |
| 212 | + return None | |
| 213 | + | |
| 214 | + if os.path.exists(file_path): | |
| 215 | + try: | |
| 216 | + original = await read_file_func(file_path) | |
| 217 | + return RollbackAction( | |
| 218 | + type=RollbackType.FILE_RESTORE, | |
| 219 | + description=f"Restore original: {file_path}", | |
| 220 | + file_path=file_path, | |
| 221 | + original_content=original, | |
| 222 | + ) | |
| 223 | + except Exception: | |
| 224 | + return RollbackAction( | |
| 225 | + type=RollbackType.NO_ROLLBACK, | |
| 226 | + description=f"Could not backup: {file_path}", | |
| 227 | + file_path=file_path, | |
| 228 | + ) | |
| 229 | + | |
| 230 | + return RollbackAction( | |
| 231 | + type=RollbackType.FILE_DELETE, | |
| 232 | + description=f"Delete created file: {file_path}", | |
| 233 | + file_path=file_path, | |
| 234 | + ) | |
| 235 | + | |
| 236 | + if tool_name == "edit": | |
| 237 | + file_path = tool_args.get("file_path", "") | |
| 238 | + if not file_path: | |
| 239 | + return None | |
| 240 | + | |
| 241 | + try: | |
| 242 | + original = await read_file_func(file_path) | |
| 243 | + return RollbackAction( | |
| 244 | + type=RollbackType.FILE_RESTORE, | |
| 245 | + description=f"Restore original: {file_path}", | |
| 246 | + file_path=file_path, | |
| 247 | + original_content=original, | |
| 248 | + ) | |
| 249 | + except Exception: | |
| 250 | + return RollbackAction( | |
| 251 | + type=RollbackType.NO_ROLLBACK, | |
| 252 | + description=f"Could not backup: {file_path}", | |
| 253 | + file_path=file_path, | |
| 254 | + ) | |
| 255 | + | |
| 256 | + if tool_name == "patch": | |
| 257 | + file_path = tool_args.get("file_path", "") | |
| 258 | + if not file_path: | |
| 259 | + return None | |
| 260 | + | |
| 261 | + try: | |
| 262 | + original = await read_file_func(file_path) | |
| 263 | + return RollbackAction( | |
| 264 | + type=RollbackType.FILE_RESTORE, | |
| 265 | + description=f"Restore original: {file_path}", | |
| 266 | + file_path=file_path, | |
| 267 | + original_content=original, | |
| 268 | + ) | |
| 269 | + except Exception: | |
| 270 | + return RollbackAction( | |
| 271 | + type=RollbackType.NO_ROLLBACK, | |
| 272 | + description=f"Could not backup: {file_path}", | |
| 273 | + file_path=file_path, | |
| 274 | + ) | |
| 275 | + | |
| 276 | + if tool_name == "bash": | |
| 277 | + command = tool_args.get("command", "") | |
| 278 | + undo = get_undo_command(command) | |
| 279 | + if undo: | |
| 280 | + return RollbackAction( | |
| 281 | + type=RollbackType.COMMAND_UNDO, | |
| 282 | + description=f"Undo with: {undo}", | |
| 283 | + undo_command=undo, | |
| 284 | + ) | |
| 285 | + if is_destructive_tool(tool_name, tool_args): | |
| 286 | + return RollbackAction( | |
| 287 | + type=RollbackType.NO_ROLLBACK, | |
| 288 | + description=f"Cannot undo: {command[:50]}...", | |
| 289 | + ) | |
| 290 | + | |
| 291 | + return None | |
| 292 | + | |
| 293 | + | |
| 294 | +async def execute_rollback(plan: RollbackPlan, write_file_func, run_command_func) -> list[str]: | |
| 295 | + """Execute a rollback plan.""" | |
| 296 | + | |
| 297 | + results = [] | |
| 298 | + | |
| 299 | + for action in reversed(plan.actions): | |
| 300 | + if action.executed: | |
| 301 | + continue | |
| 302 | + | |
| 303 | + try: | |
| 304 | + if action.type == RollbackType.FILE_DELETE: | |
| 305 | + if os.path.exists(action.file_path): | |
| 306 | + os.remove(action.file_path) | |
| 307 | + results.append(f"✓ Deleted: {action.file_path}") | |
| 308 | + action.executed = True | |
| 309 | + | |
| 310 | + elif action.type == RollbackType.FILE_RESTORE: | |
| 311 | + await write_file_func(action.file_path, action.original_content) | |
| 312 | + results.append(f"✓ Restored: {action.file_path}") | |
| 313 | + action.executed = True | |
| 314 | + | |
| 315 | + elif action.type == RollbackType.COMMAND_UNDO: | |
| 316 | + await run_command_func(action.undo_command) | |
| 317 | + results.append(f"✓ Ran: {action.undo_command}") | |
| 318 | + action.executed = True | |
| 319 | + | |
| 320 | + elif action.type == RollbackType.GIT_STASH_POP: | |
| 321 | + await run_command_func("git stash pop") | |
| 322 | + results.append("✓ Restored git stash") | |
| 323 | + action.executed = True | |
| 324 | + | |
| 325 | + elif action.type == RollbackType.NO_ROLLBACK: | |
| 326 | + results.append(f"⚠ Skipped (no rollback): {action.description}") | |
| 327 | + | |
| 328 | + except Exception as exc: | |
| 329 | + results.append(f"✗ Failed {action.description}: {exc}") | |
| 330 | + | |
| 331 | + return results | |
src/loader/ui/adapter.pymodified@@ -11,13 +11,12 @@ if TYPE_CHECKING: | ||
| 11 | 11 | from ..agent.reasoning import ( |
| 12 | 12 | ActionVerification, |
| 13 | 13 | ConfidenceAssessment, |
| 14 | - RollbackAction, | |
| 15 | - RollbackPlan, | |
| 16 | 14 | SelfCritique, |
| 17 | 15 | Subtask, |
| 18 | 16 | TaskCompletionCheck, |
| 19 | 17 | TaskDecomposition, |
| 20 | 18 | ) |
| 19 | + from ..runtime.rollback import RollbackAction, RollbackPlan | |
| 21 | 20 | |
| 22 | 21 | |
| 23 | 22 | # Custom Textual messages for TUI updates |
tests/test_runtime_rollback.pyadded@@ -0,0 +1,83 @@ | ||
| 1 | +"""Tests for runtime-owned rollback planning.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +import pytest | |
| 6 | + | |
| 7 | +from loader.runtime.rollback import ( | |
| 8 | + RollbackPlan, | |
| 9 | + RollbackType, | |
| 10 | + create_rollback_plan_for_action, | |
| 11 | + execute_rollback, | |
| 12 | + get_undo_command, | |
| 13 | + is_destructive_tool, | |
| 14 | +) | |
| 15 | + | |
| 16 | + | |
| 17 | +def test_rollback_plan_formats_reverse_steps() -> None: | |
| 18 | + plan = RollbackPlan() | |
| 19 | + plan.add_file_creation("new.txt") | |
| 20 | + plan.add_file_modification("config.json", '{"a":1}') | |
| 21 | + | |
| 22 | + assert plan.get_rollback_steps() == [ | |
| 23 | + "Restore: config.json", | |
| 24 | + "Delete: new.txt", | |
| 25 | + ] | |
| 26 | + assert plan.to_prompt() == ( | |
| 27 | + "Rollback plan:\n" | |
| 28 | + " 1. Restore: config.json\n" | |
| 29 | + " 2. Delete: new.txt" | |
| 30 | + ) | |
| 31 | + | |
| 32 | + | |
| 33 | +def test_get_undo_command_handles_common_install_commands() -> None: | |
| 34 | + assert get_undo_command("mkdir docs") == "rmdir docs" | |
| 35 | + assert get_undo_command("npm install react") == "npm uninstall react" | |
| 36 | + assert get_undo_command("pip install pytest") == "pip uninstall -y pytest" | |
| 37 | + | |
| 38 | + | |
| 39 | +def test_is_destructive_tool_covers_patch_and_bash_patterns() -> None: | |
| 40 | + assert is_destructive_tool("patch", {"file_path": "notes.txt"}) is True | |
| 41 | + assert is_destructive_tool("bash", {"command": "git checkout -- README.md"}) is True | |
| 42 | + assert is_destructive_tool("bash", {"command": "ls -la"}) is False | |
| 43 | + | |
| 44 | + | |
| 45 | +@pytest.mark.asyncio | |
| 46 | +async def test_create_rollback_plan_for_new_write_returns_delete_action(tmp_path) -> None: | |
| 47 | + target = tmp_path / "notes.txt" | |
| 48 | + | |
| 49 | + async def read_file(_path: str) -> str: | |
| 50 | + raise AssertionError("new files should not be read for rollback") | |
| 51 | + | |
| 52 | + action = await create_rollback_plan_for_action( | |
| 53 | + "write", | |
| 54 | + {"file_path": str(target), "content": "alpha\n"}, | |
| 55 | + read_file, | |
| 56 | + ) | |
| 57 | + | |
| 58 | + assert action is not None | |
| 59 | + assert action.type == RollbackType.FILE_DELETE | |
| 60 | + assert action.file_path == str(target) | |
| 61 | + | |
| 62 | + | |
| 63 | +@pytest.mark.asyncio | |
| 64 | +async def test_execute_rollback_restores_file_and_marks_action(tmp_path) -> None: | |
| 65 | + target = tmp_path / "notes.txt" | |
| 66 | + plan = RollbackPlan() | |
| 67 | + plan.add_file_modification(str(target), "restored\n") | |
| 68 | + | |
| 69 | + writes: list[tuple[str, str]] = [] | |
| 70 | + commands: list[str] = [] | |
| 71 | + | |
| 72 | + async def write_file(path: str, content: str) -> None: | |
| 73 | + writes.append((path, content)) | |
| 74 | + | |
| 75 | + async def run_command(command: str) -> None: | |
| 76 | + commands.append(command) | |
| 77 | + | |
| 78 | + results = await execute_rollback(plan, write_file, run_command) | |
| 79 | + | |
| 80 | + assert results == [f"✓ Restored: {target}"] | |
| 81 | + assert writes == [(str(target), "restored\n")] | |
| 82 | + assert commands == [] | |
| 83 | + assert plan.actions[0].executed is True | |