Rust · 13454 bytes Raw Blame History
1 //! TLS certificate management for HyprKVM
2 //!
3 //! Handles certificate generation, loading, and fingerprint calculation.
4
5 use std::fs;
6 use std::io::{self, BufReader};
7 use std::path::Path;
8 use std::sync::Arc;
9
10 use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256};
11 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
12 use rustls::{ClientConfig, ServerConfig};
13 use rustls_pemfile::{certs, private_key};
14 use sha2::{Digest, Sha256};
15 use tokio_rustls::{TlsAcceptor, TlsConnector};
16
17 /// TLS-related errors
18 #[derive(Debug, thiserror::Error)]
19 pub enum TlsError {
20 #[error("IO error: {0}")]
21 Io(#[from] io::Error),
22
23 #[error("Certificate generation error: {0}")]
24 CertGen(String),
25
26 #[error("Certificate loading error: {0}")]
27 CertLoad(String),
28
29 #[error("TLS configuration error: {0}")]
30 Config(String),
31
32 #[error("Certificate verification failed: {0}")]
33 Verification(String),
34 }
35
36 /// Certificate fingerprint (SHA-256)
37 #[derive(Debug, Clone, PartialEq, Eq)]
38 pub struct Fingerprint(pub [u8; 32]);
39
40 impl Fingerprint {
41 /// Calculate fingerprint from DER-encoded certificate
42 pub fn from_der(der: &[u8]) -> Self {
43 let mut hasher = Sha256::new();
44 hasher.update(der);
45 let result = hasher.finalize();
46 let mut bytes = [0u8; 32];
47 bytes.copy_from_slice(&result);
48 Self(bytes)
49 }
50
51 /// Convert to hex string (colon-separated, uppercase)
52 pub fn to_hex(&self) -> String {
53 self.0
54 .iter()
55 .map(|b| format!("{:02X}", b))
56 .collect::<Vec<_>>()
57 .join(":")
58 }
59
60 /// Parse from hex string (colon-separated or continuous)
61 pub fn from_hex(s: &str) -> Result<Self, TlsError> {
62 let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
63 if clean.len() != 64 {
64 return Err(TlsError::Verification(format!(
65 "Invalid fingerprint length: expected 64 hex chars, got {}",
66 clean.len()
67 )));
68 }
69
70 let mut bytes = [0u8; 32];
71 for (i, chunk) in clean.as_bytes().chunks(2).enumerate() {
72 let hex = std::str::from_utf8(chunk).unwrap();
73 bytes[i] = u8::from_str_radix(hex, 16)
74 .map_err(|e| TlsError::Verification(format!("Invalid hex: {}", e)))?;
75 }
76
77 Ok(Self(bytes))
78 }
79 }
80
81 impl std::fmt::Display for Fingerprint {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "{}", self.to_hex())
84 }
85 }
86
87 /// Generate a self-signed ECDSA P-256 certificate
88 pub fn generate_certificate(machine_name: &str) -> Result<(String, String), TlsError> {
89 tracing::info!("Generating new self-signed certificate for '{}'", machine_name);
90
91 // Generate ECDSA P-256 key pair
92 let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
93 .map_err(|e| TlsError::CertGen(format!("Failed to generate key pair: {}", e)))?;
94
95 // Build certificate parameters
96 let mut params = CertificateParams::default();
97
98 // Set distinguished name
99 let mut dn = DistinguishedName::new();
100 dn.push(DnType::CommonName, machine_name);
101 dn.push(DnType::OrganizationName, "HyprKVM");
102 params.distinguished_name = dn;
103
104 // Set validity (10 years)
105 params.not_before = time::OffsetDateTime::now_utc();
106 params.not_after = params.not_before + time::Duration::days(3650);
107
108 // Add Subject Alternative Names
109 params.subject_alt_names = vec![
110 rcgen::SanType::DnsName(machine_name.try_into().map_err(|e| {
111 TlsError::CertGen(format!("Invalid machine name for SAN: {}", e))
112 })?),
113 ];
114
115 // Generate certificate
116 let cert = params
117 .self_signed(&key_pair)
118 .map_err(|e| TlsError::CertGen(format!("Failed to generate certificate: {}", e)))?;
119
120 // Serialize to PEM
121 let cert_pem = cert.pem();
122 let key_pem = key_pair.serialize_pem();
123
124 tracing::info!("Certificate generated successfully");
125
126 Ok((cert_pem, key_pem))
127 }
128
129 /// Ensure certificate exists, generating if needed
130 pub fn ensure_certificate(
131 cert_path: &Path,
132 key_path: &Path,
133 machine_name: &str,
134 ) -> Result<(), TlsError> {
135 if cert_path.exists() && key_path.exists() {
136 tracing::debug!("Certificate files exist: {:?}, {:?}", cert_path, key_path);
137 return Ok(());
138 }
139
140 tracing::info!("Certificate files not found, generating new certificate");
141
142 let (cert_pem, key_pem) = generate_certificate(machine_name)?;
143
144 // Ensure parent directories exist
145 if let Some(parent) = cert_path.parent() {
146 fs::create_dir_all(parent)?;
147 }
148 if let Some(parent) = key_path.parent() {
149 fs::create_dir_all(parent)?;
150 }
151
152 // Write certificate
153 fs::write(cert_path, &cert_pem)?;
154 tracing::info!("Wrote certificate to {:?}", cert_path);
155
156 // Write private key with restrictive permissions
157 #[cfg(unix)]
158 {
159 use std::os::unix::fs::OpenOptionsExt;
160 let mut opts = fs::OpenOptions::new();
161 opts.write(true).create(true).truncate(true).mode(0o600);
162 let mut file = opts.open(key_path)?;
163 std::io::Write::write_all(&mut file, key_pem.as_bytes())?;
164 }
165 #[cfg(not(unix))]
166 {
167 fs::write(key_path, &key_pem)?;
168 }
169 tracing::info!("Wrote private key to {:?}", key_path);
170
171 Ok(())
172 }
173
174 /// Load certificate chain from PEM file
175 pub fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
176 let file = fs::File::open(path)
177 .map_err(|e| TlsError::CertLoad(format!("Failed to open {:?}: {}", path, e)))?;
178 let mut reader = BufReader::new(file);
179
180 let certs: Vec<CertificateDer<'static>> = certs(&mut reader)
181 .collect::<Result<Vec<_>, _>>()
182 .map_err(|e| TlsError::CertLoad(format!("Failed to parse certificates: {}", e)))?;
183
184 if certs.is_empty() {
185 return Err(TlsError::CertLoad("No certificates found in file".into()));
186 }
187
188 Ok(certs)
189 }
190
191 /// Load private key from PEM file
192 pub fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
193 let file = fs::File::open(path)
194 .map_err(|e| TlsError::CertLoad(format!("Failed to open {:?}: {}", path, e)))?;
195 let mut reader = BufReader::new(file);
196
197 private_key(&mut reader)
198 .map_err(|e| TlsError::CertLoad(format!("Failed to parse private key: {}", e)))?
199 .ok_or_else(|| TlsError::CertLoad("No private key found in file".into()))
200 }
201
202 /// Get fingerprint of certificate file
203 pub fn get_cert_fingerprint(path: &Path) -> Result<Fingerprint, TlsError> {
204 let certs = load_certs(path)?;
205 if let Some(cert) = certs.first() {
206 Ok(Fingerprint::from_der(cert.as_ref()))
207 } else {
208 Err(TlsError::CertLoad("No certificate in file".into()))
209 }
210 }
211
212 /// Create a TLS acceptor for server-side connections
213 pub fn create_tls_acceptor(cert_path: &Path, key_path: &Path) -> Result<TlsAcceptor, TlsError> {
214 let certs = load_certs(cert_path)?;
215 let key = load_private_key(key_path)?;
216
217 let config = ServerConfig::builder()
218 .with_no_client_auth()
219 .with_single_cert(certs, key)
220 .map_err(|e| TlsError::Config(format!("Failed to create server config: {}", e)))?;
221
222 Ok(TlsAcceptor::from(Arc::new(config)))
223 }
224
225 /// Custom certificate verifier that supports TOFU and fingerprint pinning
226 #[derive(Debug)]
227 pub struct HyprkvmCertVerifier {
228 /// Expected fingerprint (if pinned)
229 expected_fingerprint: Option<Fingerprint>,
230 /// Callback to handle TOFU (returns true to accept, false to reject)
231 /// If None, TOFU is disabled
232 tofu_enabled: bool,
233 }
234
235 impl HyprkvmCertVerifier {
236 /// Create verifier with pinned fingerprint
237 #[allow(dead_code)]
238 pub fn with_fingerprint(fingerprint: Fingerprint) -> Self {
239 Self {
240 expected_fingerprint: Some(fingerprint),
241 tofu_enabled: false,
242 }
243 }
244
245 /// Create verifier with TOFU enabled
246 #[allow(dead_code)]
247 pub fn with_tofu() -> Self {
248 Self {
249 expected_fingerprint: None,
250 tofu_enabled: true,
251 }
252 }
253
254 /// Create verifier with fingerprint, falling back to TOFU if not set
255 pub fn new(fingerprint: Option<Fingerprint>, tofu_enabled: bool) -> Self {
256 Self {
257 expected_fingerprint: fingerprint,
258 tofu_enabled,
259 }
260 }
261 }
262
263 impl rustls::client::danger::ServerCertVerifier for HyprkvmCertVerifier {
264 fn verify_server_cert(
265 &self,
266 end_entity: &CertificateDer<'_>,
267 _intermediates: &[CertificateDer<'_>],
268 _server_name: &rustls::pki_types::ServerName<'_>,
269 _ocsp_response: &[u8],
270 _now: rustls::pki_types::UnixTime,
271 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
272 let fingerprint = Fingerprint::from_der(end_entity.as_ref());
273
274 if let Some(ref expected) = self.expected_fingerprint {
275 // Fingerprint pinning mode
276 if fingerprint == *expected {
277 tracing::debug!("Certificate fingerprint matches pinned value");
278 return Ok(rustls::client::danger::ServerCertVerified::assertion());
279 } else {
280 tracing::error!(
281 "Certificate fingerprint mismatch! Expected: {}, Got: {}",
282 expected,
283 fingerprint
284 );
285 return Err(rustls::Error::General(
286 "Certificate fingerprint mismatch".into(),
287 ));
288 }
289 }
290
291 if self.tofu_enabled {
292 // TOFU mode - accept any certificate but log for manual verification
293 tracing::warn!(
294 "TOFU: Accepting certificate with fingerprint: {}",
295 fingerprint
296 );
297 tracing::warn!("TOFU: Add this fingerprint to config to pin the certificate");
298 return Ok(rustls::client::danger::ServerCertVerified::assertion());
299 }
300
301 Err(rustls::Error::General(
302 "No verification method configured".into(),
303 ))
304 }
305
306 fn verify_tls12_signature(
307 &self,
308 message: &[u8],
309 cert: &CertificateDer<'_>,
310 dss: &rustls::DigitallySignedStruct,
311 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
312 rustls::crypto::verify_tls12_signature(
313 message,
314 cert,
315 dss,
316 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
317 )
318 }
319
320 fn verify_tls13_signature(
321 &self,
322 message: &[u8],
323 cert: &CertificateDer<'_>,
324 dss: &rustls::DigitallySignedStruct,
325 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
326 rustls::crypto::verify_tls13_signature(
327 message,
328 cert,
329 dss,
330 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
331 )
332 }
333
334 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
335 rustls::crypto::aws_lc_rs::default_provider()
336 .signature_verification_algorithms
337 .supported_schemes()
338 }
339 }
340
341 /// Create a TLS connector for client-side connections
342 pub fn create_tls_connector(
343 expected_fingerprint: Option<&str>,
344 tofu_enabled: bool,
345 ) -> Result<TlsConnector, TlsError> {
346 let fingerprint = match expected_fingerprint {
347 Some(fp) => Some(Fingerprint::from_hex(fp)?),
348 None => None,
349 };
350
351 let verifier = HyprkvmCertVerifier::new(fingerprint, tofu_enabled);
352
353 let config = ClientConfig::builder()
354 .dangerous()
355 .with_custom_certificate_verifier(Arc::new(verifier))
356 .with_no_client_auth();
357
358 Ok(TlsConnector::from(Arc::new(config)))
359 }
360
361 #[cfg(test)]
362 mod tests {
363 use super::*;
364 use tempfile::tempdir;
365
366 #[test]
367 fn test_fingerprint_roundtrip() {
368 let bytes = [0xABu8; 32];
369 let fp = Fingerprint(bytes);
370 let hex = fp.to_hex();
371 let parsed = Fingerprint::from_hex(&hex).unwrap();
372 assert_eq!(fp, parsed);
373 }
374
375 #[test]
376 fn test_generate_certificate() {
377 let (cert_pem, key_pem) = generate_certificate("test-machine").unwrap();
378 assert!(cert_pem.contains("-----BEGIN CERTIFICATE-----"));
379 assert!(key_pem.contains("-----BEGIN PRIVATE KEY-----"));
380 }
381
382 #[test]
383 fn test_ensure_certificate() {
384 let dir = tempdir().unwrap();
385 let cert_path = dir.path().join("cert.pem");
386 let key_path = dir.path().join("key.pem");
387
388 ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
389
390 assert!(cert_path.exists());
391 assert!(key_path.exists());
392
393 // Second call should not regenerate
394 let cert_content = fs::read_to_string(&cert_path).unwrap();
395 ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
396 let cert_content2 = fs::read_to_string(&cert_path).unwrap();
397 assert_eq!(cert_content, cert_content2);
398 }
399
400 #[test]
401 fn test_fingerprint_calculation() {
402 let dir = tempdir().unwrap();
403 let cert_path = dir.path().join("cert.pem");
404 let key_path = dir.path().join("key.pem");
405
406 ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
407
408 let fp = get_cert_fingerprint(&cert_path).unwrap();
409 assert_eq!(fp.to_hex().len(), 95); // 32 bytes * 2 + 31 colons
410 }
411 }
412