tenseleyflow/sultree / b8f423f

Browse files

test files

Authored by espadonne
SHA
b8f423f96eb4a678dac88189a6a19c7e6cef6a51
Parents
5423f57
Tree
525dfcb

5 changed files

StatusFile+-
A tests/__init__.py 1 0
A tests/__pycache__/test_args.cpython-39.pyc bin
A tests/__pycache__/test_selinux.cpython-39.pyc bin
A tests/test_args.py 212 0
A tests/test_selinux.py 177 0
tests/__init__.pyadded
@@ -0,0 +1,1 @@
1
+"""Test suite for sultree."""
tests/__pycache__/test_args.cpython-39.pycadded
Binary file changed.
tests/__pycache__/test_selinux.cpython-39.pycadded
Binary file changed.
tests/test_args.pyadded
@@ -0,0 +1,212 @@
1
+"""
2
+Tests for command-line argument parsing.
3
+"""
4
+
5
+import unittest
6
+from unittest.mock import patch
7
+from pathlib import Path
8
+import sys
9
+import io
10
+
11
+from sultree.args import (
12
+    create_argument_parser,
13
+    parse_arguments,
14
+    validate_depth,
15
+    validate_pattern,
16
+    validate_selinux_pattern,
17
+    validate_path
18
+)
19
+
20
+
21
+class TestArgumentValidation(unittest.TestCase):
22
+    """Test argument validation functions."""
23
+    
24
+    def test_validate_depth(self):
25
+        """Test depth validation."""
26
+        self.assertEqual(validate_depth("5"), 5)
27
+        self.assertEqual(validate_depth("0"), 0)
28
+        
29
+        with self.assertRaises(Exception):
30
+            validate_depth("-1")  # Negative depth
31
+            
32
+        with self.assertRaises(Exception):
33
+            validate_depth("not_a_number")
34
+            
35
+        with self.assertRaises(Exception):
36
+            validate_depth("1001")  # Too large
37
+    
38
+    def test_validate_pattern(self):
39
+        """Test pattern validation."""
40
+        # Valid patterns
41
+        self.assertEqual(validate_pattern("*.txt"), "*.txt")
42
+        self.assertEqual(validate_pattern("test_*"), "test_*")
43
+        
44
+        # Invalid patterns
45
+        with self.assertRaises(Exception):
46
+            validate_pattern("")  # Empty pattern
47
+            
48
+        with self.assertRaises(Exception):
49
+            validate_pattern("x" * 1001)  # Too long
50
+    
51
+    def test_validate_selinux_pattern(self):
52
+        """Test SELinux pattern validation."""
53
+        # Valid patterns
54
+        self.assertEqual(validate_selinux_pattern("passwd_file_t"), "passwd_file_t")
55
+        self.assertEqual(validate_selinux_pattern("system_u:object_r:*:s0"), "system_u:object_r:*:s0")
56
+        
57
+        # Invalid patterns with dangerous characters
58
+        with self.assertRaises(Exception):
59
+            validate_selinux_pattern("test; rm -rf /")
60
+            
61
+        with self.assertRaises(Exception):
62
+            validate_selinux_pattern("test|malicious")
63
+            
64
+        with self.assertRaises(Exception):
65
+            validate_selinux_pattern("test`command`")
66
+    
67
+    def test_validate_path(self):
68
+        """Test path validation."""
69
+        # Valid paths
70
+        path = validate_path("/etc/passwd")
71
+        self.assertEqual(str(path), "/etc/passwd")
72
+        
73
+        # Invalid paths
74
+        with self.assertRaises(Exception):
75
+            validate_path("")  # Empty path
76
+            
77
+        with self.assertRaises(Exception):
78
+            validate_path("x" * 5000)  # Too long
79
+            
80
+        with self.assertRaises(Exception):
81
+            validate_path("/path/with/\x00/null")  # Null bytes
82
+
83
+
84
+class TestArgumentParser(unittest.TestCase):
85
+    """Test argument parser functionality."""
86
+    
87
+    def setUp(self):
88
+        """Set up test fixtures."""
89
+        self.parser = create_argument_parser()
90
+    
91
+    def test_basic_parsing(self):
92
+        """Test basic argument parsing."""
93
+        args = self.parser.parse_args(['/etc'])
94
+        
95
+        self.assertEqual(len(args.directories), 1)
96
+        self.assertEqual(str(args.directories[0]), '/etc')
97
+        self.assertFalse(args.all)
98
+        self.assertFalse(args.dirs_only)
99
+    
100
+    def test_flag_parsing(self):
101
+        """Test various flags."""
102
+        args = self.parser.parse_args(['-a', '-d', '-l', '/tmp'])
103
+        
104
+        self.assertTrue(args.all)
105
+        self.assertTrue(args.dirs_only)
106
+        self.assertTrue(args.follow_links)
107
+    
108
+    def test_pattern_parsing(self):
109
+        """Test pattern argument parsing."""
110
+        args = self.parser.parse_args(['-P', '*.txt', '-I', '*.bak', '/tmp'])
111
+        
112
+        self.assertIn('*.txt', args.include_patterns)
113
+        self.assertIn('*.bak', args.exclude_patterns)
114
+    
115
+    def test_selinux_pattern_parsing(self):
116
+        """Test SELinux pattern parsing."""
117
+        args = self.parser.parse_args(['-S', 'passwd_file_t', '-S', '*_exec_t', '/etc'])
118
+        
119
+        self.assertEqual(len(args.selinux_patterns), 2)
120
+        self.assertIn('passwd_file_t', args.selinux_patterns)
121
+        self.assertIn('*_exec_t', args.selinux_patterns)
122
+    
123
+    def test_depth_parsing(self):
124
+        """Test depth limit parsing."""
125
+        args = self.parser.parse_args(['-L', '3', '/tmp'])
126
+        
127
+        self.assertEqual(args.level, 3)
128
+    
129
+    def test_multiple_directories(self):
130
+        """Test multiple directory arguments."""
131
+        args = self.parser.parse_args(['/etc', '/tmp', '/var'])
132
+        
133
+        self.assertEqual(len(args.directories), 3)
134
+        self.assertIn(Path('/etc'), args.directories)
135
+        self.assertIn(Path('/tmp'), args.directories)
136
+        self.assertIn(Path('/var'), args.directories)
137
+
138
+
139
+class TestParseArguments(unittest.TestCase):
140
+    """Test the main parse_arguments function."""
141
+    
142
+    @patch('sultree.selinux.is_selinux_enabled')
143
+    @patch('pathlib.Path.exists')
144
+    @patch('pathlib.Path.is_dir')
145
+    def test_successful_parsing(self, mock_is_dir, mock_exists, mock_selinux_enabled):
146
+        """Test successful argument parsing."""
147
+        # Mock path validation
148
+        mock_exists.return_value = True
149
+        mock_is_dir.return_value = True
150
+        mock_selinux_enabled.return_value = True
151
+        
152
+        args, options = parse_arguments(['-a', '-L', '2', '/etc'])
153
+        
154
+        self.assertTrue(options.show_all)
155
+        self.assertEqual(options.max_depth, 2)
156
+        self.assertEqual(len(args.directories), 1)
157
+    
158
+    @patch('sultree.selinux.is_selinux_enabled')
159
+    @patch('pathlib.Path.exists')  
160
+    @patch('pathlib.Path.is_dir')
161
+    def test_selinux_filter_creation(self, mock_is_dir, mock_exists, mock_selinux_enabled):
162
+        """Test SELinux filter creation."""
163
+        mock_exists.return_value = True
164
+        mock_is_dir.return_value = True
165
+        mock_selinux_enabled.return_value = True
166
+        
167
+        args, options = parse_arguments(['-S', 'passwd_file_t', '/etc'])
168
+        
169
+        self.assertIsNotNone(options.selinux_filter)
170
+    
171
+    @patch('sultree.selinux.is_selinux_enabled')
172
+    @patch('pathlib.Path.exists')
173
+    def test_nonexistent_directory(self, mock_exists, mock_selinux_enabled):
174
+        """Test handling of non-existent directories."""
175
+        mock_exists.return_value = False
176
+        mock_selinux_enabled.return_value = True
177
+        
178
+        # Capture stderr to check error message
179
+        old_stderr = sys.stderr
180
+        sys.stderr = captured_output = io.StringIO()
181
+        
182
+        try:
183
+            with self.assertRaises(SystemExit):
184
+                parse_arguments(['/non/existent/path'])
185
+        finally:
186
+            sys.stderr = old_stderr
187
+        
188
+        self.assertIn("does not exist", captured_output.getvalue())
189
+    
190
+    @patch('sultree.selinux.is_selinux_enabled')
191
+    @patch('pathlib.Path.exists')
192
+    @patch('pathlib.Path.is_dir')
193
+    def test_selinux_unavailable(self, mock_is_dir, mock_exists, mock_selinux_enabled):
194
+        """Test handling when SELinux is unavailable."""
195
+        mock_exists.return_value = True
196
+        mock_is_dir.return_value = True
197
+        mock_selinux_enabled.return_value = False
198
+        
199
+        old_stderr = sys.stderr
200
+        sys.stderr = captured_output = io.StringIO()
201
+        
202
+        try:
203
+            with self.assertRaises(SystemExit):
204
+                parse_arguments(['-S', 'passwd_file_t', '/etc'])
205
+        finally:
206
+            sys.stderr = old_stderr
207
+        
208
+        self.assertIn("SELinux", captured_output.getvalue())
209
+
210
+
211
+if __name__ == '__main__':
212
+    unittest.main()
tests/test_selinux.pyadded
@@ -0,0 +1,177 @@
1
+"""
2
+Tests for SELinux functionality.
3
+"""
4
+
5
+import unittest
6
+from pathlib import Path
7
+from unittest.mock import patch, MagicMock
8
+import subprocess
9
+
10
+from sultree.selinux import (
11
+    SELinuxContext,
12
+    SELinuxFilter,
13
+    get_selinux_context,
14
+    is_selinux_enabled,
15
+    SELinuxQueryError
16
+)
17
+
18
+
19
+class TestSELinuxContext(unittest.TestCase):
20
+    """Test SELinux context parsing."""
21
+    
22
+    def test_valid_context_parsing(self):
23
+        """Test parsing valid SELinux contexts."""
24
+        context = SELinuxContext("system_u:object_r:passwd_file_t:s0")
25
+        
26
+        self.assertEqual(context.user, "system_u")
27
+        self.assertEqual(context.role, "object_r")
28
+        self.assertEqual(context.type, "passwd_file_t")
29
+        self.assertEqual(context.level, "s0")
30
+        self.assertEqual(context.raw, "system_u:object_r:passwd_file_t:s0")
31
+    
32
+    def test_context_with_complex_level(self):
33
+        """Test context with complex MLS level."""
34
+        context = SELinuxContext("user_u:user_r:user_t:s0:c0.c1023")
35
+        
36
+        self.assertEqual(context.user, "user_u")
37
+        self.assertEqual(context.role, "user_r")
38
+        self.assertEqual(context.type, "user_t")
39
+        self.assertEqual(context.level, "s0:c0.c1023")
40
+    
41
+    def test_invalid_context_format(self):
42
+        """Test handling of invalid context formats."""
43
+        with self.assertRaises(ValueError):
44
+            SELinuxContext("invalid")
45
+            
46
+        with self.assertRaises(ValueError):
47
+            SELinuxContext("user:role")  # Too few fields
48
+    
49
+    def test_pattern_matching_exact(self):
50
+        """Test exact pattern matching."""
51
+        context = SELinuxContext("system_u:object_r:passwd_file_t:s0")
52
+        
53
+        self.assertTrue(context.matches_pattern("passwd_file_t"))
54
+        self.assertTrue(context.matches_pattern("system_u"))
55
+        self.assertFalse(context.matches_pattern("shadow_t"))
56
+    
57
+    def test_pattern_matching_wildcard(self):
58
+        """Test wildcard pattern matching."""
59
+        context = SELinuxContext("system_u:object_r:httpd_exec_t:s0")
60
+        
61
+        self.assertTrue(context.matches_pattern("httpd_*"))
62
+        self.assertTrue(context.matches_pattern("*_exec_t"))
63
+        self.assertTrue(context.matches_pattern("*_u:*:*_t:*"))
64
+        self.assertFalse(context.matches_pattern("ssh_*"))
65
+
66
+
67
+class TestSELinuxFilter(unittest.TestCase):
68
+    """Test SELinux filtering functionality."""
69
+    
70
+    def test_filter_initialization(self):
71
+        """Test filter initialization with patterns."""
72
+        patterns = ["passwd_file_t", "*_exec_t", "httpd_*"]
73
+        selinux_filter = SELinuxFilter(patterns)
74
+        
75
+        self.assertEqual(len(selinux_filter.patterns), 3)
76
+        self.assertIn("passwd_file_t", selinux_filter.patterns)
77
+    
78
+    def test_filter_sanitization(self):
79
+        """Test pattern sanitization."""
80
+        dangerous_patterns = [
81
+            "passwd_file_t; rm -rf /",  # Command injection attempt
82
+            "test_t|malicious",         # Pipe attempt
83
+            "normal_pattern_t"          # Normal pattern
84
+        ]
85
+        
86
+        selinux_filter = SELinuxFilter(dangerous_patterns)
87
+        
88
+        # Should have sanitized dangerous patterns
89
+        for pattern in selinux_filter.patterns:
90
+            self.assertNotIn(';', pattern)
91
+            self.assertNotIn('|', pattern)
92
+            self.assertNotIn('`', pattern)
93
+    
94
+    @patch('sultree.selinux.get_selinux_context')
95
+    def test_filter_matching(self, mock_get_context):
96
+        """Test file matching against patterns."""
97
+        # Mock SELinux context
98
+        mock_context = MagicMock()
99
+        mock_context.matches_pattern.return_value = True
100
+        mock_get_context.return_value = mock_context
101
+        
102
+        selinux_filter = SELinuxFilter(["passwd_file_t"])
103
+        test_path = Path("/etc/passwd")
104
+        
105
+        result = selinux_filter.matches(test_path)
106
+        
107
+        self.assertTrue(result)
108
+        mock_get_context.assert_called_once_with(test_path)
109
+
110
+
111
+class TestSELinuxUtilities(unittest.TestCase):
112
+    """Test SELinux utility functions."""
113
+    
114
+    @patch('subprocess.run')
115
+    def test_get_selinux_context_success(self, mock_run):
116
+        """Test successful SELinux context retrieval."""
117
+        # Mock successful getfattr output
118
+        mock_result = MagicMock()
119
+        mock_result.returncode = 0
120
+        mock_result.stdout = "system_u:object_r:passwd_file_t:s0"
121
+        mock_run.return_value = mock_result
122
+        
123
+        context = get_selinux_context(Path("/etc/passwd"))
124
+        
125
+        self.assertIsNotNone(context)
126
+        self.assertEqual(context.type, "passwd_file_t")
127
+        
128
+        # Verify secure command construction
129
+        args, kwargs = mock_run.call_args
130
+        command = args[0]
131
+        self.assertEqual(command[0], 'getfattr')
132
+        self.assertIn('--only-values', command)
133
+        self.assertIn('-n', command)
134
+        self.assertIn('security.selinux', command)
135
+    
136
+    @patch('subprocess.run')
137
+    def test_get_selinux_context_timeout(self, mock_run):
138
+        """Test timeout handling in context retrieval."""
139
+        mock_run.side_effect = subprocess.TimeoutExpired('getfattr', 5)
140
+        
141
+        with self.assertRaises(SELinuxQueryError):
142
+            get_selinux_context(Path("/etc/passwd"))
143
+    
144
+    @patch('subprocess.run')
145
+    def test_get_selinux_context_no_context(self, mock_run):
146
+        """Test handling when no SELinux context exists."""
147
+        mock_result = MagicMock()
148
+        mock_result.returncode = 1  # getfattr returns 1 for no attribute
149
+        mock_result.stdout = ""
150
+        mock_run.return_value = mock_result
151
+        
152
+        context = get_selinux_context(Path("/etc/passwd"))
153
+        
154
+        self.assertIsNone(context)
155
+    
156
+    def test_path_validation(self):
157
+        """Test path validation in get_selinux_context."""
158
+        # Test with non-existent path
159
+        non_existent = Path("/non/existent/path")
160
+        context = get_selinux_context(non_existent)
161
+        
162
+        self.assertIsNone(context)
163
+        
164
+    @patch('sultree.selinux.get_selinux_context')
165
+    def test_is_selinux_enabled(self, mock_get_context):
166
+        """Test SELinux availability detection."""
167
+        # Test when SELinux is available
168
+        mock_get_context.return_value = MagicMock()
169
+        self.assertTrue(is_selinux_enabled())
170
+        
171
+        # Test when SELinux is not available
172
+        mock_get_context.return_value = None
173
+        self.assertFalse(is_selinux_enabled())
174
+
175
+
176
+if __name__ == '__main__':
177
+    unittest.main()