tenseleyflow/hyprkvm / 34378cc

Browse files

feat: integrate TLS with backwards compatibility

- Add tls.enabled config (default: false) for opt-in encryption
- Add per-neighbor tls override for mixed TLS/non-TLS environments
- Generate certificates on startup when TLS enabled
- Use TLS for server binding and outgoing connections based on config
- Print certificate fingerprint at startup for TOFU verification
- Clean up unused import warnings in network module
Authored by mfwolffe <wolffemf@dukes.jmu.edu>
SHA
34378cc01370f9935c07fdfb3ffa94c76235c008
Parents
ce63a95
Tree
c874474

5 changed files

StatusFile+-
M hyprkvm-daemon/src/config/mod.rs 16 1
M hyprkvm-daemon/src/main.rs 69 5
M hyprkvm-daemon/src/network/mod.rs 2 0
M hyprkvm-daemon/src/network/tls.rs 2 0
M hyprkvm-daemon/src/network/transport.rs 1 2
hyprkvm-daemon/src/config/mod.rsmodified
@@ -138,6 +138,10 @@ impl Default for NetworkConfig {
138
 
138
 
139
 #[derive(Debug, Clone, Serialize, Deserialize)]
139
 #[derive(Debug, Clone, Serialize, Deserialize)]
140
 pub struct TlsConfig {
140
 pub struct TlsConfig {
141
+    /// Enable TLS encryption (default: false for backwards compatibility)
142
+    #[serde(default)]
143
+    pub enabled: bool,
144
+
141
     /// Path to certificate file
145
     /// Path to certificate file
142
     #[serde(default = "default_cert_path")]
146
     #[serde(default = "default_cert_path")]
143
     pub cert_path: String,
147
     pub cert_path: String,
@@ -145,6 +149,10 @@ pub struct TlsConfig {
145
     /// Path to private key file
149
     /// Path to private key file
146
     #[serde(default = "default_key_path")]
150
     #[serde(default = "default_key_path")]
147
     pub key_path: String,
151
     pub key_path: String,
152
+
153
+    /// Enable TOFU (Trust On First Use) for unknown peers
154
+    #[serde(default = "default_true")]
155
+    pub tofu: bool,
148
 }
156
 }
149
 
157
 
150
 fn default_cert_path() -> String {
158
 fn default_cert_path() -> String {
@@ -166,8 +174,10 @@ fn default_key_path() -> String {
166
 impl Default for TlsConfig {
174
 impl Default for TlsConfig {
167
     fn default() -> Self {
175
     fn default() -> Self {
168
         Self {
176
         Self {
177
+            enabled: false, // Default to false for backwards compatibility
169
             cert_path: default_cert_path(),
178
             cert_path: default_cert_path(),
170
             key_path: default_key_path(),
179
             key_path: default_key_path(),
180
+            tofu: true,
171
         }
181
         }
172
     }
182
     }
173
 }
183
 }
@@ -209,8 +219,13 @@ pub struct NeighborConfig {
209
     /// Address (ip:port)
219
     /// Address (ip:port)
210
     pub address: SocketAddr,
220
     pub address: SocketAddr,
211
 
221
 
212
-    /// Pre-trusted certificate fingerprint (optional)
222
+    /// Pre-trusted certificate fingerprint (optional, for TLS pinning)
223
+    #[serde(default, skip_serializing_if = "Option::is_none")]
213
     pub fingerprint: Option<String>,
224
     pub fingerprint: Option<String>,
225
+
226
+    /// Override TLS setting for this neighbor (uses global setting if None)
227
+    #[serde(default, skip_serializing_if = "Option::is_none")]
228
+    pub tls: Option<bool>,
214
 }
229
 }
215
 
230
 
216
 #[derive(Debug, Clone, Serialize, Deserialize)]
231
 #[derive(Debug, Clone, Serialize, Deserialize)]
hyprkvm-daemon/src/main.rsmodified
@@ -275,10 +275,41 @@ async fn run_daemon(config_path: &std::path::Path) -> anyhow::Result<()> {
275
     let peers: Arc<RwLock<HashMap<Direction, network::FramedConnection>>> =
275
     let peers: Arc<RwLock<HashMap<Direction, network::FramedConnection>>> =
276
         Arc::new(RwLock::new(HashMap::new()));
276
         Arc::new(RwLock::new(HashMap::new()));
277
 
277
 
278
+    // TLS setup (if enabled)
279
+    let tls_enabled = config.network.tls.enabled;
280
+    if tls_enabled {
281
+        info!("TLS is enabled, ensuring certificates exist...");
282
+        let cert_path = std::path::Path::new(&config.network.tls.cert_path);
283
+        let key_path = std::path::Path::new(&config.network.tls.key_path);
284
+
285
+        if let Err(e) = network::ensure_certificate(cert_path, key_path, &config.machines.self_name) {
286
+            anyhow::bail!("Failed to setup TLS certificates: {}", e);
287
+        }
288
+
289
+        // Print certificate fingerprint for users to share
290
+        match network::get_cert_fingerprint(cert_path) {
291
+            Ok(fp) => {
292
+                info!("Certificate fingerprint: {}", fp);
293
+                info!("Share this fingerprint with peers for secure verification");
294
+            }
295
+            Err(e) => {
296
+                tracing::warn!("Could not read certificate fingerprint: {}", e);
297
+            }
298
+        }
299
+    } else {
300
+        info!("TLS is disabled (plain TCP mode for backwards compatibility)");
301
+    }
302
+
278
     // Start network server
303
     // Start network server
279
     let listen_addr: SocketAddr = format!("0.0.0.0:{}", config.network.listen_port).parse()?;
304
     let listen_addr: SocketAddr = format!("0.0.0.0:{}", config.network.listen_port).parse()?;
280
-    let server = network::Server::bind(listen_addr).await?;
305
+    let server = if tls_enabled {
281
-    info!("Listening for connections on {}", server.local_addr());
306
+        let cert_path = std::path::Path::new(&config.network.tls.cert_path);
307
+        let key_path = std::path::Path::new(&config.network.tls.key_path);
308
+        network::Server::bind_tls(listen_addr, cert_path, key_path).await?
309
+    } else {
310
+        network::Server::bind(listen_addr).await?
311
+    };
312
+    info!("Listening for connections on {} (TLS: {})", server.local_addr(), tls_enabled);
282
 
313
 
283
     // Spawn task to accept incoming connections
314
     // Spawn task to accept incoming connections
284
     let machine_name = config.machines.self_name.clone();
315
     let machine_name = config.machines.self_name.clone();
@@ -355,6 +386,13 @@ async fn run_daemon(config_path: &std::path::Path) -> anyhow::Result<()> {
355
         let direction = neighbor.direction;
386
         let direction = neighbor.direction;
356
         let peers_clone = peers.clone();
387
         let peers_clone = peers.clone();
357
         let machine_name = config.machines.self_name.clone();
388
         let machine_name = config.machines.self_name.clone();
389
+        let neighbor_name = neighbor.name.clone();
390
+
391
+        // Determine if TLS should be used for this neighbor
392
+        // Per-neighbor override takes precedence over global setting
393
+        let use_tls = neighbor.tls.unwrap_or(tls_enabled);
394
+        let fingerprint = neighbor.fingerprint.clone();
395
+        let tofu_enabled = config.network.tls.tofu;
358
 
396
 
359
         tokio::spawn(async move {
397
         tokio::spawn(async move {
360
             loop {
398
             loop {
@@ -369,8 +407,21 @@ async fn run_daemon(config_path: &std::path::Path) -> anyhow::Result<()> {
369
                     }
407
                     }
370
                 }
408
                 }
371
 
409
 
372
-                tracing::debug!("Connecting to {} at {}...", direction, addr);
410
+                tracing::debug!("Connecting to {} at {} (TLS: {})...", direction, addr, use_tls);
373
-                match network::connect(addr).await {
411
+
412
+                // Connect with or without TLS
413
+                let conn_result = if use_tls {
414
+                    network::connect_tls(
415
+                        addr,
416
+                        &neighbor_name,
417
+                        fingerprint.as_deref(),
418
+                        tofu_enabled,
419
+                    ).await
420
+                } else {
421
+                    network::connect(addr).await
422
+                };
423
+
424
+                match conn_result {
374
                     Ok(mut conn) => {
425
                     Ok(mut conn) => {
375
                         // Send Hello
426
                         // Send Hello
376
                         let hello = Message::Hello(HelloPayload {
427
                         let hello = Message::Hello(HelloPayload {
@@ -1695,8 +1746,21 @@ async fn run_daemon(config_path: &std::path::Path) -> anyhow::Result<()> {
1695
                                 let peers_clone = peers.clone();
1746
                                 let peers_clone = peers.clone();
1696
                                 let machine_name = config.machines.self_name.clone();
1747
                                 let machine_name = config.machines.self_name.clone();
1697
                                 let neighbor_name = n.name.clone();
1748
                                 let neighbor_name = n.name.clone();
1749
+
1750
+                                // Determine TLS settings for this neighbor
1751
+                                let use_tls = n.tls.unwrap_or(tls_enabled);
1752
+                                let fingerprint = n.fingerprint.clone();
1753
+                                let tofu = config.network.tls.tofu;
1754
+
1698
                                 tokio::spawn(async move {
1755
                                 tokio::spawn(async move {
1699
-                                    match network::connect(addr).await {
1756
+                                    // Connect with or without TLS
1757
+                                    let conn_result = if use_tls {
1758
+                                        network::connect_tls(addr, &neighbor_name, fingerprint.as_deref(), tofu).await
1759
+                                    } else {
1760
+                                        network::connect(addr).await
1761
+                                    };
1762
+
1763
+                                    match conn_result {
1700
                                         Ok(mut conn) => {
1764
                                         Ok(mut conn) => {
1701
                                             // Send Hello
1765
                                             // Send Hello
1702
                                             let hello = Message::Hello(HelloPayload {
1766
                                             let hello = Message::Hello(HelloPayload {
hyprkvm-daemon/src/network/mod.rsmodified
@@ -2,6 +2,8 @@
2
 //!
2
 //!
3
 //! Handles peer-to-peer connections between HyprKVM instances.
3
 //! Handles peer-to-peer connections between HyprKVM instances.
4
 
4
 
5
+#![allow(unused_imports)]
6
+
5
 #[allow(dead_code)]
7
 #[allow(dead_code)]
6
 pub mod known_hosts;
8
 pub mod known_hosts;
7
 #[allow(dead_code)]
9
 #[allow(dead_code)]
hyprkvm-daemon/src/network/tls.rsmodified
@@ -234,6 +234,7 @@ pub struct HyprkvmCertVerifier {
234
 
234
 
235
 impl HyprkvmCertVerifier {
235
 impl HyprkvmCertVerifier {
236
     /// Create verifier with pinned fingerprint
236
     /// Create verifier with pinned fingerprint
237
+    #[allow(dead_code)]
237
     pub fn with_fingerprint(fingerprint: Fingerprint) -> Self {
238
     pub fn with_fingerprint(fingerprint: Fingerprint) -> Self {
238
         Self {
239
         Self {
239
             expected_fingerprint: Some(fingerprint),
240
             expected_fingerprint: Some(fingerprint),
@@ -242,6 +243,7 @@ impl HyprkvmCertVerifier {
242
     }
243
     }
243
 
244
 
244
     /// Create verifier with TOFU enabled
245
     /// Create verifier with TOFU enabled
246
+    #[allow(dead_code)]
245
     pub fn with_tofu() -> Self {
247
     pub fn with_tofu() -> Self {
246
         Self {
248
         Self {
247
             expected_fingerprint: None,
249
             expected_fingerprint: None,
hyprkvm-daemon/src/network/transport.rsmodified
@@ -8,7 +8,6 @@ use std::io;
8
 use std::net::SocketAddr;
8
 use std::net::SocketAddr;
9
 use std::path::Path;
9
 use std::path::Path;
10
 use std::pin::Pin;
10
 use std::pin::Pin;
11
-use std::sync::Arc;
12
 use std::task::{Context, Poll};
11
 use std::task::{Context, Poll};
13
 
12
 
14
 use bytes::{Buf, BufMut, BytesMut};
13
 use bytes::{Buf, BufMut, BytesMut};
@@ -16,7 +15,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
16
 use tokio::net::{TcpListener, TcpStream};
15
 use tokio::net::{TcpListener, TcpStream};
17
 use tokio_rustls::client::TlsStream as ClientTlsStream;
16
 use tokio_rustls::client::TlsStream as ClientTlsStream;
18
 use tokio_rustls::server::TlsStream as ServerTlsStream;
17
 use tokio_rustls::server::TlsStream as ServerTlsStream;
19
-use tokio_rustls::{TlsAcceptor, TlsConnector};
18
+use tokio_rustls::TlsAcceptor;
20
 
19
 
21
 use hyprkvm_common::protocol::Message;
20
 use hyprkvm_common::protocol::Message;
22
 
21