Rust · 8267 bytes Raw Blame History
1 //! Peer connection management
2 //!
3 //! Handles connections to other HyprKVM instances.
4
5 use std::collections::HashMap;
6 use std::net::SocketAddr;
7 use std::sync::Arc;
8 use std::time::Duration;
9
10 use tokio::sync::{mpsc, RwLock};
11 use tokio::time::timeout;
12
13 use hyprkvm_common::protocol::{
14 HelloAckPayload, HelloPayload, Message, PROTOCOL_VERSION,
15 };
16
17 use super::transport::{connect, FramedConnection, Server, TransportError};
18
19 /// Default timeout for handshake
20 const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
21
22 /// A connected peer
23 pub struct Peer {
24 /// Connection to the peer
25 pub conn: FramedConnection,
26 /// Peer's machine name
27 pub name: String,
28 /// Peer's protocol version
29 pub protocol_version: u32,
30 /// Peer's capabilities
31 pub capabilities: Vec<String>,
32 }
33
34 impl Peer {
35 /// Perform client-side handshake
36 pub async fn handshake_client(
37 mut conn: FramedConnection,
38 our_name: &str,
39 our_capabilities: &[String],
40 ) -> Result<Self, PeerError> {
41 // Send Hello
42 let hello = Message::Hello(HelloPayload {
43 protocol_version: PROTOCOL_VERSION,
44 machine_name: our_name.to_string(),
45 capabilities: our_capabilities.to_vec(),
46 });
47
48 conn.send(&hello).await?;
49
50 // Wait for HelloAck
51 let response = timeout(HANDSHAKE_TIMEOUT, conn.recv())
52 .await
53 .map_err(|_| PeerError::HandshakeTimeout)?
54 .map_err(PeerError::Transport)?
55 .ok_or(PeerError::ConnectionClosed)?;
56
57 match response {
58 Message::HelloAck(ack) => {
59 if !ack.accepted {
60 return Err(PeerError::HandshakeRejected(
61 ack.error.unwrap_or_else(|| "Unknown reason".to_string()),
62 ));
63 }
64
65 tracing::info!(
66 "Handshake successful with peer '{}' (protocol v{})",
67 ack.machine_name,
68 ack.protocol_version
69 );
70
71 Ok(Self {
72 conn,
73 name: ack.machine_name,
74 protocol_version: ack.protocol_version,
75 capabilities: vec![], // Server doesn't send capabilities in ack
76 })
77 }
78 _ => Err(PeerError::UnexpectedMessage),
79 }
80 }
81
82 /// Perform server-side handshake
83 pub async fn handshake_server(
84 mut conn: FramedConnection,
85 our_name: &str,
86 our_capabilities: &[String],
87 ) -> Result<Self, PeerError> {
88 // Wait for Hello
89 let request = timeout(HANDSHAKE_TIMEOUT, conn.recv())
90 .await
91 .map_err(|_| PeerError::HandshakeTimeout)?
92 .map_err(PeerError::Transport)?
93 .ok_or(PeerError::ConnectionClosed)?;
94
95 match request {
96 Message::Hello(hello) => {
97 // Check protocol version compatibility
98 let accepted = hello.protocol_version == PROTOCOL_VERSION;
99 let error = if !accepted {
100 Some(format!(
101 "Protocol version mismatch: expected {}, got {}",
102 PROTOCOL_VERSION, hello.protocol_version
103 ))
104 } else {
105 None
106 };
107
108 // Send HelloAck
109 let ack = Message::HelloAck(HelloAckPayload {
110 accepted,
111 protocol_version: PROTOCOL_VERSION,
112 machine_name: our_name.to_string(),
113 error: error.clone(),
114 });
115
116 conn.send(&ack).await?;
117
118 if !accepted {
119 return Err(PeerError::HandshakeRejected(error.unwrap()));
120 }
121
122 tracing::info!(
123 "Handshake successful with peer '{}' (protocol v{})",
124 hello.machine_name,
125 hello.protocol_version
126 );
127
128 Ok(Self {
129 conn,
130 name: hello.machine_name,
131 protocol_version: hello.protocol_version,
132 capabilities: hello.capabilities,
133 })
134 }
135 _ => Err(PeerError::UnexpectedMessage),
136 }
137 }
138
139 /// Send a message to the peer
140 pub async fn send(&mut self, msg: &Message) -> Result<(), PeerError> {
141 self.conn.send(msg).await.map_err(PeerError::Transport)
142 }
143
144 /// Receive a message from the peer
145 pub async fn recv(&mut self) -> Result<Option<Message>, PeerError> {
146 self.conn.recv().await.map_err(PeerError::Transport)
147 }
148
149 /// Get the remote address
150 pub fn remote_addr(&self) -> SocketAddr {
151 self.conn.remote_addr()
152 }
153 }
154
155 /// Manages connections to multiple peers
156 pub struct PeerManager {
157 /// Our machine name
158 our_name: String,
159 /// Our capabilities
160 our_capabilities: Vec<String>,
161 /// Connected peers by name
162 peers: Arc<RwLock<HashMap<String, Peer>>>,
163 /// Server for incoming connections
164 server: Option<Server>,
165 }
166
167 impl PeerManager {
168 /// Create a new peer manager
169 pub fn new(our_name: String, our_capabilities: Vec<String>) -> Self {
170 Self {
171 our_name,
172 our_capabilities,
173 peers: Arc::new(RwLock::new(HashMap::new())),
174 server: None,
175 }
176 }
177
178 /// Start listening for incoming connections
179 pub async fn listen(&mut self, addr: SocketAddr) -> Result<SocketAddr, PeerError> {
180 let server = Server::bind(addr).await.map_err(PeerError::Transport)?;
181 let local_addr = server.local_addr();
182 self.server = Some(server);
183 Ok(local_addr)
184 }
185
186 /// Accept a new incoming connection (call in a loop)
187 pub async fn accept(&self) -> Result<Option<Peer>, PeerError> {
188 let server = match &self.server {
189 Some(s) => s,
190 None => return Ok(None),
191 };
192
193 let conn = server.accept().await.map_err(PeerError::Transport)?;
194 let peer = Peer::handshake_server(conn, &self.our_name, &self.our_capabilities).await?;
195
196 // Add to peers map
197 {
198 let mut peers = self.peers.write().await;
199 peers.insert(peer.name.clone(), peer);
200 }
201
202 // Return None to indicate we stored the peer internally
203 // Callers should use get_peer to access it
204 Ok(None)
205 }
206
207 /// Connect to a remote peer
208 pub async fn connect(&self, addr: SocketAddr) -> Result<String, PeerError> {
209 let conn = connect(addr).await.map_err(PeerError::Transport)?;
210 let peer = Peer::handshake_client(conn, &self.our_name, &self.our_capabilities).await?;
211 let name = peer.name.clone();
212
213 // Add to peers map
214 {
215 let mut peers = self.peers.write().await;
216 peers.insert(name.clone(), peer);
217 }
218
219 Ok(name)
220 }
221
222 /// Get a peer by name
223 pub async fn get_peer(&self, name: &str) -> Option<tokio::sync::RwLockReadGuard<'_, HashMap<String, Peer>>> {
224 let peers = self.peers.read().await;
225 if peers.contains_key(name) {
226 Some(peers)
227 } else {
228 None
229 }
230 }
231
232 /// Remove a peer
233 pub async fn remove_peer(&self, name: &str) -> Option<Peer> {
234 let mut peers = self.peers.write().await;
235 peers.remove(name)
236 }
237
238 /// Get names of all connected peers
239 pub async fn peer_names(&self) -> Vec<String> {
240 let peers = self.peers.read().await;
241 peers.keys().cloned().collect()
242 }
243
244 /// Check if a peer is connected
245 pub async fn is_connected(&self, name: &str) -> bool {
246 let peers = self.peers.read().await;
247 peers.contains_key(name)
248 }
249 }
250
251 #[derive(Debug, thiserror::Error)]
252 pub enum PeerError {
253 #[error("Transport error: {0}")]
254 Transport(#[from] TransportError),
255
256 #[error("Handshake timeout")]
257 HandshakeTimeout,
258
259 #[error("Handshake rejected: {0}")]
260 HandshakeRejected(String),
261
262 #[error("Connection closed")]
263 ConnectionClosed,
264
265 #[error("Unexpected message during handshake")]
266 UnexpectedMessage,
267
268 #[error("Peer not found: {0}")]
269 PeerNotFound(String),
270 }
271