Python · 13032 bytes Raw Blame History
1 """
2 PTY management for shell interactive testing.
3
4 Provides a high-level interface to spawn shell in a pseudo-terminal
5 and interact with it programmatically.
6 """
7
8 import os
9 import re
10 import pexpect
11 from pathlib import Path
12 from typing import Optional, Union
13
14 from utils.keys import KEYS, get_key
15
16
17 class ShellPTY:
18 """
19 Manages a shell process running in a pseudo-terminal.
20
21 This class provides methods to send input, receive output, and verify
22 behavior for interactive testing.
23 """
24
25 # Default prompt pattern - matches the shell prompt
26 # No anchors for reliability with pexpect's buffering
27 DEFAULT_PROMPT_PATTERN = r'> '
28
29 # Unique marker for reliable command output detection
30 END_MARKER = "___BENSCH_CMD_END___"
31
32 def __init__(
33 self,
34 shell_path: str = None,
35 timeout: float = 5.0,
36 prompt_pattern: Optional[str] = None,
37 env: Optional[dict] = None,
38 profile: Optional[dict] = None,
39 ):
40 """
41 Initialize the PTY wrapper.
42
43 Args:
44 shell_path: Path to shell binary
45 timeout: Default timeout for expect operations (seconds)
46 prompt_pattern: Regex pattern to match the shell prompt
47 env: Additional environment variables
48 profile: Shell profile dict (from profile.py)
49 """
50 self.shell_path = shell_path
51 self.timeout = timeout
52 self.profile = profile or {}
53 self.prompt_pattern = (
54 prompt_pattern
55 or self.profile.get("prompt_pattern")
56 or self.DEFAULT_PROMPT_PATTERN
57 )
58 self.custom_env = env or {}
59 self.child: Optional[pexpect.spawn] = None
60 self._output_buffer: str = ""
61
62 def start(self, rc_file: Optional[str] = None) -> None:
63 """
64 Start shell in a PTY.
65
66 Args:
67 rc_file: Path to rc file, or None to use default, or "/dev/null" for no rc
68 """
69 env = os.environ.copy()
70 env["TERM"] = "xterm-256color"
71 env["LANG"] = "en_US.UTF-8"
72 env["LC_ALL"] = "en_US.UTF-8"
73
74 # Apply profile-driven environment (test mode, history disable, rc disable)
75 if self.profile:
76 for k, v in self.profile.get("test_mode_env", {}).items():
77 env[k] = str(v)
78 for k, v in self.profile.get("history_disable", {}).get("env", {}).items():
79 env[k] = str(v)
80 if rc_file is not None:
81 for k, v in self.profile.get("rc_disable", {}).get("env", {}).items():
82 env[k] = str(v)
83 else:
84 # Fallback: disable history
85 env["HISTFILE"] = "/dev/null"
86
87 # Apply custom environment
88 env.update(self.custom_env)
89
90 self.child = pexpect.spawn(
91 self.shell_path,
92 encoding="utf-8",
93 codec_errors="replace", # Handle raw ANSI/highlight bytes without crashing
94 timeout=self.timeout,
95 env=env,
96 dimensions=(24, 80), # Standard terminal size
97 echo=False, # Don't echo back input
98 )
99
100 # Wait for initial prompt
101 self.wait_for_prompt()
102
103 def stop(self) -> int:
104 """
105 Stop shell and return exit code.
106
107 Returns:
108 Exit code of the shell process
109 """
110 if self.child is None:
111 return -1
112
113 import time
114 import signal
115
116 exit_code = 0
117 pid = self.child.pid
118
119 try:
120 # Try graceful exit first
121 self.child.sendline("exit")
122 self.child.expect(pexpect.EOF, timeout=1)
123 exit_code = self.child.exitstatus or 0
124 except (pexpect.TIMEOUT, pexpect.EOF):
125 # Send SIGTERM then SIGKILL
126 try:
127 self.child.kill(signal.SIGTERM)
128 time.sleep(0.1)
129 if self.child.isalive():
130 self.child.kill(signal.SIGKILL)
131 time.sleep(0.1)
132 except:
133 pass
134
135 # Close file descriptors explicitly
136 try:
137 if hasattr(self.child, 'child_fd') and self.child.child_fd is not None:
138 try:
139 os.close(self.child.child_fd)
140 except OSError:
141 pass
142 except:
143 pass
144
145 # Close the ptyprocess file object if it exists
146 try:
147 if hasattr(self.child, 'fileobj') and self.child.fileobj:
148 self.child.fileobj.close()
149 except:
150 pass
151
152 try:
153 self.child.close(force=True)
154 except:
155 pass
156
157 # Wait for process to fully terminate with retries
158 for _ in range(5):
159 try:
160 if pid:
161 result = os.waitpid(pid, os.WNOHANG)
162 if result[0] != 0:
163 break
164 time.sleep(0.05)
165 except ChildProcessError:
166 break
167 except:
168 break
169
170 # Ensure any zombie is reaped
171 try:
172 if pid:
173 os.kill(pid, 0) # Check if still exists
174 os.kill(pid, signal.SIGKILL)
175 os.waitpid(pid, 0)
176 except (ProcessLookupError, ChildProcessError, OSError):
177 pass
178
179 self.child = None
180 return exit_code
181
182 def clear_buffer(self) -> None:
183 """Clear the pexpect buffer to avoid accumulation."""
184 if self.child is None:
185 return
186 # Read any pending output without blocking
187 try:
188 while True:
189 self.child.read_nonblocking(size=1024, timeout=0.01)
190 except (pexpect.TIMEOUT, pexpect.EOF):
191 pass
192
193 def send(self, text: str) -> None:
194 """
195 Send text without newline.
196
197 Args:
198 text: Text to send to the shell
199 """
200 if self.child is None:
201 raise RuntimeError("shell not started")
202 self.child.send(text)
203
204 def send_line(self, text: str) -> None:
205 """
206 Send text followed by Enter.
207
208 Args:
209 text: Text to send to the shell
210 """
211 if self.child is None:
212 raise RuntimeError("shell not started")
213 self.child.sendline(text)
214
215 def send_key(self, key_name: str) -> None:
216 """
217 Send a special key by name.
218
219 Args:
220 key_name: Name of the key (e.g., "Up", "C-a", "Enter")
221 """
222 if self.child is None:
223 raise RuntimeError("shell not started")
224 self.child.send(get_key(key_name))
225
226 def send_keys(self, *key_names: str) -> None:
227 """
228 Send multiple keys in sequence.
229
230 Args:
231 *key_names: Names of keys to send
232 """
233 for key in key_names:
234 self.send_key(key)
235
236 def wait_for_prompt(self, timeout: Optional[float] = None) -> str:
237 """
238 Wait for the shell prompt to appear.
239
240 Args:
241 timeout: Timeout in seconds (uses default if None)
242
243 Returns:
244 Output received before the prompt
245 """
246 if self.child is None:
247 raise RuntimeError("shell not started")
248
249 self.child.expect(self.prompt_pattern, timeout=timeout or self.timeout)
250 output = self.child.before or ""
251 self._output_buffer = output
252 return output
253
254 def expect(
255 self,
256 pattern: Union[str, list],
257 timeout: Optional[float] = None
258 ) -> int:
259 """
260 Wait for a pattern in the output.
261
262 Args:
263 pattern: Regex pattern or list of patterns
264 timeout: Timeout in seconds
265
266 Returns:
267 Index of matched pattern (if list) or 0
268
269 Raises:
270 pexpect.TIMEOUT: If pattern not found within timeout
271 pexpect.EOF: If process terminates
272 """
273 if self.child is None:
274 raise RuntimeError("shell not started")
275
276 return self.child.expect(pattern, timeout=timeout or self.timeout)
277
278 def expect_exact(self, text: str, timeout: Optional[float] = None) -> None:
279 """
280 Wait for exact text in output.
281
282 Args:
283 text: Exact text to find
284 timeout: Timeout in seconds
285
286 Raises:
287 pexpect.TIMEOUT: If text not found within timeout
288 """
289 if self.child is None:
290 raise RuntimeError("shell not started")
291
292 self.child.expect_exact(text, timeout=timeout or self.timeout)
293
294 def get_output(self) -> str:
295 """
296 Get output from the last expect operation.
297
298 Returns:
299 Output that appeared before the matched pattern
300 """
301 if self.child is None:
302 return ""
303 return self.child.before or ""
304
305 def get_clean_output(self) -> str:
306 """
307 Get cleaned output, filtering out terminal redraw noise.
308
309 Returns:
310 Cleaned output with prompts and redraws removed
311 """
312 if self.child is None:
313 return ""
314
315 raw = self.child.before or ""
316
317 # Split into lines and filter
318 lines = raw.split('\n')
319 clean_lines = []
320
321 for line in lines:
322 # Skip lines that are mostly prompt redraws
323 if ':: ~' in line and line.count('@') > 1:
324 continue
325 # Skip lines that look like partial prompts
326 if line.strip().endswith('> ') and len(line.strip()) < 10:
327 continue
328 # Keep the line if it has actual content
329 clean_lines.append(line)
330
331 return '\n'.join(clean_lines)
332
333 def get_match(self) -> str:
334 """
335 Get the text that matched the last expect pattern.
336
337 Returns:
338 Matched text
339 """
340 if self.child is None:
341 return ""
342 return self.child.after or ""
343
344 def run_command(self, command: str, timeout: Optional[float] = None) -> str:
345 """
346 Run a command and return its output.
347
348 Uses a unique marker for reliable output detection instead of
349 prompt matching, which can be unreliable with terminal redraws.
350
351 Args:
352 command: Shell command to run
353 timeout: Timeout for the command
354
355 Returns:
356 Command output (excluding the prompt and marker)
357 """
358 # Send command followed by marker echo
359 self.send_line(command)
360 self.send_line(f"echo {self.END_MARKER}")
361
362 # Wait for the marker to appear
363 self.expect(self.END_MARKER, timeout=timeout)
364 output = self.get_output()
365
366 # Clean up the output
367 # Remove the marker echo command and any prompts
368 lines = []
369 for line in output.split('\n'):
370 # Skip lines containing the marker or prompt patterns
371 if self.END_MARKER in line:
372 continue
373 if line.strip().startswith(('>', '$', '#', '%')):
374 continue
375 # Skip the echoed command
376 if command in line:
377 continue
378 lines.append(line)
379
380 return '\n'.join(lines).strip()
381
382 def interrupt(self) -> None:
383 """Send Ctrl+C to interrupt current operation."""
384 self.send_key("C-c")
385
386 def suspend(self) -> None:
387 """Send Ctrl+Z to suspend current operation."""
388 self.send_key("C-z")
389
390 def eof(self) -> None:
391 """Send Ctrl+D (EOF)."""
392 self.send_key("C-d")
393
394 @property
395 def is_running(self) -> bool:
396 """Check if shell is still running."""
397 if self.child is None:
398 return False
399 return self.child.isalive()
400
401 def set_terminal_size(self, rows: int, cols: int) -> None:
402 """
403 Change the terminal dimensions (triggers SIGWINCH).
404
405 Args:
406 rows: Number of rows
407 cols: Number of columns
408 """
409 if self.child is None:
410 raise RuntimeError("shell not started")
411 self.child.setwinsize(rows, cols)
412
413
414 class ShellTestSession:
415 """
416 Context manager for shell test sessions.
417
418 Usage:
419 with ShellTestSession() as sh:
420 sh.send_line("echo hello")
421 output = sh.wait_for_prompt()
422 assert "hello" in output
423 """
424
425 def __init__(self, **kwargs):
426 self.kwargs = kwargs
427 self.pty: Optional[ShellPTY] = None
428
429 def __enter__(self) -> ShellPTY:
430 self.pty = ShellPTY(**self.kwargs)
431 self.pty.start()
432 return self.pty
433
434 def __exit__(self, exc_type, exc_val, exc_tb):
435 if self.pty:
436 self.pty.stop()
437 return False
438
439
440 # Convenience function for quick testing
441 def quick_test(command: str, shell_path: str = None) -> str:
442 """
443 Run a single command in shell and return output.
444
445 Args:
446 command: Command to run
447 shell_path: Path to shell binary
448
449 Returns:
450 Command output
451 """
452 with ShellTestSession(shell_path=shell_path) as sh:
453 return sh.run_command(command)