Python · 7373 bytes Raw Blame History
1 """
2 Tests for command-line argument parsing.
3 """
4
5 import unittest
6 from unittest.mock import patch
7 from pathlib import Path
8 import sys
9 import io
10
11 from sultree.args import (
12 create_argument_parser,
13 parse_arguments,
14 validate_depth,
15 validate_pattern,
16 validate_selinux_pattern,
17 validate_path
18 )
19
20
21 class TestArgumentValidation(unittest.TestCase):
22 """Test argument validation functions."""
23
24 def test_validate_depth(self):
25 """Test depth validation."""
26 self.assertEqual(validate_depth("5"), 5)
27 self.assertEqual(validate_depth("0"), 0)
28
29 with self.assertRaises(Exception):
30 validate_depth("-1") # Negative depth
31
32 with self.assertRaises(Exception):
33 validate_depth("not_a_number")
34
35 with self.assertRaises(Exception):
36 validate_depth("1001") # Too large
37
38 def test_validate_pattern(self):
39 """Test pattern validation."""
40 # Valid patterns
41 self.assertEqual(validate_pattern("*.txt"), "*.txt")
42 self.assertEqual(validate_pattern("test_*"), "test_*")
43
44 # Invalid patterns
45 with self.assertRaises(Exception):
46 validate_pattern("") # Empty pattern
47
48 with self.assertRaises(Exception):
49 validate_pattern("x" * 1001) # Too long
50
51 def test_validate_selinux_pattern(self):
52 """Test SELinux pattern validation."""
53 # Valid patterns
54 self.assertEqual(validate_selinux_pattern("passwd_file_t"), "passwd_file_t")
55 self.assertEqual(validate_selinux_pattern("system_u:object_r:*:s0"), "system_u:object_r:*:s0")
56
57 # Invalid patterns with dangerous characters
58 with self.assertRaises(Exception):
59 validate_selinux_pattern("test; rm -rf /")
60
61 with self.assertRaises(Exception):
62 validate_selinux_pattern("test|malicious")
63
64 with self.assertRaises(Exception):
65 validate_selinux_pattern("test`command`")
66
67 def test_validate_path(self):
68 """Test path validation."""
69 # Valid paths
70 path = validate_path("/etc/passwd")
71 self.assertEqual(str(path), "/etc/passwd")
72
73 # Invalid paths
74 with self.assertRaises(Exception):
75 validate_path("") # Empty path
76
77 with self.assertRaises(Exception):
78 validate_path("x" * 5000) # Too long
79
80 with self.assertRaises(Exception):
81 validate_path("/path/with/\x00/null") # Null bytes
82
83
84 class TestArgumentParser(unittest.TestCase):
85 """Test argument parser functionality."""
86
87 def setUp(self):
88 """Set up test fixtures."""
89 self.parser = create_argument_parser()
90
91 def test_basic_parsing(self):
92 """Test basic argument parsing."""
93 args = self.parser.parse_args(['/etc'])
94
95 self.assertEqual(len(args.directories), 1)
96 self.assertEqual(str(args.directories[0]), '/etc')
97 self.assertFalse(args.all)
98 self.assertFalse(args.dirs_only)
99
100 def test_flag_parsing(self):
101 """Test various flags."""
102 args = self.parser.parse_args(['-a', '-d', '-l', '/tmp'])
103
104 self.assertTrue(args.all)
105 self.assertTrue(args.dirs_only)
106 self.assertTrue(args.follow_links)
107
108 def test_pattern_parsing(self):
109 """Test pattern argument parsing."""
110 args = self.parser.parse_args(['-P', '*.txt', '-I', '*.bak', '/tmp'])
111
112 self.assertIn('*.txt', args.include_patterns)
113 self.assertIn('*.bak', args.exclude_patterns)
114
115 def test_selinux_pattern_parsing(self):
116 """Test SELinux pattern parsing."""
117 args = self.parser.parse_args(['-S', 'passwd_file_t', '-S', '*_exec_t', '/etc'])
118
119 self.assertEqual(len(args.selinux_patterns), 2)
120 self.assertIn('passwd_file_t', args.selinux_patterns)
121 self.assertIn('*_exec_t', args.selinux_patterns)
122
123 def test_depth_parsing(self):
124 """Test depth limit parsing."""
125 args = self.parser.parse_args(['-L', '3', '/tmp'])
126
127 self.assertEqual(args.level, 3)
128
129 def test_multiple_directories(self):
130 """Test multiple directory arguments."""
131 args = self.parser.parse_args(['/etc', '/tmp', '/var'])
132
133 self.assertEqual(len(args.directories), 3)
134 self.assertIn(Path('/etc'), args.directories)
135 self.assertIn(Path('/tmp'), args.directories)
136 self.assertIn(Path('/var'), args.directories)
137
138
139 class TestParseArguments(unittest.TestCase):
140 """Test the main parse_arguments function."""
141
142 @patch('sultree.selinux.is_selinux_enabled')
143 @patch('pathlib.Path.exists')
144 @patch('pathlib.Path.is_dir')
145 def test_successful_parsing(self, mock_is_dir, mock_exists, mock_selinux_enabled):
146 """Test successful argument parsing."""
147 # Mock path validation
148 mock_exists.return_value = True
149 mock_is_dir.return_value = True
150 mock_selinux_enabled.return_value = True
151
152 args, options = parse_arguments(['-a', '-L', '2', '/etc'])
153
154 self.assertTrue(options.show_all)
155 self.assertEqual(options.max_depth, 2)
156 self.assertEqual(len(args.directories), 1)
157
158 @patch('sultree.selinux.is_selinux_enabled')
159 @patch('pathlib.Path.exists')
160 @patch('pathlib.Path.is_dir')
161 def test_selinux_filter_creation(self, mock_is_dir, mock_exists, mock_selinux_enabled):
162 """Test SELinux filter creation."""
163 mock_exists.return_value = True
164 mock_is_dir.return_value = True
165 mock_selinux_enabled.return_value = True
166
167 args, options = parse_arguments(['-S', 'passwd_file_t', '/etc'])
168
169 self.assertIsNotNone(options.selinux_filter)
170
171 @patch('sultree.selinux.is_selinux_enabled')
172 @patch('pathlib.Path.exists')
173 def test_nonexistent_directory(self, mock_exists, mock_selinux_enabled):
174 """Test handling of non-existent directories."""
175 mock_exists.return_value = False
176 mock_selinux_enabled.return_value = True
177
178 # Capture stderr to check error message
179 old_stderr = sys.stderr
180 sys.stderr = captured_output = io.StringIO()
181
182 try:
183 with self.assertRaises(SystemExit):
184 parse_arguments(['/non/existent/path'])
185 finally:
186 sys.stderr = old_stderr
187
188 self.assertIn("does not exist", captured_output.getvalue())
189
190 @patch('sultree.selinux.is_selinux_enabled')
191 @patch('pathlib.Path.exists')
192 @patch('pathlib.Path.is_dir')
193 def test_selinux_unavailable(self, mock_is_dir, mock_exists, mock_selinux_enabled):
194 """Test handling when SELinux is unavailable."""
195 mock_exists.return_value = True
196 mock_is_dir.return_value = True
197 mock_selinux_enabled.return_value = False
198
199 old_stderr = sys.stderr
200 sys.stderr = captured_output = io.StringIO()
201
202 try:
203 with self.assertRaises(SystemExit):
204 parse_arguments(['-S', 'passwd_file_t', '/etc'])
205 finally:
206 sys.stderr = old_stderr
207
208 self.assertIn("SELinux", captured_output.getvalue())
209
210
211 if __name__ == '__main__':
212 unittest.main()