| 1 | """ |
| 2 | Tests for SELinux functionality. |
| 3 | """ |
| 4 | |
| 5 | import unittest |
| 6 | from pathlib import Path |
| 7 | from unittest.mock import patch, MagicMock |
| 8 | import subprocess |
| 9 | |
| 10 | from sultree.selinux import ( |
| 11 | SELinuxContext, |
| 12 | SELinuxFilter, |
| 13 | get_selinux_context, |
| 14 | is_selinux_enabled, |
| 15 | SELinuxQueryError, |
| 16 | clear_selinux_cache, |
| 17 | get_cache_info |
| 18 | ) |
| 19 | |
| 20 | |
| 21 | class TestSELinuxContext(unittest.TestCase): |
| 22 | """Test SELinux context parsing.""" |
| 23 | |
| 24 | def test_valid_context_parsing(self): |
| 25 | """Test parsing valid SELinux contexts.""" |
| 26 | context = SELinuxContext("system_u:object_r:passwd_file_t:s0") |
| 27 | |
| 28 | self.assertEqual(context.user, "system_u") |
| 29 | self.assertEqual(context.role, "object_r") |
| 30 | self.assertEqual(context.type, "passwd_file_t") |
| 31 | self.assertEqual(context.level, "s0") |
| 32 | self.assertEqual(context.raw, "system_u:object_r:passwd_file_t:s0") |
| 33 | |
| 34 | def test_context_with_complex_level(self): |
| 35 | """Test context with complex MLS level.""" |
| 36 | context = SELinuxContext("user_u:user_r:user_t:s0:c0.c1023") |
| 37 | |
| 38 | self.assertEqual(context.user, "user_u") |
| 39 | self.assertEqual(context.role, "user_r") |
| 40 | self.assertEqual(context.type, "user_t") |
| 41 | self.assertEqual(context.level, "s0:c0.c1023") |
| 42 | |
| 43 | def test_invalid_context_format(self): |
| 44 | """Test handling of invalid context formats.""" |
| 45 | with self.assertRaises(ValueError): |
| 46 | SELinuxContext("invalid") |
| 47 | |
| 48 | with self.assertRaises(ValueError): |
| 49 | SELinuxContext("user:role") # Too few fields |
| 50 | |
| 51 | def test_pattern_matching_exact(self): |
| 52 | """Test exact pattern matching.""" |
| 53 | context = SELinuxContext("system_u:object_r:passwd_file_t:s0") |
| 54 | |
| 55 | self.assertTrue(context.matches_pattern("passwd_file_t")) |
| 56 | self.assertTrue(context.matches_pattern("system_u")) |
| 57 | self.assertFalse(context.matches_pattern("shadow_t")) |
| 58 | |
| 59 | def test_pattern_matching_wildcard(self): |
| 60 | """Test wildcard pattern matching.""" |
| 61 | context = SELinuxContext("system_u:object_r:httpd_exec_t:s0") |
| 62 | |
| 63 | self.assertTrue(context.matches_pattern("httpd_*")) |
| 64 | self.assertTrue(context.matches_pattern("*_exec_t")) |
| 65 | self.assertTrue(context.matches_pattern("*_u:*:*_t:*")) |
| 66 | self.assertFalse(context.matches_pattern("ssh_*")) |
| 67 | |
| 68 | |
| 69 | class TestSELinuxFilter(unittest.TestCase): |
| 70 | """Test SELinux filtering functionality.""" |
| 71 | |
| 72 | def test_filter_initialization(self): |
| 73 | """Test filter initialization with patterns.""" |
| 74 | patterns = ["passwd_file_t", "*_exec_t", "httpd_*"] |
| 75 | selinux_filter = SELinuxFilter(patterns) |
| 76 | |
| 77 | self.assertEqual(len(selinux_filter.patterns), 3) |
| 78 | self.assertIn("passwd_file_t", selinux_filter.patterns) |
| 79 | |
| 80 | def test_filter_sanitization(self): |
| 81 | """Test pattern sanitization.""" |
| 82 | dangerous_patterns = [ |
| 83 | "passwd_file_t; rm -rf /", # Command injection attempt |
| 84 | "test_t|malicious", # Pipe attempt |
| 85 | "normal_pattern_t" # Normal pattern |
| 86 | ] |
| 87 | |
| 88 | selinux_filter = SELinuxFilter(dangerous_patterns) |
| 89 | |
| 90 | # Should have sanitized dangerous patterns |
| 91 | for pattern in selinux_filter.patterns: |
| 92 | self.assertNotIn(';', pattern) |
| 93 | self.assertNotIn('|', pattern) |
| 94 | self.assertNotIn('`', pattern) |
| 95 | |
| 96 | @patch('sultree.selinux.get_selinux_context') |
| 97 | def test_filter_matching(self, mock_get_context): |
| 98 | """Test file matching against patterns.""" |
| 99 | # Mock SELinux context |
| 100 | mock_context = MagicMock() |
| 101 | mock_context.matches_pattern.return_value = True |
| 102 | mock_get_context.return_value = mock_context |
| 103 | |
| 104 | selinux_filter = SELinuxFilter(["passwd_file_t"]) |
| 105 | test_path = Path("/etc/passwd") |
| 106 | |
| 107 | result = selinux_filter.matches(test_path) |
| 108 | |
| 109 | self.assertTrue(result) |
| 110 | mock_get_context.assert_called_once_with(test_path) |
| 111 | |
| 112 | |
| 113 | class TestSELinuxUtilities(unittest.TestCase): |
| 114 | """Test SELinux utility functions.""" |
| 115 | |
| 116 | @patch('subprocess.run') |
| 117 | def test_get_selinux_context_success(self, mock_run): |
| 118 | """Test successful SELinux context retrieval.""" |
| 119 | # Mock successful getfattr output |
| 120 | mock_result = MagicMock() |
| 121 | mock_result.returncode = 0 |
| 122 | mock_result.stdout = "system_u:object_r:passwd_file_t:s0" |
| 123 | mock_run.return_value = mock_result |
| 124 | |
| 125 | context = get_selinux_context(Path("/etc/passwd")) |
| 126 | |
| 127 | self.assertIsNotNone(context) |
| 128 | self.assertEqual(context.type, "passwd_file_t") |
| 129 | |
| 130 | # Verify secure command construction |
| 131 | args, kwargs = mock_run.call_args |
| 132 | command = args[0] |
| 133 | self.assertEqual(command[0], 'getfattr') |
| 134 | self.assertIn('--only-values', command) |
| 135 | self.assertIn('-n', command) |
| 136 | self.assertIn('security.selinux', command) |
| 137 | |
| 138 | @patch('subprocess.run') |
| 139 | def test_get_selinux_context_timeout(self, mock_run): |
| 140 | """Test timeout handling in context retrieval.""" |
| 141 | mock_run.side_effect = subprocess.TimeoutExpired('getfattr', 5) |
| 142 | |
| 143 | with self.assertRaises(SELinuxQueryError): |
| 144 | get_selinux_context(Path("/etc/passwd")) |
| 145 | |
| 146 | @patch('subprocess.run') |
| 147 | def test_get_selinux_context_no_context(self, mock_run): |
| 148 | """Test handling when no SELinux context exists.""" |
| 149 | mock_result = MagicMock() |
| 150 | mock_result.returncode = 1 # getfattr returns 1 for no attribute |
| 151 | mock_result.stdout = "" |
| 152 | mock_run.return_value = mock_result |
| 153 | |
| 154 | context = get_selinux_context(Path("/etc/passwd")) |
| 155 | |
| 156 | self.assertIsNone(context) |
| 157 | |
| 158 | def test_path_validation(self): |
| 159 | """Test path validation in get_selinux_context.""" |
| 160 | # Test with non-existent path |
| 161 | non_existent = Path("/non/existent/path") |
| 162 | context = get_selinux_context(non_existent) |
| 163 | |
| 164 | self.assertIsNone(context) |
| 165 | |
| 166 | @patch('sultree.selinux.get_selinux_context') |
| 167 | def test_is_selinux_enabled(self, mock_get_context): |
| 168 | """Test SELinux availability detection.""" |
| 169 | # Test when SELinux is available |
| 170 | mock_get_context.return_value = MagicMock() |
| 171 | self.assertTrue(is_selinux_enabled()) |
| 172 | |
| 173 | # Test when SELinux is not available |
| 174 | mock_get_context.return_value = None |
| 175 | self.assertFalse(is_selinux_enabled()) |
| 176 | |
| 177 | |
| 178 | if __name__ == '__main__': |
| 179 | unittest.main() |