| 1 | """ |
| 2 | Tests for command-line argument parsing. |
| 3 | """ |
| 4 | |
| 5 | import unittest |
| 6 | from unittest.mock import patch |
| 7 | from pathlib import Path |
| 8 | import sys |
| 9 | import io |
| 10 | |
| 11 | from sultree.args import ( |
| 12 | create_argument_parser, |
| 13 | parse_arguments, |
| 14 | validate_depth, |
| 15 | validate_pattern, |
| 16 | validate_selinux_pattern, |
| 17 | validate_path |
| 18 | ) |
| 19 | |
| 20 | |
| 21 | class TestArgumentValidation(unittest.TestCase): |
| 22 | """Test argument validation functions.""" |
| 23 | |
| 24 | def test_validate_depth(self): |
| 25 | """Test depth validation.""" |
| 26 | self.assertEqual(validate_depth("5"), 5) |
| 27 | self.assertEqual(validate_depth("0"), 0) |
| 28 | |
| 29 | with self.assertRaises(Exception): |
| 30 | validate_depth("-1") # Negative depth |
| 31 | |
| 32 | with self.assertRaises(Exception): |
| 33 | validate_depth("not_a_number") |
| 34 | |
| 35 | with self.assertRaises(Exception): |
| 36 | validate_depth("1001") # Too large |
| 37 | |
| 38 | def test_validate_pattern(self): |
| 39 | """Test pattern validation.""" |
| 40 | # Valid patterns |
| 41 | self.assertEqual(validate_pattern("*.txt"), "*.txt") |
| 42 | self.assertEqual(validate_pattern("test_*"), "test_*") |
| 43 | |
| 44 | # Invalid patterns |
| 45 | with self.assertRaises(Exception): |
| 46 | validate_pattern("") # Empty pattern |
| 47 | |
| 48 | with self.assertRaises(Exception): |
| 49 | validate_pattern("x" * 1001) # Too long |
| 50 | |
| 51 | def test_validate_selinux_pattern(self): |
| 52 | """Test SELinux pattern validation.""" |
| 53 | # Valid patterns |
| 54 | self.assertEqual(validate_selinux_pattern("passwd_file_t"), "passwd_file_t") |
| 55 | self.assertEqual(validate_selinux_pattern("system_u:object_r:*:s0"), "system_u:object_r:*:s0") |
| 56 | |
| 57 | # Invalid patterns with dangerous characters |
| 58 | with self.assertRaises(Exception): |
| 59 | validate_selinux_pattern("test; rm -rf /") |
| 60 | |
| 61 | with self.assertRaises(Exception): |
| 62 | validate_selinux_pattern("test|malicious") |
| 63 | |
| 64 | with self.assertRaises(Exception): |
| 65 | validate_selinux_pattern("test`command`") |
| 66 | |
| 67 | def test_validate_path(self): |
| 68 | """Test path validation.""" |
| 69 | # Valid paths |
| 70 | path = validate_path("/etc/passwd") |
| 71 | self.assertEqual(str(path), "/etc/passwd") |
| 72 | |
| 73 | # Invalid paths |
| 74 | with self.assertRaises(Exception): |
| 75 | validate_path("") # Empty path |
| 76 | |
| 77 | with self.assertRaises(Exception): |
| 78 | validate_path("x" * 5000) # Too long |
| 79 | |
| 80 | with self.assertRaises(Exception): |
| 81 | validate_path("/path/with/\x00/null") # Null bytes |
| 82 | |
| 83 | |
| 84 | class TestArgumentParser(unittest.TestCase): |
| 85 | """Test argument parser functionality.""" |
| 86 | |
| 87 | def setUp(self): |
| 88 | """Set up test fixtures.""" |
| 89 | self.parser = create_argument_parser() |
| 90 | |
| 91 | def test_basic_parsing(self): |
| 92 | """Test basic argument parsing.""" |
| 93 | args = self.parser.parse_args(['/etc']) |
| 94 | |
| 95 | self.assertEqual(len(args.directories), 1) |
| 96 | self.assertEqual(str(args.directories[0]), '/etc') |
| 97 | self.assertFalse(args.all) |
| 98 | self.assertFalse(args.dirs_only) |
| 99 | |
| 100 | def test_flag_parsing(self): |
| 101 | """Test various flags.""" |
| 102 | args = self.parser.parse_args(['-a', '-d', '-l', '/tmp']) |
| 103 | |
| 104 | self.assertTrue(args.all) |
| 105 | self.assertTrue(args.dirs_only) |
| 106 | self.assertTrue(args.follow_links) |
| 107 | |
| 108 | def test_pattern_parsing(self): |
| 109 | """Test pattern argument parsing.""" |
| 110 | args = self.parser.parse_args(['-P', '*.txt', '-I', '*.bak', '/tmp']) |
| 111 | |
| 112 | self.assertIn('*.txt', args.include_patterns) |
| 113 | self.assertIn('*.bak', args.exclude_patterns) |
| 114 | |
| 115 | def test_selinux_pattern_parsing(self): |
| 116 | """Test SELinux pattern parsing.""" |
| 117 | args = self.parser.parse_args(['-S', 'passwd_file_t', '-S', '*_exec_t', '/etc']) |
| 118 | |
| 119 | self.assertEqual(len(args.selinux_patterns), 2) |
| 120 | self.assertIn('passwd_file_t', args.selinux_patterns) |
| 121 | self.assertIn('*_exec_t', args.selinux_patterns) |
| 122 | |
| 123 | def test_depth_parsing(self): |
| 124 | """Test depth limit parsing.""" |
| 125 | args = self.parser.parse_args(['-L', '3', '/tmp']) |
| 126 | |
| 127 | self.assertEqual(args.level, 3) |
| 128 | |
| 129 | def test_multiple_directories(self): |
| 130 | """Test multiple directory arguments.""" |
| 131 | args = self.parser.parse_args(['/etc', '/tmp', '/var']) |
| 132 | |
| 133 | self.assertEqual(len(args.directories), 3) |
| 134 | self.assertIn(Path('/etc'), args.directories) |
| 135 | self.assertIn(Path('/tmp'), args.directories) |
| 136 | self.assertIn(Path('/var'), args.directories) |
| 137 | |
| 138 | |
| 139 | class TestParseArguments(unittest.TestCase): |
| 140 | """Test the main parse_arguments function.""" |
| 141 | |
| 142 | @patch('sultree.selinux.is_selinux_enabled') |
| 143 | @patch('pathlib.Path.exists') |
| 144 | @patch('pathlib.Path.is_dir') |
| 145 | def test_successful_parsing(self, mock_is_dir, mock_exists, mock_selinux_enabled): |
| 146 | """Test successful argument parsing.""" |
| 147 | # Mock path validation |
| 148 | mock_exists.return_value = True |
| 149 | mock_is_dir.return_value = True |
| 150 | mock_selinux_enabled.return_value = True |
| 151 | |
| 152 | args, options = parse_arguments(['-a', '-L', '2', '/etc']) |
| 153 | |
| 154 | self.assertTrue(options.show_all) |
| 155 | self.assertEqual(options.max_depth, 2) |
| 156 | self.assertEqual(len(args.directories), 1) |
| 157 | |
| 158 | @patch('sultree.selinux.is_selinux_enabled') |
| 159 | @patch('pathlib.Path.exists') |
| 160 | @patch('pathlib.Path.is_dir') |
| 161 | def test_selinux_filter_creation(self, mock_is_dir, mock_exists, mock_selinux_enabled): |
| 162 | """Test SELinux filter creation.""" |
| 163 | mock_exists.return_value = True |
| 164 | mock_is_dir.return_value = True |
| 165 | mock_selinux_enabled.return_value = True |
| 166 | |
| 167 | args, options = parse_arguments(['-S', 'passwd_file_t', '/etc']) |
| 168 | |
| 169 | self.assertIsNotNone(options.selinux_filter) |
| 170 | |
| 171 | @patch('sultree.selinux.is_selinux_enabled') |
| 172 | @patch('pathlib.Path.exists') |
| 173 | def test_nonexistent_directory(self, mock_exists, mock_selinux_enabled): |
| 174 | """Test handling of non-existent directories.""" |
| 175 | mock_exists.return_value = False |
| 176 | mock_selinux_enabled.return_value = True |
| 177 | |
| 178 | # Capture stderr to check error message |
| 179 | old_stderr = sys.stderr |
| 180 | sys.stderr = captured_output = io.StringIO() |
| 181 | |
| 182 | try: |
| 183 | with self.assertRaises(SystemExit): |
| 184 | parse_arguments(['/non/existent/path']) |
| 185 | finally: |
| 186 | sys.stderr = old_stderr |
| 187 | |
| 188 | self.assertIn("does not exist", captured_output.getvalue()) |
| 189 | |
| 190 | @patch('sultree.selinux.is_selinux_enabled') |
| 191 | @patch('pathlib.Path.exists') |
| 192 | @patch('pathlib.Path.is_dir') |
| 193 | def test_selinux_unavailable(self, mock_is_dir, mock_exists, mock_selinux_enabled): |
| 194 | """Test handling when SELinux is unavailable.""" |
| 195 | mock_exists.return_value = True |
| 196 | mock_is_dir.return_value = True |
| 197 | mock_selinux_enabled.return_value = False |
| 198 | |
| 199 | old_stderr = sys.stderr |
| 200 | sys.stderr = captured_output = io.StringIO() |
| 201 | |
| 202 | try: |
| 203 | with self.assertRaises(SystemExit): |
| 204 | parse_arguments(['-S', 'passwd_file_t', '/etc']) |
| 205 | finally: |
| 206 | sys.stderr = old_stderr |
| 207 | |
| 208 | self.assertIn("SELinux", captured_output.getvalue()) |
| 209 | |
| 210 | |
| 211 | if __name__ == '__main__': |
| 212 | unittest.main() |