tenseleyflow/hyprkvm / ce63a95

Browse files

feat: add TLS infrastructure for encrypted peer connections

Implements TLS support with certificate pinning and TOFU:

- tls.rs: Certificate generation (ECDSA P-256), loading, fingerprint
calculation (SHA-256), TLS acceptor/connector with custom verifier
- known_hosts.rs: TOFU storage for trusted fingerprints, persisted to
~/.config/hyprkvm/known_hosts.toml
- transport.rs: Updated to support both plain TCP and TLS connections
with Stream enum wrapping TcpStream/TlsStream variants
- Added sha2 and time dependencies for crypto and cert validity

The transport layer now provides:
- Server::bind_tls() for TLS-enabled server
- connect_tls() for TLS client connections with fingerprint pinning
- Peer certificate fingerprint extraction after handshake

Integration with main.rs pending - this commit provides the foundation.
Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
ce63a953cf1e6ae263b876ec26ac0b5df86fe34e
Parents
29b886d
Tree
fd04ed5

7 changed files

StatusFile+-
M Cargo.lock 73 0
M Cargo.toml 3 1
M hyprkvm-daemon/Cargo.toml 2 0
A hyprkvm-daemon/src/network/known_hosts.rs 226 0
M hyprkvm-daemon/src/network/mod.rs 9 1
A hyprkvm-daemon/src/network/tls.rs 409 0
M hyprkvm-daemon/src/network/transport.rs 198 13
Cargo.lockmodified
@@ -125,6 +125,15 @@ dependencies = [
125125
  "wyz",
126126
 ]
127127
 
128
+[[package]]
129
+name = "block-buffer"
130
+version = "0.10.4"
131
+source = "registry+https://github.com/rust-lang/crates.io-index"
132
+checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
133
+dependencies = [
134
+ "generic-array",
135
+]
136
+
128137
 [[package]]
129138
 name = "bytemuck"
130139
 version = "1.24.0"
@@ -259,6 +268,15 @@ dependencies = [
259268
  "crossbeam-utils",
260269
 ]
261270
 
271
+[[package]]
272
+name = "cpufeatures"
273
+version = "0.2.17"
274
+source = "registry+https://github.com/rust-lang/crates.io-index"
275
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
276
+dependencies = [
277
+ "libc",
278
+]
279
+
262280
 [[package]]
263281
 name = "crossbeam-channel"
264282
 version = "0.5.15"
@@ -274,6 +292,16 @@ version = "0.8.21"
274292
 source = "registry+https://github.com/rust-lang/crates.io-index"
275293
 checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
276294
 
295
+[[package]]
296
+name = "crypto-common"
297
+version = "0.1.7"
298
+source = "registry+https://github.com/rust-lang/crates.io-index"
299
+checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
300
+dependencies = [
301
+ "generic-array",
302
+ "typenum",
303
+]
304
+
277305
 [[package]]
278306
 name = "cursor-icon"
279307
 version = "1.2.0"
@@ -289,6 +317,16 @@ dependencies = [
289317
  "powerfmt",
290318
 ]
291319
 
320
+[[package]]
321
+name = "digest"
322
+version = "0.10.7"
323
+source = "registry+https://github.com/rust-lang/crates.io-index"
324
+checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
325
+dependencies = [
326
+ "block-buffer",
327
+ "crypto-common",
328
+]
329
+
292330
 [[package]]
293331
 name = "dirs"
294332
 version = "5.0.1"
@@ -369,6 +407,16 @@ version = "2.0.0"
369407
 source = "registry+https://github.com/rust-lang/crates.io-index"
370408
 checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
371409
 
410
+[[package]]
411
+name = "generic-array"
412
+version = "0.14.7"
413
+source = "registry+https://github.com/rust-lang/crates.io-index"
414
+checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
415
+dependencies = [
416
+ "typenum",
417
+ "version_check",
418
+]
419
+
372420
 [[package]]
373421
 name = "getrandom"
374422
 version = "0.2.16"
@@ -459,8 +507,10 @@ dependencies = [
459507
  "rustls-pemfile",
460508
  "serde",
461509
  "serde_json",
510
+ "sha2",
462511
  "smithay-client-toolkit",
463512
  "thiserror 1.0.69",
513
+ "time",
464514
  "tokio",
465515
  "tokio-rustls",
466516
  "toml",
@@ -950,6 +1000,17 @@ dependencies = [
9501000
  "serde",
9511001
 ]
9521002
 
1003
+[[package]]
1004
+name = "sha2"
1005
+version = "0.10.9"
1006
+source = "registry+https://github.com/rust-lang/crates.io-index"
1007
+checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
1008
+dependencies = [
1009
+ "cfg-if",
1010
+ "cpufeatures",
1011
+ "digest",
1012
+]
1013
+
9531014
 [[package]]
9541015
 name = "sharded-slab"
9551016
 version = "0.1.7"
@@ -1286,6 +1347,12 @@ dependencies = [
12861347
  "tracing-log",
12871348
 ]
12881349
 
1350
+[[package]]
1351
+name = "typenum"
1352
+version = "1.19.0"
1353
+source = "registry+https://github.com/rust-lang/crates.io-index"
1354
+checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
1355
+
12891356
 [[package]]
12901357
 name = "unicode-ident"
12911358
 version = "1.0.22"
@@ -1310,6 +1377,12 @@ version = "0.1.1"
13101377
 source = "registry+https://github.com/rust-lang/crates.io-index"
13111378
 checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
13121379
 
1380
+[[package]]
1381
+name = "version_check"
1382
+version = "0.9.5"
1383
+source = "registry+https://github.com/rust-lang/crates.io-index"
1384
+checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
1385
+
13131386
 [[package]]
13141387
 name = "wasi"
13151388
 version = "0.11.1+wasi-snapshot-preview1"
Cargo.tomlmodified
@@ -44,9 +44,11 @@ smithay-client-toolkit = "0.19"
4444
 
4545
 # TLS
4646
 tokio-rustls = "0.26"
47
-rustls = "0.23"
47
+rustls = { version = "0.23", features = ["aws_lc_rs"] }
4848
 rustls-pemfile = "2"
4949
 rcgen = "0.13"
50
+sha2 = "0.10"
51
+time = "0.3"
5052
 
5153
 # Utilities
5254
 dirs = "5"
hyprkvm-daemon/Cargo.tomlmodified
@@ -45,6 +45,8 @@ tokio-rustls = { workspace = true }
4545
 rustls = { workspace = true }
4646
 rustls-pemfile = { workspace = true }
4747
 rcgen = { workspace = true }
48
+sha2 = { workspace = true }
49
+time = { workspace = true }
4850
 
4951
 # Utilities
5052
 dirs = { workspace = true }
hyprkvm-daemon/src/network/known_hosts.rsadded
@@ -0,0 +1,226 @@
1
+//! Known hosts management for TOFU (Trust On First Use)
2
+//!
3
+//! Stores trusted certificate fingerprints for peer machines.
4
+
5
+use std::collections::HashMap;
6
+use std::fs;
7
+use std::path::{Path, PathBuf};
8
+
9
+use serde::{Deserialize, Serialize};
10
+
11
+use super::tls::Fingerprint;
12
+
13
+/// Known hosts storage
14
+#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15
+pub struct KnownHosts {
16
+    /// Map of machine name to trusted fingerprint
17
+    #[serde(default)]
18
+    pub hosts: HashMap<String, KnownHost>,
19
+}
20
+
21
+/// A known host entry
22
+#[derive(Debug, Clone, Serialize, Deserialize)]
23
+pub struct KnownHost {
24
+    /// Certificate fingerprint (SHA-256, hex encoded)
25
+    pub fingerprint: String,
26
+    /// When this host was first seen
27
+    #[serde(default = "default_timestamp")]
28
+    pub first_seen: String,
29
+    /// When this host was last seen
30
+    #[serde(default = "default_timestamp")]
31
+    pub last_seen: String,
32
+    /// Optional notes
33
+    #[serde(default, skip_serializing_if = "Option::is_none")]
34
+    pub notes: Option<String>,
35
+}
36
+
37
+fn default_timestamp() -> String {
38
+    chrono_lite_now()
39
+}
40
+
41
+/// Simple timestamp without chrono dependency
42
+fn chrono_lite_now() -> String {
43
+    use std::time::{SystemTime, UNIX_EPOCH};
44
+    let duration = SystemTime::now()
45
+        .duration_since(UNIX_EPOCH)
46
+        .unwrap_or_default();
47
+    format!("{}", duration.as_secs())
48
+}
49
+
50
+impl KnownHosts {
51
+    /// Get the default known hosts file path
52
+    pub fn default_path() -> PathBuf {
53
+        dirs::config_dir()
54
+            .map(|d| d.join("hyprkvm").join("known_hosts.toml"))
55
+            .unwrap_or_else(|| PathBuf::from("known_hosts.toml"))
56
+    }
57
+
58
+    /// Load known hosts from file, or create empty if not exists
59
+    pub fn load(path: &Path) -> Result<Self, KnownHostsError> {
60
+        if !path.exists() {
61
+            tracing::debug!("Known hosts file not found, starting fresh");
62
+            return Ok(Self::default());
63
+        }
64
+
65
+        let content = fs::read_to_string(path)
66
+            .map_err(|e| KnownHostsError::Io(format!("Failed to read {:?}: {}", path, e)))?;
67
+
68
+        toml::from_str(&content)
69
+            .map_err(|e| KnownHostsError::Parse(format!("Failed to parse {:?}: {}", path, e)))
70
+    }
71
+
72
+    /// Save known hosts to file
73
+    pub fn save(&self, path: &Path) -> Result<(), KnownHostsError> {
74
+        // Ensure parent directory exists
75
+        if let Some(parent) = path.parent() {
76
+            fs::create_dir_all(parent)
77
+                .map_err(|e| KnownHostsError::Io(format!("Failed to create directory: {}", e)))?;
78
+        }
79
+
80
+        let content = toml::to_string_pretty(self)
81
+            .map_err(|e| KnownHostsError::Serialize(e.to_string()))?;
82
+
83
+        fs::write(path, content)
84
+            .map_err(|e| KnownHostsError::Io(format!("Failed to write {:?}: {}", path, e)))?;
85
+
86
+        tracing::debug!("Saved known hosts to {:?}", path);
87
+        Ok(())
88
+    }
89
+
90
+    /// Check if a host is known with the given fingerprint
91
+    pub fn is_trusted(&self, machine_name: &str, fingerprint: &Fingerprint) -> TrustStatus {
92
+        match self.hosts.get(machine_name) {
93
+            None => TrustStatus::Unknown,
94
+            Some(host) => {
95
+                if host.fingerprint == fingerprint.to_hex() {
96
+                    TrustStatus::Trusted
97
+                } else {
98
+                    TrustStatus::Changed {
99
+                        old_fingerprint: host.fingerprint.clone(),
100
+                        new_fingerprint: fingerprint.to_hex(),
101
+                    }
102
+                }
103
+            }
104
+        }
105
+    }
106
+
107
+    /// Add or update a trusted host
108
+    pub fn trust_host(&mut self, machine_name: &str, fingerprint: &Fingerprint) {
109
+        let now = chrono_lite_now();
110
+        let fp_hex = fingerprint.to_hex();
111
+
112
+        if let Some(existing) = self.hosts.get_mut(machine_name) {
113
+            existing.fingerprint = fp_hex;
114
+            existing.last_seen = now;
115
+            tracing::info!("Updated known host: {}", machine_name);
116
+        } else {
117
+            self.hosts.insert(
118
+                machine_name.to_string(),
119
+                KnownHost {
120
+                    fingerprint: fp_hex,
121
+                    first_seen: now.clone(),
122
+                    last_seen: now,
123
+                    notes: None,
124
+                },
125
+            );
126
+            tracing::info!("Added new known host: {}", machine_name);
127
+        }
128
+    }
129
+
130
+    /// Remove a host from known hosts
131
+    pub fn remove_host(&mut self, machine_name: &str) -> bool {
132
+        self.hosts.remove(machine_name).is_some()
133
+    }
134
+
135
+    /// Get fingerprint for a known host
136
+    pub fn get_fingerprint(&self, machine_name: &str) -> Option<Fingerprint> {
137
+        self.hosts
138
+            .get(machine_name)
139
+            .and_then(|h| Fingerprint::from_hex(&h.fingerprint).ok())
140
+    }
141
+
142
+    /// Update last_seen timestamp for a host
143
+    pub fn touch(&mut self, machine_name: &str) {
144
+        if let Some(host) = self.hosts.get_mut(machine_name) {
145
+            host.last_seen = chrono_lite_now();
146
+        }
147
+    }
148
+}
149
+
150
+/// Result of checking trust status
151
+#[derive(Debug, Clone, PartialEq)]
152
+pub enum TrustStatus {
153
+    /// Host is not in known_hosts
154
+    Unknown,
155
+    /// Host is known and fingerprint matches
156
+    Trusted,
157
+    /// Host is known but fingerprint changed (potential MITM!)
158
+    Changed {
159
+        old_fingerprint: String,
160
+        new_fingerprint: String,
161
+    },
162
+}
163
+
164
+impl TrustStatus {
165
+    pub fn is_trusted(&self) -> bool {
166
+        matches!(self, TrustStatus::Trusted)
167
+    }
168
+
169
+    pub fn is_unknown(&self) -> bool {
170
+        matches!(self, TrustStatus::Unknown)
171
+    }
172
+
173
+    pub fn is_changed(&self) -> bool {
174
+        matches!(self, TrustStatus::Changed { .. })
175
+    }
176
+}
177
+
178
+#[derive(Debug, thiserror::Error)]
179
+pub enum KnownHostsError {
180
+    #[error("IO error: {0}")]
181
+    Io(String),
182
+
183
+    #[error("Parse error: {0}")]
184
+    Parse(String),
185
+
186
+    #[error("Serialize error: {0}")]
187
+    Serialize(String),
188
+}
189
+
190
+#[cfg(test)]
191
+mod tests {
192
+    use super::*;
193
+    use tempfile::tempdir;
194
+
195
+    #[test]
196
+    fn test_trust_workflow() {
197
+        let mut known_hosts = KnownHosts::default();
198
+        let fp = Fingerprint([0xAB; 32]);
199
+
200
+        // Initially unknown
201
+        assert!(known_hosts.is_trusted("test-machine", &fp).is_unknown());
202
+
203
+        // Trust it
204
+        known_hosts.trust_host("test-machine", &fp);
205
+        assert!(known_hosts.is_trusted("test-machine", &fp).is_trusted());
206
+
207
+        // Different fingerprint should be detected
208
+        let fp2 = Fingerprint([0xCD; 32]);
209
+        assert!(known_hosts.is_trusted("test-machine", &fp2).is_changed());
210
+    }
211
+
212
+    #[test]
213
+    fn test_save_load() {
214
+        let dir = tempdir().unwrap();
215
+        let path = dir.path().join("known_hosts.toml");
216
+
217
+        let mut known_hosts = KnownHosts::default();
218
+        let fp = Fingerprint([0xAB; 32]);
219
+        known_hosts.trust_host("test-machine", &fp);
220
+
221
+        known_hosts.save(&path).unwrap();
222
+
223
+        let loaded = KnownHosts::load(&path).unwrap();
224
+        assert!(loaded.is_trusted("test-machine", &fp).is_trusted());
225
+    }
226
+}
hyprkvm-daemon/src/network/mod.rsmodified
@@ -2,8 +2,16 @@
22
 //!
33
 //! Handles peer-to-peer connections between HyprKVM instances.
44
 
5
+#[allow(dead_code)]
6
+pub mod known_hosts;
57
 #[allow(dead_code)]
68
 pub mod peer;
9
+pub mod tls;
710
 pub mod transport;
811
 
9
-pub use transport::{connect, FramedConnection, Server};
12
+pub use known_hosts::{KnownHosts, TrustStatus};
13
+pub use tls::{
14
+    create_tls_acceptor, create_tls_connector, ensure_certificate, get_cert_fingerprint,
15
+    Fingerprint, TlsError,
16
+};
17
+pub use transport::{connect, connect_tls, FramedConnection, Server, TransportError};
hyprkvm-daemon/src/network/tls.rsadded
@@ -0,0 +1,409 @@
1
+//! TLS certificate management for HyprKVM
2
+//!
3
+//! Handles certificate generation, loading, and fingerprint calculation.
4
+
5
+use std::fs;
6
+use std::io::{self, BufReader};
7
+use std::path::Path;
8
+use std::sync::Arc;
9
+
10
+use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair, PKCS_ECDSA_P256_SHA256};
11
+use rustls::pki_types::{CertificateDer, PrivateKeyDer};
12
+use rustls::{ClientConfig, ServerConfig};
13
+use rustls_pemfile::{certs, private_key};
14
+use sha2::{Digest, Sha256};
15
+use tokio_rustls::{TlsAcceptor, TlsConnector};
16
+
17
+/// TLS-related errors
18
+#[derive(Debug, thiserror::Error)]
19
+pub enum TlsError {
20
+    #[error("IO error: {0}")]
21
+    Io(#[from] io::Error),
22
+
23
+    #[error("Certificate generation error: {0}")]
24
+    CertGen(String),
25
+
26
+    #[error("Certificate loading error: {0}")]
27
+    CertLoad(String),
28
+
29
+    #[error("TLS configuration error: {0}")]
30
+    Config(String),
31
+
32
+    #[error("Certificate verification failed: {0}")]
33
+    Verification(String),
34
+}
35
+
36
+/// Certificate fingerprint (SHA-256)
37
+#[derive(Debug, Clone, PartialEq, Eq)]
38
+pub struct Fingerprint(pub [u8; 32]);
39
+
40
+impl Fingerprint {
41
+    /// Calculate fingerprint from DER-encoded certificate
42
+    pub fn from_der(der: &[u8]) -> Self {
43
+        let mut hasher = Sha256::new();
44
+        hasher.update(der);
45
+        let result = hasher.finalize();
46
+        let mut bytes = [0u8; 32];
47
+        bytes.copy_from_slice(&result);
48
+        Self(bytes)
49
+    }
50
+
51
+    /// Convert to hex string (colon-separated, uppercase)
52
+    pub fn to_hex(&self) -> String {
53
+        self.0
54
+            .iter()
55
+            .map(|b| format!("{:02X}", b))
56
+            .collect::<Vec<_>>()
57
+            .join(":")
58
+    }
59
+
60
+    /// Parse from hex string (colon-separated or continuous)
61
+    pub fn from_hex(s: &str) -> Result<Self, TlsError> {
62
+        let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
63
+        if clean.len() != 64 {
64
+            return Err(TlsError::Verification(format!(
65
+                "Invalid fingerprint length: expected 64 hex chars, got {}",
66
+                clean.len()
67
+            )));
68
+        }
69
+
70
+        let mut bytes = [0u8; 32];
71
+        for (i, chunk) in clean.as_bytes().chunks(2).enumerate() {
72
+            let hex = std::str::from_utf8(chunk).unwrap();
73
+            bytes[i] = u8::from_str_radix(hex, 16)
74
+                .map_err(|e| TlsError::Verification(format!("Invalid hex: {}", e)))?;
75
+        }
76
+
77
+        Ok(Self(bytes))
78
+    }
79
+}
80
+
81
+impl std::fmt::Display for Fingerprint {
82
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83
+        write!(f, "{}", self.to_hex())
84
+    }
85
+}
86
+
87
+/// Generate a self-signed ECDSA P-256 certificate
88
+pub fn generate_certificate(machine_name: &str) -> Result<(String, String), TlsError> {
89
+    tracing::info!("Generating new self-signed certificate for '{}'", machine_name);
90
+
91
+    // Generate ECDSA P-256 key pair
92
+    let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)
93
+        .map_err(|e| TlsError::CertGen(format!("Failed to generate key pair: {}", e)))?;
94
+
95
+    // Build certificate parameters
96
+    let mut params = CertificateParams::default();
97
+
98
+    // Set distinguished name
99
+    let mut dn = DistinguishedName::new();
100
+    dn.push(DnType::CommonName, machine_name);
101
+    dn.push(DnType::OrganizationName, "HyprKVM");
102
+    params.distinguished_name = dn;
103
+
104
+    // Set validity (10 years)
105
+    params.not_before = time::OffsetDateTime::now_utc();
106
+    params.not_after = params.not_before + time::Duration::days(3650);
107
+
108
+    // Add Subject Alternative Names
109
+    params.subject_alt_names = vec![
110
+        rcgen::SanType::DnsName(machine_name.try_into().map_err(|e| {
111
+            TlsError::CertGen(format!("Invalid machine name for SAN: {}", e))
112
+        })?),
113
+    ];
114
+
115
+    // Generate certificate
116
+    let cert = params
117
+        .self_signed(&key_pair)
118
+        .map_err(|e| TlsError::CertGen(format!("Failed to generate certificate: {}", e)))?;
119
+
120
+    // Serialize to PEM
121
+    let cert_pem = cert.pem();
122
+    let key_pem = key_pair.serialize_pem();
123
+
124
+    tracing::info!("Certificate generated successfully");
125
+
126
+    Ok((cert_pem, key_pem))
127
+}
128
+
129
+/// Ensure certificate exists, generating if needed
130
+pub fn ensure_certificate(
131
+    cert_path: &Path,
132
+    key_path: &Path,
133
+    machine_name: &str,
134
+) -> Result<(), TlsError> {
135
+    if cert_path.exists() && key_path.exists() {
136
+        tracing::debug!("Certificate files exist: {:?}, {:?}", cert_path, key_path);
137
+        return Ok(());
138
+    }
139
+
140
+    tracing::info!("Certificate files not found, generating new certificate");
141
+
142
+    let (cert_pem, key_pem) = generate_certificate(machine_name)?;
143
+
144
+    // Ensure parent directories exist
145
+    if let Some(parent) = cert_path.parent() {
146
+        fs::create_dir_all(parent)?;
147
+    }
148
+    if let Some(parent) = key_path.parent() {
149
+        fs::create_dir_all(parent)?;
150
+    }
151
+
152
+    // Write certificate
153
+    fs::write(cert_path, &cert_pem)?;
154
+    tracing::info!("Wrote certificate to {:?}", cert_path);
155
+
156
+    // Write private key with restrictive permissions
157
+    #[cfg(unix)]
158
+    {
159
+        use std::os::unix::fs::OpenOptionsExt;
160
+        let mut opts = fs::OpenOptions::new();
161
+        opts.write(true).create(true).truncate(true).mode(0o600);
162
+        let mut file = opts.open(key_path)?;
163
+        std::io::Write::write_all(&mut file, key_pem.as_bytes())?;
164
+    }
165
+    #[cfg(not(unix))]
166
+    {
167
+        fs::write(key_path, &key_pem)?;
168
+    }
169
+    tracing::info!("Wrote private key to {:?}", key_path);
170
+
171
+    Ok(())
172
+}
173
+
174
+/// Load certificate chain from PEM file
175
+pub fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
176
+    let file = fs::File::open(path)
177
+        .map_err(|e| TlsError::CertLoad(format!("Failed to open {:?}: {}", path, e)))?;
178
+    let mut reader = BufReader::new(file);
179
+
180
+    let certs: Vec<CertificateDer<'static>> = certs(&mut reader)
181
+        .collect::<Result<Vec<_>, _>>()
182
+        .map_err(|e| TlsError::CertLoad(format!("Failed to parse certificates: {}", e)))?;
183
+
184
+    if certs.is_empty() {
185
+        return Err(TlsError::CertLoad("No certificates found in file".into()));
186
+    }
187
+
188
+    Ok(certs)
189
+}
190
+
191
+/// Load private key from PEM file
192
+pub fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
193
+    let file = fs::File::open(path)
194
+        .map_err(|e| TlsError::CertLoad(format!("Failed to open {:?}: {}", path, e)))?;
195
+    let mut reader = BufReader::new(file);
196
+
197
+    private_key(&mut reader)
198
+        .map_err(|e| TlsError::CertLoad(format!("Failed to parse private key: {}", e)))?
199
+        .ok_or_else(|| TlsError::CertLoad("No private key found in file".into()))
200
+}
201
+
202
+/// Get fingerprint of certificate file
203
+pub fn get_cert_fingerprint(path: &Path) -> Result<Fingerprint, TlsError> {
204
+    let certs = load_certs(path)?;
205
+    if let Some(cert) = certs.first() {
206
+        Ok(Fingerprint::from_der(cert.as_ref()))
207
+    } else {
208
+        Err(TlsError::CertLoad("No certificate in file".into()))
209
+    }
210
+}
211
+
212
+/// Create a TLS acceptor for server-side connections
213
+pub fn create_tls_acceptor(cert_path: &Path, key_path: &Path) -> Result<TlsAcceptor, TlsError> {
214
+    let certs = load_certs(cert_path)?;
215
+    let key = load_private_key(key_path)?;
216
+
217
+    let config = ServerConfig::builder()
218
+        .with_no_client_auth()
219
+        .with_single_cert(certs, key)
220
+        .map_err(|e| TlsError::Config(format!("Failed to create server config: {}", e)))?;
221
+
222
+    Ok(TlsAcceptor::from(Arc::new(config)))
223
+}
224
+
225
+/// Custom certificate verifier that supports TOFU and fingerprint pinning
226
+#[derive(Debug)]
227
+pub struct HyprkvmCertVerifier {
228
+    /// Expected fingerprint (if pinned)
229
+    expected_fingerprint: Option<Fingerprint>,
230
+    /// Callback to handle TOFU (returns true to accept, false to reject)
231
+    /// If None, TOFU is disabled
232
+    tofu_enabled: bool,
233
+}
234
+
235
+impl HyprkvmCertVerifier {
236
+    /// Create verifier with pinned fingerprint
237
+    pub fn with_fingerprint(fingerprint: Fingerprint) -> Self {
238
+        Self {
239
+            expected_fingerprint: Some(fingerprint),
240
+            tofu_enabled: false,
241
+        }
242
+    }
243
+
244
+    /// Create verifier with TOFU enabled
245
+    pub fn with_tofu() -> Self {
246
+        Self {
247
+            expected_fingerprint: None,
248
+            tofu_enabled: true,
249
+        }
250
+    }
251
+
252
+    /// Create verifier with fingerprint, falling back to TOFU if not set
253
+    pub fn new(fingerprint: Option<Fingerprint>, tofu_enabled: bool) -> Self {
254
+        Self {
255
+            expected_fingerprint: fingerprint,
256
+            tofu_enabled,
257
+        }
258
+    }
259
+}
260
+
261
+impl rustls::client::danger::ServerCertVerifier for HyprkvmCertVerifier {
262
+    fn verify_server_cert(
263
+        &self,
264
+        end_entity: &CertificateDer<'_>,
265
+        _intermediates: &[CertificateDer<'_>],
266
+        _server_name: &rustls::pki_types::ServerName<'_>,
267
+        _ocsp_response: &[u8],
268
+        _now: rustls::pki_types::UnixTime,
269
+    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
270
+        let fingerprint = Fingerprint::from_der(end_entity.as_ref());
271
+
272
+        if let Some(ref expected) = self.expected_fingerprint {
273
+            // Fingerprint pinning mode
274
+            if fingerprint == *expected {
275
+                tracing::debug!("Certificate fingerprint matches pinned value");
276
+                return Ok(rustls::client::danger::ServerCertVerified::assertion());
277
+            } else {
278
+                tracing::error!(
279
+                    "Certificate fingerprint mismatch! Expected: {}, Got: {}",
280
+                    expected,
281
+                    fingerprint
282
+                );
283
+                return Err(rustls::Error::General(
284
+                    "Certificate fingerprint mismatch".into(),
285
+                ));
286
+            }
287
+        }
288
+
289
+        if self.tofu_enabled {
290
+            // TOFU mode - accept any certificate but log for manual verification
291
+            tracing::warn!(
292
+                "TOFU: Accepting certificate with fingerprint: {}",
293
+                fingerprint
294
+            );
295
+            tracing::warn!("TOFU: Add this fingerprint to config to pin the certificate");
296
+            return Ok(rustls::client::danger::ServerCertVerified::assertion());
297
+        }
298
+
299
+        Err(rustls::Error::General(
300
+            "No verification method configured".into(),
301
+        ))
302
+    }
303
+
304
+    fn verify_tls12_signature(
305
+        &self,
306
+        message: &[u8],
307
+        cert: &CertificateDer<'_>,
308
+        dss: &rustls::DigitallySignedStruct,
309
+    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310
+        rustls::crypto::verify_tls12_signature(
311
+            message,
312
+            cert,
313
+            dss,
314
+            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
315
+        )
316
+    }
317
+
318
+    fn verify_tls13_signature(
319
+        &self,
320
+        message: &[u8],
321
+        cert: &CertificateDer<'_>,
322
+        dss: &rustls::DigitallySignedStruct,
323
+    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
324
+        rustls::crypto::verify_tls13_signature(
325
+            message,
326
+            cert,
327
+            dss,
328
+            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
329
+        )
330
+    }
331
+
332
+    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
333
+        rustls::crypto::aws_lc_rs::default_provider()
334
+            .signature_verification_algorithms
335
+            .supported_schemes()
336
+    }
337
+}
338
+
339
+/// Create a TLS connector for client-side connections
340
+pub fn create_tls_connector(
341
+    expected_fingerprint: Option<&str>,
342
+    tofu_enabled: bool,
343
+) -> Result<TlsConnector, TlsError> {
344
+    let fingerprint = match expected_fingerprint {
345
+        Some(fp) => Some(Fingerprint::from_hex(fp)?),
346
+        None => None,
347
+    };
348
+
349
+    let verifier = HyprkvmCertVerifier::new(fingerprint, tofu_enabled);
350
+
351
+    let config = ClientConfig::builder()
352
+        .dangerous()
353
+        .with_custom_certificate_verifier(Arc::new(verifier))
354
+        .with_no_client_auth();
355
+
356
+    Ok(TlsConnector::from(Arc::new(config)))
357
+}
358
+
359
+#[cfg(test)]
360
+mod tests {
361
+    use super::*;
362
+    use tempfile::tempdir;
363
+
364
+    #[test]
365
+    fn test_fingerprint_roundtrip() {
366
+        let bytes = [0xABu8; 32];
367
+        let fp = Fingerprint(bytes);
368
+        let hex = fp.to_hex();
369
+        let parsed = Fingerprint::from_hex(&hex).unwrap();
370
+        assert_eq!(fp, parsed);
371
+    }
372
+
373
+    #[test]
374
+    fn test_generate_certificate() {
375
+        let (cert_pem, key_pem) = generate_certificate("test-machine").unwrap();
376
+        assert!(cert_pem.contains("-----BEGIN CERTIFICATE-----"));
377
+        assert!(key_pem.contains("-----BEGIN PRIVATE KEY-----"));
378
+    }
379
+
380
+    #[test]
381
+    fn test_ensure_certificate() {
382
+        let dir = tempdir().unwrap();
383
+        let cert_path = dir.path().join("cert.pem");
384
+        let key_path = dir.path().join("key.pem");
385
+
386
+        ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
387
+
388
+        assert!(cert_path.exists());
389
+        assert!(key_path.exists());
390
+
391
+        // Second call should not regenerate
392
+        let cert_content = fs::read_to_string(&cert_path).unwrap();
393
+        ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
394
+        let cert_content2 = fs::read_to_string(&cert_path).unwrap();
395
+        assert_eq!(cert_content, cert_content2);
396
+    }
397
+
398
+    #[test]
399
+    fn test_fingerprint_calculation() {
400
+        let dir = tempdir().unwrap();
401
+        let cert_path = dir.path().join("cert.pem");
402
+        let key_path = dir.path().join("key.pem");
403
+
404
+        ensure_certificate(&cert_path, &key_path, "test-machine").unwrap();
405
+
406
+        let fp = get_cert_fingerprint(&cert_path).unwrap();
407
+        assert_eq!(fp.to_hex().len(), 95); // 32 bytes * 2 + 31 colons
408
+    }
409
+}
hyprkvm-daemon/src/network/transport.rsmodified
@@ -1,49 +1,151 @@
1
-//! TCP transport layer for HyprKVM
1
+//! TCP/TLS transport layer for HyprKVM
22
 //!
3
-//! Provides basic message framing and transport over TCP.
3
+//! Provides message framing and transport over TCP with optional TLS encryption.
44
 
55
 #![allow(dead_code)]
66
 
77
 use std::io;
88
 use std::net::SocketAddr;
9
+use std::path::Path;
10
+use std::pin::Pin;
11
+use std::sync::Arc;
12
+use std::task::{Context, Poll};
913
 
1014
 use bytes::{Buf, BufMut, BytesMut};
11
-use tokio::io::{AsyncReadExt, AsyncWriteExt};
15
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
1216
 use tokio::net::{TcpListener, TcpStream};
17
+use tokio_rustls::client::TlsStream as ClientTlsStream;
18
+use tokio_rustls::server::TlsStream as ServerTlsStream;
19
+use tokio_rustls::{TlsAcceptor, TlsConnector};
1320
 
1421
 use hyprkvm_common::protocol::Message;
1522
 
23
+use super::tls::{self, Fingerprint};
24
+
1625
 /// Maximum message size (1MB)
1726
 const MAX_MESSAGE_SIZE: u32 = 1024 * 1024;
1827
 
1928
 /// Frame header size (4 bytes for length)
2029
 const FRAME_HEADER_SIZE: usize = 4;
2130
 
31
+/// Stream type - either plain TCP or TLS-wrapped
32
+pub enum Stream {
33
+    Plain(TcpStream),
34
+    TlsClient(ClientTlsStream<TcpStream>),
35
+    TlsServer(ServerTlsStream<TcpStream>),
36
+}
37
+
38
+impl AsyncRead for Stream {
39
+    fn poll_read(
40
+        self: Pin<&mut Self>,
41
+        cx: &mut Context<'_>,
42
+        buf: &mut ReadBuf<'_>,
43
+    ) -> Poll<io::Result<()>> {
44
+        match self.get_mut() {
45
+            Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
46
+            Stream::TlsClient(s) => Pin::new(s).poll_read(cx, buf),
47
+            Stream::TlsServer(s) => Pin::new(s).poll_read(cx, buf),
48
+        }
49
+    }
50
+}
51
+
52
+impl AsyncWrite for Stream {
53
+    fn poll_write(
54
+        self: Pin<&mut Self>,
55
+        cx: &mut Context<'_>,
56
+        buf: &[u8],
57
+    ) -> Poll<io::Result<usize>> {
58
+        match self.get_mut() {
59
+            Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
60
+            Stream::TlsClient(s) => Pin::new(s).poll_write(cx, buf),
61
+            Stream::TlsServer(s) => Pin::new(s).poll_write(cx, buf),
62
+        }
63
+    }
64
+
65
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
66
+        match self.get_mut() {
67
+            Stream::Plain(s) => Pin::new(s).poll_flush(cx),
68
+            Stream::TlsClient(s) => Pin::new(s).poll_flush(cx),
69
+            Stream::TlsServer(s) => Pin::new(s).poll_flush(cx),
70
+        }
71
+    }
72
+
73
+    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74
+        match self.get_mut() {
75
+            Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
76
+            Stream::TlsClient(s) => Pin::new(s).poll_shutdown(cx),
77
+            Stream::TlsServer(s) => Pin::new(s).poll_shutdown(cx),
78
+        }
79
+    }
80
+}
81
+
2282
 /// A framed connection that can send and receive Messages
2383
 pub struct FramedConnection {
24
-    stream: TcpStream,
84
+    stream: Stream,
2585
     read_buf: BytesMut,
2686
     write_buf: BytesMut,
2787
     remote_addr: SocketAddr,
88
+    /// Peer certificate fingerprint (if TLS)
89
+    peer_fingerprint: Option<Fingerprint>,
2890
 }
2991
 
3092
 impl FramedConnection {
31
-    /// Create a new framed connection from a TcpStream
93
+    /// Create a new framed connection from a plain TcpStream
3294
     pub fn new(stream: TcpStream) -> io::Result<Self> {
3395
         let remote_addr = stream.peer_addr()?;
3496
         Ok(Self {
35
-            stream,
97
+            stream: Stream::Plain(stream),
3698
             read_buf: BytesMut::with_capacity(8192),
3799
             write_buf: BytesMut::with_capacity(8192),
38100
             remote_addr,
101
+            peer_fingerprint: None,
39102
         })
40103
     }
41104
 
105
+    /// Create a new framed connection from a client TLS stream
106
+    pub fn from_tls_client(
107
+        stream: ClientTlsStream<TcpStream>,
108
+        remote_addr: SocketAddr,
109
+        peer_fingerprint: Option<Fingerprint>,
110
+    ) -> Self {
111
+        Self {
112
+            stream: Stream::TlsClient(stream),
113
+            read_buf: BytesMut::with_capacity(8192),
114
+            write_buf: BytesMut::with_capacity(8192),
115
+            remote_addr,
116
+            peer_fingerprint,
117
+        }
118
+    }
119
+
120
+    /// Create a new framed connection from a server TLS stream
121
+    pub fn from_tls_server(
122
+        stream: ServerTlsStream<TcpStream>,
123
+        remote_addr: SocketAddr,
124
+    ) -> Self {
125
+        Self {
126
+            stream: Stream::TlsServer(stream),
127
+            read_buf: BytesMut::with_capacity(8192),
128
+            write_buf: BytesMut::with_capacity(8192),
129
+            remote_addr,
130
+            peer_fingerprint: None,
131
+        }
132
+    }
133
+
42134
     /// Get the remote address
43135
     pub fn remote_addr(&self) -> SocketAddr {
44136
         self.remote_addr
45137
     }
46138
 
139
+    /// Get the peer's certificate fingerprint (if TLS connection)
140
+    pub fn peer_fingerprint(&self) -> Option<&Fingerprint> {
141
+        self.peer_fingerprint.as_ref()
142
+    }
143
+
144
+    /// Check if this is a TLS connection
145
+    pub fn is_tls(&self) -> bool {
146
+        !matches!(self.stream, Stream::Plain(_))
147
+    }
148
+
47149
     /// Send a message
48150
     pub async fn send(&mut self, msg: &Message) -> Result<(), TransportError> {
49151
         // Serialize the message
@@ -137,22 +239,44 @@ impl FramedConnection {
137239
     }
138240
 }
139241
 
140
-/// TCP server that accepts connections
242
+/// TLS-enabled TCP server that accepts connections
141243
 pub struct Server {
142244
     listener: TcpListener,
143245
     local_addr: SocketAddr,
246
+    tls_acceptor: Option<TlsAcceptor>,
144247
 }
145248
 
146249
 impl Server {
147
-    /// Bind to an address and start listening
250
+    /// Bind to an address and start listening (plain TCP)
148251
     pub async fn bind(addr: SocketAddr) -> Result<Self, TransportError> {
149252
         let listener = TcpListener::bind(addr).await.map_err(TransportError::Io)?;
150253
         let local_addr = listener.local_addr().map_err(TransportError::Io)?;
151254
 
152
-        tracing::info!("Server listening on {}", local_addr);
255
+        tracing::info!("Server listening on {} (plain TCP)", local_addr);
153256
         Ok(Self {
154257
             listener,
155258
             local_addr,
259
+            tls_acceptor: None,
260
+        })
261
+    }
262
+
263
+    /// Bind to an address with TLS
264
+    pub async fn bind_tls(
265
+        addr: SocketAddr,
266
+        cert_path: &Path,
267
+        key_path: &Path,
268
+    ) -> Result<Self, TransportError> {
269
+        let listener = TcpListener::bind(addr).await.map_err(TransportError::Io)?;
270
+        let local_addr = listener.local_addr().map_err(TransportError::Io)?;
271
+
272
+        let acceptor = tls::create_tls_acceptor(cert_path, key_path)
273
+            .map_err(|e| TransportError::Tls(e.to_string()))?;
274
+
275
+        tracing::info!("Server listening on {} (TLS)", local_addr);
276
+        Ok(Self {
277
+            listener,
278
+            local_addr,
279
+            tls_acceptor: Some(acceptor),
156280
         })
157281
     }
158282
 
@@ -161,19 +285,36 @@ impl Server {
161285
         self.local_addr
162286
     }
163287
 
288
+    /// Check if TLS is enabled
289
+    pub fn is_tls(&self) -> bool {
290
+        self.tls_acceptor.is_some()
291
+    }
292
+
164293
     /// Accept a new connection
165294
     pub async fn accept(&self) -> Result<FramedConnection, TransportError> {
166295
         let (stream, addr) = self.listener.accept().await.map_err(TransportError::Io)?;
167296
         // Disable Nagle's algorithm for low-latency input forwarding
168297
         stream.set_nodelay(true).map_err(TransportError::Io)?;
169
-        tracing::info!("Accepted connection from {}", addr);
170
-        FramedConnection::new(stream).map_err(TransportError::Io)
298
+
299
+        if let Some(ref acceptor) = self.tls_acceptor {
300
+            tracing::debug!("Performing TLS handshake with {}", addr);
301
+            let tls_stream = acceptor
302
+                .accept(stream)
303
+                .await
304
+                .map_err(|e| TransportError::Tls(format!("TLS handshake failed: {}", e)))?;
305
+
306
+            tracing::info!("Accepted TLS connection from {}", addr);
307
+            Ok(FramedConnection::from_tls_server(tls_stream, addr))
308
+        } else {
309
+            tracing::info!("Accepted connection from {}", addr);
310
+            FramedConnection::new(stream).map_err(TransportError::Io)
311
+        }
171312
     }
172313
 }
173314
 
174
-/// Connect to a remote server
315
+/// Connect to a remote server (plain TCP)
175316
 pub async fn connect(addr: SocketAddr) -> Result<FramedConnection, TransportError> {
176
-    tracing::info!("Connecting to {}", addr);
317
+    tracing::info!("Connecting to {} (plain TCP)", addr);
177318
     let stream = TcpStream::connect(addr).await.map_err(TransportError::Io)?;
178319
     // Disable Nagle's algorithm for low-latency input forwarding
179320
     stream.set_nodelay(true).map_err(TransportError::Io)?;
@@ -181,11 +322,55 @@ pub async fn connect(addr: SocketAddr) -> Result<FramedConnection, TransportErro
181322
     FramedConnection::new(stream).map_err(TransportError::Io)
182323
 }
183324
 
325
+/// Connect to a remote server with TLS
326
+pub async fn connect_tls(
327
+    addr: SocketAddr,
328
+    server_name: &str,
329
+    expected_fingerprint: Option<&str>,
330
+    tofu_enabled: bool,
331
+) -> Result<FramedConnection, TransportError> {
332
+    tracing::info!("Connecting to {} (TLS, server_name={})", addr, server_name);
333
+
334
+    let stream = TcpStream::connect(addr).await.map_err(TransportError::Io)?;
335
+    stream.set_nodelay(true).map_err(TransportError::Io)?;
336
+
337
+    let connector = tls::create_tls_connector(expected_fingerprint, tofu_enabled)
338
+        .map_err(|e| TransportError::Tls(e.to_string()))?;
339
+
340
+    let server_name = rustls::pki_types::ServerName::try_from(server_name.to_string())
341
+        .map_err(|e| TransportError::Tls(format!("Invalid server name: {}", e)))?;
342
+
343
+    tracing::debug!("Performing TLS handshake with {}", addr);
344
+    let tls_stream = connector
345
+        .connect(server_name, stream)
346
+        .await
347
+        .map_err(|e| TransportError::Tls(format!("TLS handshake failed: {}", e)))?;
348
+
349
+    // Extract peer certificate fingerprint
350
+    let peer_fingerprint = tls_stream
351
+        .get_ref()
352
+        .1
353
+        .peer_certificates()
354
+        .and_then(|certs| certs.first())
355
+        .map(|cert| Fingerprint::from_der(cert.as_ref()));
356
+
357
+    if let Some(ref fp) = peer_fingerprint {
358
+        tracing::info!("Connected to {} (TLS), peer fingerprint: {}", addr, fp);
359
+    } else {
360
+        tracing::info!("Connected to {} (TLS)", addr);
361
+    }
362
+
363
+    Ok(FramedConnection::from_tls_client(tls_stream, addr, peer_fingerprint))
364
+}
365
+
184366
 #[derive(Debug, thiserror::Error)]
185367
 pub enum TransportError {
186368
     #[error("IO error: {0}")]
187369
     Io(#[from] io::Error),
188370
 
371
+    #[error("TLS error: {0}")]
372
+    Tls(String),
373
+
189374
     #[error("Failed to serialize message: {0}")]
190375
     Serialize(String),
191376