Python · 12224 bytes Raw Blame History
1 """Shared file-mutation preview models and Rich render helpers."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass
6 from pathlib import Path
7 from typing import Any
8
9 from rich import box
10 from rich.console import Group
11 from rich.panel import Panel
12 from rich.text import Text
13
14 from ..tools.fs_safety import (
15 StructuredPatchHunk,
16 coerce_structured_patch_payload,
17 make_structured_patch,
18 parse_unified_diff_patch,
19 )
20
21 FILE_MUTATION_TOOLS = {"write", "edit", "patch"}
22 DIFF_TRUNCATION_NOTICE = "truncated for display; full result preserved in session"
23
24
25 @dataclass(slots=True)
26 class FileMutationPreview:
27 """Normalized preview data for file-mutation tools."""
28
29 tool_name: str
30 file_path: str
31 operation: str
32 structured_patch: list[StructuredPatchHunk]
33 old_text: str | None = None
34 new_text: str | None = None
35 added_lines: int = 0
36 removed_lines: int = 0
37 context_lines: int = 0
38
39 @property
40 def hunk_count(self) -> int:
41 return len(self.structured_patch)
42
43 def to_dict(self) -> dict[str, Any]:
44 """Serialize the preview for runtime/UI event payloads."""
45
46 return {
47 "tool_name": self.tool_name,
48 "file_path": self.file_path,
49 "operation": self.operation,
50 "structured_patch": [hunk.to_dict() for hunk in self.structured_patch],
51 "old_text": self.old_text,
52 "new_text": self.new_text,
53 "added_lines": self.added_lines,
54 "removed_lines": self.removed_lines,
55 "context_lines": self.context_lines,
56 }
57
58 @classmethod
59 def from_dict(cls, payload: dict[str, Any]) -> FileMutationPreview:
60 """Deserialize one preview payload."""
61
62 return cls(
63 tool_name=str(payload.get("tool_name", "")),
64 file_path=str(payload.get("file_path", "")),
65 operation=str(payload.get("operation", "update")),
66 structured_patch=_coerce_patch_hunks(payload.get("structured_patch")),
67 old_text=_coerce_optional_text(payload.get("old_text")),
68 new_text=_coerce_optional_text(payload.get("new_text")),
69 added_lines=int(payload.get("added_lines", 0)),
70 removed_lines=int(payload.get("removed_lines", 0)),
71 context_lines=int(payload.get("context_lines", 0)),
72 )
73
74
75 def is_file_mutation_tool(tool_name: str) -> bool:
76 """Return whether the tool mutates file content."""
77
78 return tool_name in FILE_MUTATION_TOOLS
79
80
81 def build_file_mutation_preview(
82 tool_name: str,
83 *,
84 tool_args: dict[str, Any] | None = None,
85 metadata: dict[str, Any] | None = None,
86 ) -> FileMutationPreview | None:
87 """Build a normalized preview from tool args and/or result metadata."""
88
89 if not is_file_mutation_tool(tool_name):
90 return None
91
92 args = tool_args or {}
93 info = metadata or {}
94 file_path = _extract_file_path(info) or _extract_file_path(args)
95
96 structured_patch = (
97 _coerce_patch_hunks(info.get("structured_patch"))
98 or _coerce_patch_hunks(info.get("structuredPatch"))
99 or _coerce_patch_hunks(args.get("structured_patch"))
100 or _coerce_patch_hunks(args.get("structuredPatch"))
101 )
102 if not structured_patch and tool_name == "patch":
103 structured_patch = _coerce_patch_hunks(info.get("hunks")) or _coerce_patch_hunks(
104 args.get("hunks")
105 )
106 if not structured_patch and tool_name == "patch":
107 structured_patch = _coerce_raw_patch_hunks(info) or _coerce_raw_patch_hunks(args)
108
109 old_text = _extract_old_text(tool_name, info) or _extract_old_text(tool_name, args)
110 new_text = _extract_new_text(tool_name, info) or _extract_new_text(tool_name, args)
111
112 if not structured_patch and tool_name in {"write", "edit"} and new_text is not None:
113 structured_patch = make_structured_patch(old_text or "", new_text)
114
115 if not structured_patch:
116 return None
117
118 operation = _determine_operation(tool_name, info, old_text)
119 added_lines, removed_lines, context_lines = _count_patch_lines(structured_patch)
120 return FileMutationPreview(
121 tool_name=tool_name,
122 file_path=file_path or "",
123 operation=operation,
124 structured_patch=structured_patch,
125 old_text=old_text,
126 new_text=new_text,
127 added_lines=added_lines,
128 removed_lines=removed_lines,
129 context_lines=context_lines,
130 )
131
132
133 def build_file_mutation_preview_dict(
134 tool_name: str,
135 *,
136 tool_args: dict[str, Any] | None = None,
137 metadata: dict[str, Any] | None = None,
138 ) -> dict[str, Any] | None:
139 """Build a serialized preview payload when possible."""
140
141 preview = build_file_mutation_preview(
142 tool_name,
143 tool_args=tool_args,
144 metadata=metadata,
145 )
146 return preview.to_dict() if preview is not None else None
147
148
149 def render_file_mutation_preview(
150 preview: FileMutationPreview | dict[str, Any],
151 *,
152 border_style: str = "cyan",
153 title: str = "Diff",
154 max_lines: int = 40,
155 max_chars: int = 6_000,
156 ) -> Group:
157 """Render one normalized preview as a Rich group."""
158
159 resolved = (
160 FileMutationPreview.from_dict(preview)
161 if isinstance(preview, dict)
162 else preview
163 )
164 return Group(
165 _render_preview_summary(resolved),
166 Panel(
167 _render_preview_text(
168 resolved,
169 max_lines=max_lines,
170 max_chars=max_chars,
171 ),
172 title=title,
173 border_style=border_style,
174 box=box.SQUARE,
175 expand=True,
176 ),
177 )
178
179
180 def _coerce_optional_text(value: Any) -> str | None:
181 if value is None:
182 return None
183 return str(value)
184
185
186 def _coerce_patch_hunks(value: Any) -> list[StructuredPatchHunk]:
187 patch_items = coerce_structured_patch_payload(value)
188 if not patch_items:
189 return []
190
191 hunks: list[StructuredPatchHunk] = []
192 for item in patch_items:
193 if isinstance(item, StructuredPatchHunk):
194 hunks.append(item)
195 elif isinstance(item, dict):
196 try:
197 hunks.append(StructuredPatchHunk.from_dict(item))
198 except (TypeError, ValueError):
199 continue
200 return hunks
201
202
203 def _coerce_raw_patch_hunks(payload: dict[str, Any]) -> list[StructuredPatchHunk]:
204 for key in ("patch", "diff", "patch_text"):
205 value = payload.get(key)
206 if not isinstance(value, str) or not value.strip():
207 continue
208 try:
209 return parse_unified_diff_patch(value)
210 except ValueError:
211 continue
212 return []
213
214
215 def _extract_file_path(payload: dict[str, Any]) -> str | None:
216 for key in ("file_path", "filePath", "path", "filename", "file"):
217 value = payload.get(key)
218 if value:
219 return str(value)
220 return None
221
222
223 def _extract_old_text(tool_name: str, payload: dict[str, Any]) -> str | None:
224 for key in ("original_file", "originalFile", "old_string", "oldString", "old"):
225 value = payload.get(key)
226 if value is not None:
227 return str(value)
228 if tool_name == "write":
229 return ""
230 return None
231
232
233 def _extract_new_text(tool_name: str, payload: dict[str, Any]) -> str | None:
234 if tool_name == "write":
235 for key in ("content", "contents", "text", "data"):
236 value = payload.get(key)
237 if value is not None:
238 return str(value)
239 for key in ("content", "new_string", "newString", "new", "replacement", "replace"):
240 value = payload.get(key)
241 if value is not None:
242 return str(value)
243 return None
244
245
246 def _determine_operation(tool_name: str, metadata: dict[str, Any], old_text: str | None) -> str:
247 if tool_name == "patch":
248 return "patch"
249 kind = metadata.get("kind") or metadata.get("type")
250 if kind in {"create", "update"}:
251 return str(kind)
252 if tool_name == "write":
253 return "update" if old_text else "create"
254 return "update"
255
256
257 def _count_patch_lines(hunks: list[StructuredPatchHunk]) -> tuple[int, int, int]:
258 added = 0
259 removed = 0
260 context = 0
261 for hunk in hunks:
262 for raw_line in hunk.lines:
263 prefix = raw_line[:1]
264 if prefix == "+":
265 added += 1
266 elif prefix == "-":
267 removed += 1
268 elif prefix == " ":
269 context += 1
270 return added, removed, context
271
272
273 def _render_preview_summary(preview: FileMutationPreview) -> Text:
274 action_label = {
275 "create": "Create",
276 "update": "Update",
277 "patch": "Patch",
278 }.get(preview.operation, "Update")
279 action_style = {
280 "create": "bold green",
281 "update": "bold cyan",
282 "patch": "bold yellow",
283 }.get(preview.operation, "bold cyan")
284
285 text = Text()
286 filename = Path(preview.file_path).name if preview.file_path else "(unknown file)"
287 text.append(action_label, style=action_style)
288 text.append(f"({filename})")
289 if preview.file_path:
290 text.append(f"\n{preview.file_path}", style="dim")
291
292 stats = []
293 if preview.added_lines:
294 stats.append((f"+{preview.added_lines}", "green"))
295 if preview.removed_lines:
296 stats.append((f"-{preview.removed_lines}", "red"))
297 if preview.context_lines:
298 stats.append((f"{preview.context_lines} context", "dim"))
299 hunk_suffix = "" if preview.hunk_count == 1 else "s"
300 stats.append((f"{preview.hunk_count} hunk{hunk_suffix}", "dim"))
301
302 text.append("\n")
303 for index, (label, style) in enumerate(stats):
304 if index:
305 text.append(" ")
306 text.append(label, style=style)
307 return text
308
309
310 def _render_preview_text(
311 preview: FileMutationPreview,
312 *,
313 max_lines: int,
314 max_chars: int,
315 ) -> Text:
316 rendered_lines = _iter_rendered_patch_lines(preview.structured_patch)
317 output = Text()
318 line_count = 0
319 char_count = 0
320 truncated = False
321
322 for line in rendered_lines:
323 next_chars = len(line.plain)
324 if line_count >= max_lines or char_count + next_chars > max_chars:
325 truncated = True
326 break
327 output.append_text(line)
328 line_count += 1
329 char_count += next_chars
330
331 if truncated:
332 if output.plain and not output.plain.endswith("\n"):
333 output.append("\n")
334 output.append(f"... {DIFF_TRUNCATION_NOTICE}", style="dim")
335
336 if not output.plain:
337 output.append("No diff available", style="dim")
338
339 return output
340
341
342 def _iter_rendered_patch_lines(
343 hunks: list[StructuredPatchHunk],
344 ) -> list[Text]:
345 rendered: list[Text] = []
346 for hunk in hunks:
347 rendered.append(
348 Text(
349 f"@@ -{hunk.old_start},{hunk.old_lines} "
350 f"+{hunk.new_start},{hunk.new_lines} @@\n",
351 style="dim",
352 )
353 )
354 old_line = hunk.old_start
355 new_line = hunk.new_start
356 for raw_line in hunk.lines:
357 prefix = raw_line[:1] if raw_line[:1] in {" ", "+", "-"} else ""
358 content = raw_line[1:] if prefix else raw_line
359 display = Text()
360 old_label = " "
361 new_label = " "
362 if prefix in {" ", "-"}:
363 old_label = f"{old_line:>4}"
364 old_line += 1
365 if prefix in {" ", "+"}:
366 new_label = f"{new_line:>4}"
367 new_line += 1
368
369 display.append(old_label, style="dim")
370 display.append(" ")
371 display.append(new_label, style="dim")
372 display.append(" ")
373
374 if prefix == "+":
375 display.append("+ ", style="green")
376 display.append(content, style="green")
377 elif prefix == "-":
378 display.append("- ", style="red")
379 display.append(content, style="red")
380 elif prefix == " ":
381 display.append(" ", style="dim")
382 display.append(content, style="dim")
383 else:
384 display.append(" ", style="dim")
385 display.append(content)
386 display.append("\n")
387 rendered.append(display)
388 return rendered