Rust · 13626 bytes Raw Blame History
1 //! TCP/TLS transport layer for HyprKVM
2 //!
3 //! Provides message framing and transport over TCP with optional TLS encryption.
4
5 #![allow(dead_code)]
6
7 use std::io;
8 use std::net::SocketAddr;
9 use std::path::Path;
10 use std::pin::Pin;
11 use std::task::{Context, Poll};
12
13 use bytes::{Buf, BufMut, BytesMut};
14 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
15 use tokio::net::{TcpListener, TcpStream};
16 use tokio_rustls::client::TlsStream as ClientTlsStream;
17 use tokio_rustls::server::TlsStream as ServerTlsStream;
18 use tokio_rustls::TlsAcceptor;
19
20 use hyprkvm_common::protocol::Message;
21
22 use super::tls::{self, Fingerprint};
23
24 /// Maximum message size (1MB)
25 const MAX_MESSAGE_SIZE: u32 = 1024 * 1024;
26
27 /// Frame header size (4 bytes for length)
28 const FRAME_HEADER_SIZE: usize = 4;
29
30 /// Stream type - either plain TCP or TLS-wrapped
31 pub enum Stream {
32 Plain(TcpStream),
33 TlsClient(ClientTlsStream<TcpStream>),
34 TlsServer(ServerTlsStream<TcpStream>),
35 }
36
37 impl AsyncRead for Stream {
38 fn poll_read(
39 self: Pin<&mut Self>,
40 cx: &mut Context<'_>,
41 buf: &mut ReadBuf<'_>,
42 ) -> Poll<io::Result<()>> {
43 match self.get_mut() {
44 Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
45 Stream::TlsClient(s) => Pin::new(s).poll_read(cx, buf),
46 Stream::TlsServer(s) => Pin::new(s).poll_read(cx, buf),
47 }
48 }
49 }
50
51 impl AsyncWrite for Stream {
52 fn poll_write(
53 self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: &[u8],
56 ) -> Poll<io::Result<usize>> {
57 match self.get_mut() {
58 Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
59 Stream::TlsClient(s) => Pin::new(s).poll_write(cx, buf),
60 Stream::TlsServer(s) => Pin::new(s).poll_write(cx, buf),
61 }
62 }
63
64 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
65 match self.get_mut() {
66 Stream::Plain(s) => Pin::new(s).poll_flush(cx),
67 Stream::TlsClient(s) => Pin::new(s).poll_flush(cx),
68 Stream::TlsServer(s) => Pin::new(s).poll_flush(cx),
69 }
70 }
71
72 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
73 match self.get_mut() {
74 Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
75 Stream::TlsClient(s) => Pin::new(s).poll_shutdown(cx),
76 Stream::TlsServer(s) => Pin::new(s).poll_shutdown(cx),
77 }
78 }
79 }
80
81 /// A framed connection that can send and receive Messages
82 pub struct FramedConnection {
83 stream: Stream,
84 read_buf: BytesMut,
85 write_buf: BytesMut,
86 remote_addr: SocketAddr,
87 /// Peer certificate fingerprint (if TLS)
88 peer_fingerprint: Option<Fingerprint>,
89 }
90
91 impl FramedConnection {
92 /// Create a new framed connection from a plain TcpStream
93 pub fn new(stream: TcpStream) -> io::Result<Self> {
94 let remote_addr = stream.peer_addr()?;
95 Ok(Self {
96 stream: Stream::Plain(stream),
97 read_buf: BytesMut::with_capacity(8192),
98 write_buf: BytesMut::with_capacity(8192),
99 remote_addr,
100 peer_fingerprint: None,
101 })
102 }
103
104 /// Create a new framed connection from a client TLS stream
105 pub fn from_tls_client(
106 stream: ClientTlsStream<TcpStream>,
107 remote_addr: SocketAddr,
108 peer_fingerprint: Option<Fingerprint>,
109 ) -> Self {
110 Self {
111 stream: Stream::TlsClient(stream),
112 read_buf: BytesMut::with_capacity(8192),
113 write_buf: BytesMut::with_capacity(8192),
114 remote_addr,
115 peer_fingerprint,
116 }
117 }
118
119 /// Create a new framed connection from a server TLS stream
120 pub fn from_tls_server(
121 stream: ServerTlsStream<TcpStream>,
122 remote_addr: SocketAddr,
123 ) -> Self {
124 Self {
125 stream: Stream::TlsServer(stream),
126 read_buf: BytesMut::with_capacity(8192),
127 write_buf: BytesMut::with_capacity(8192),
128 remote_addr,
129 peer_fingerprint: None,
130 }
131 }
132
133 /// Get the remote address
134 pub fn remote_addr(&self) -> SocketAddr {
135 self.remote_addr
136 }
137
138 /// Get the peer's certificate fingerprint (if TLS connection)
139 pub fn peer_fingerprint(&self) -> Option<&Fingerprint> {
140 self.peer_fingerprint.as_ref()
141 }
142
143 /// Check if this is a TLS connection
144 pub fn is_tls(&self) -> bool {
145 !matches!(self.stream, Stream::Plain(_))
146 }
147
148 /// Send a message
149 pub async fn send(&mut self, msg: &Message) -> Result<(), TransportError> {
150 // Serialize the message
151 let json = serde_json::to_vec(msg)
152 .map_err(|e| TransportError::Serialize(e.to_string()))?;
153
154 if json.len() > MAX_MESSAGE_SIZE as usize {
155 return Err(TransportError::MessageTooLarge(json.len()));
156 }
157
158 // Write length prefix + data
159 self.write_buf.clear();
160 self.write_buf.put_u32(json.len() as u32);
161 self.write_buf.extend_from_slice(&json);
162
163 self.stream
164 .write_all(&self.write_buf)
165 .await
166 .map_err(TransportError::Io)?;
167
168 self.stream.flush().await.map_err(TransportError::Io)?;
169
170 tracing::trace!("Sent message: {:?}", msg);
171 Ok(())
172 }
173
174 /// Receive a message (blocking until one is available or connection closes)
175 pub async fn recv(&mut self) -> Result<Option<Message>, TransportError> {
176 loop {
177 // Try to parse a complete frame from the buffer
178 if let Some(msg) = self.try_parse_frame()? {
179 return Ok(Some(msg));
180 }
181
182 // Read more data
183 let n = self
184 .stream
185 .read_buf(&mut self.read_buf)
186 .await
187 .map_err(TransportError::Io)?;
188
189 if n == 0 {
190 // Connection closed
191 if self.read_buf.is_empty() {
192 return Ok(None);
193 } else {
194 return Err(TransportError::ConnectionReset);
195 }
196 }
197 }
198 }
199
200 /// Try to parse a complete frame from the read buffer
201 fn try_parse_frame(&mut self) -> Result<Option<Message>, TransportError> {
202 if self.read_buf.len() < FRAME_HEADER_SIZE {
203 return Ok(None);
204 }
205
206 // Peek at the length
207 let len = u32::from_be_bytes([
208 self.read_buf[0],
209 self.read_buf[1],
210 self.read_buf[2],
211 self.read_buf[3],
212 ]) as usize;
213
214 if len > MAX_MESSAGE_SIZE as usize {
215 return Err(TransportError::MessageTooLarge(len));
216 }
217
218 let total_frame_size = FRAME_HEADER_SIZE + len;
219 if self.read_buf.len() < total_frame_size {
220 return Ok(None);
221 }
222
223 // Consume the frame
224 self.read_buf.advance(FRAME_HEADER_SIZE);
225 let json_data = self.read_buf.split_to(len);
226
227 // Deserialize
228 let msg: Message = serde_json::from_slice(&json_data)
229 .map_err(|e| TransportError::Deserialize(e.to_string()))?;
230
231 tracing::trace!("Received message: {:?}", msg);
232 Ok(Some(msg))
233 }
234
235 /// Shutdown the connection gracefully
236 pub async fn shutdown(&mut self) -> io::Result<()> {
237 self.stream.shutdown().await
238 }
239 }
240
241 /// TLS-enabled TCP server that accepts connections
242 pub struct Server {
243 listener: TcpListener,
244 local_addr: SocketAddr,
245 tls_acceptor: Option<TlsAcceptor>,
246 }
247
248 impl Server {
249 /// Bind to an address and start listening (plain TCP)
250 pub async fn bind(addr: SocketAddr) -> Result<Self, TransportError> {
251 let listener = TcpListener::bind(addr).await.map_err(TransportError::Io)?;
252 let local_addr = listener.local_addr().map_err(TransportError::Io)?;
253
254 tracing::info!("Server listening on {} (plain TCP)", local_addr);
255 Ok(Self {
256 listener,
257 local_addr,
258 tls_acceptor: None,
259 })
260 }
261
262 /// Bind to an address with TLS
263 pub async fn bind_tls(
264 addr: SocketAddr,
265 cert_path: &Path,
266 key_path: &Path,
267 ) -> Result<Self, TransportError> {
268 let listener = TcpListener::bind(addr).await.map_err(TransportError::Io)?;
269 let local_addr = listener.local_addr().map_err(TransportError::Io)?;
270
271 let acceptor = tls::create_tls_acceptor(cert_path, key_path)
272 .map_err(|e| TransportError::Tls(e.to_string()))?;
273
274 tracing::info!("Server listening on {} (TLS)", local_addr);
275 Ok(Self {
276 listener,
277 local_addr,
278 tls_acceptor: Some(acceptor),
279 })
280 }
281
282 /// Get the local address
283 pub fn local_addr(&self) -> SocketAddr {
284 self.local_addr
285 }
286
287 /// Check if TLS is enabled
288 pub fn is_tls(&self) -> bool {
289 self.tls_acceptor.is_some()
290 }
291
292 /// Accept a new connection
293 pub async fn accept(&self) -> Result<FramedConnection, TransportError> {
294 let (stream, addr) = self.listener.accept().await.map_err(TransportError::Io)?;
295 // Disable Nagle's algorithm for low-latency input forwarding
296 stream.set_nodelay(true).map_err(TransportError::Io)?;
297
298 if let Some(ref acceptor) = self.tls_acceptor {
299 tracing::debug!("Performing TLS handshake with {}", addr);
300 let tls_stream = acceptor
301 .accept(stream)
302 .await
303 .map_err(|e| TransportError::Tls(format!("TLS handshake failed: {}", e)))?;
304
305 tracing::info!("Accepted TLS connection from {}", addr);
306 Ok(FramedConnection::from_tls_server(tls_stream, addr))
307 } else {
308 tracing::info!("Accepted connection from {}", addr);
309 FramedConnection::new(stream).map_err(TransportError::Io)
310 }
311 }
312 }
313
314 /// Connect to a remote server (plain TCP)
315 pub async fn connect(addr: SocketAddr) -> Result<FramedConnection, TransportError> {
316 tracing::info!("Connecting to {} (plain TCP)", addr);
317 let stream = TcpStream::connect(addr).await.map_err(TransportError::Io)?;
318 // Disable Nagle's algorithm for low-latency input forwarding
319 stream.set_nodelay(true).map_err(TransportError::Io)?;
320 tracing::info!("Connected to {}", addr);
321 FramedConnection::new(stream).map_err(TransportError::Io)
322 }
323
324 /// Connect to a remote server with TLS
325 pub async fn connect_tls(
326 addr: SocketAddr,
327 server_name: &str,
328 expected_fingerprint: Option<&str>,
329 tofu_enabled: bool,
330 ) -> Result<FramedConnection, TransportError> {
331 tracing::info!("Connecting to {} (TLS, server_name={})", addr, server_name);
332
333 let stream = TcpStream::connect(addr).await.map_err(TransportError::Io)?;
334 stream.set_nodelay(true).map_err(TransportError::Io)?;
335
336 let connector = tls::create_tls_connector(expected_fingerprint, tofu_enabled)
337 .map_err(|e| TransportError::Tls(e.to_string()))?;
338
339 let server_name = rustls::pki_types::ServerName::try_from(server_name.to_string())
340 .map_err(|e| TransportError::Tls(format!("Invalid server name: {}", e)))?;
341
342 tracing::debug!("Performing TLS handshake with {}", addr);
343 let tls_stream = connector
344 .connect(server_name, stream)
345 .await
346 .map_err(|e| TransportError::Tls(format!("TLS handshake failed: {}", e)))?;
347
348 // Extract peer certificate fingerprint
349 let peer_fingerprint = tls_stream
350 .get_ref()
351 .1
352 .peer_certificates()
353 .and_then(|certs| certs.first())
354 .map(|cert| Fingerprint::from_der(cert.as_ref()));
355
356 if let Some(ref fp) = peer_fingerprint {
357 tracing::info!("Connected to {} (TLS), peer fingerprint: {}", addr, fp);
358 } else {
359 tracing::info!("Connected to {} (TLS)", addr);
360 }
361
362 Ok(FramedConnection::from_tls_client(tls_stream, addr, peer_fingerprint))
363 }
364
365 #[derive(Debug, thiserror::Error)]
366 pub enum TransportError {
367 #[error("IO error: {0}")]
368 Io(#[from] io::Error),
369
370 #[error("TLS error: {0}")]
371 Tls(String),
372
373 #[error("Failed to serialize message: {0}")]
374 Serialize(String),
375
376 #[error("Failed to deserialize message: {0}")]
377 Deserialize(String),
378
379 #[error("Message too large: {0} bytes")]
380 MessageTooLarge(usize),
381
382 #[error("Connection reset while reading")]
383 ConnectionReset,
384 }
385
386 #[cfg(test)]
387 mod tests {
388 use super::*;
389 use hyprkvm_common::protocol::{HelloPayload, PROTOCOL_VERSION};
390
391 #[tokio::test]
392 async fn test_roundtrip() {
393 let server = Server::bind("127.0.0.1:0".parse().unwrap()).await.unwrap();
394 let addr = server.local_addr();
395
396 let server_handle = tokio::spawn(async move {
397 let mut conn = server.accept().await.unwrap();
398 let msg = conn.recv().await.unwrap().unwrap();
399 conn.send(&msg).await.unwrap();
400 conn.shutdown().await.unwrap();
401 });
402
403 let mut client = connect(addr).await.unwrap();
404 let msg = Message::Hello(HelloPayload {
405 protocol_version: PROTOCOL_VERSION,
406 machine_name: "test".to_string(),
407 capabilities: vec![],
408 my_direction_for_you: None,
409 });
410
411 client.send(&msg).await.unwrap();
412 let echo = client.recv().await.unwrap().unwrap();
413
414 if let (Message::Hello(sent), Message::Hello(received)) = (&msg, &echo) {
415 assert_eq!(sent.protocol_version, received.protocol_version);
416 assert_eq!(sent.machine_name, received.machine_name);
417 } else {
418 panic!("Wrong message type");
419 }
420
421 server_handle.await.unwrap();
422 }
423 }
424