Rust · 8339 bytes Raw Blame History
1 //! Peer connection management
2 //!
3 //! Handles connections to other HyprKVM instances.
4
5 use std::collections::HashMap;
6 use std::net::SocketAddr;
7 use std::sync::Arc;
8 use std::time::Duration;
9
10 use tokio::sync::RwLock;
11 use tokio::time::timeout;
12
13 use hyprkvm_common::protocol::{
14 HelloAckPayload, HelloPayload, Message, PROTOCOL_VERSION,
15 };
16
17 use super::transport::{connect, FramedConnection, Server, TransportError};
18
19 /// Default timeout for handshake
20 const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5);
21
22 /// A connected peer
23 pub struct Peer {
24 /// Connection to the peer
25 pub conn: FramedConnection,
26 /// Peer's machine name
27 pub name: String,
28 /// Peer's protocol version
29 pub protocol_version: u32,
30 /// Peer's capabilities
31 pub capabilities: Vec<String>,
32 }
33
34 impl Peer {
35 /// Perform client-side handshake
36 pub async fn handshake_client(
37 mut conn: FramedConnection,
38 our_name: &str,
39 our_capabilities: &[String],
40 ) -> Result<Self, PeerError> {
41 // Send Hello
42 let hello = Message::Hello(HelloPayload {
43 protocol_version: PROTOCOL_VERSION,
44 machine_name: our_name.to_string(),
45 capabilities: our_capabilities.to_vec(),
46 my_direction_for_you: None, // Direction not known in this context
47 });
48
49 conn.send(&hello).await?;
50
51 // Wait for HelloAck
52 let response = timeout(HANDSHAKE_TIMEOUT, conn.recv())
53 .await
54 .map_err(|_| PeerError::HandshakeTimeout)?
55 .map_err(PeerError::Transport)?
56 .ok_or(PeerError::ConnectionClosed)?;
57
58 match response {
59 Message::HelloAck(ack) => {
60 if !ack.accepted {
61 return Err(PeerError::HandshakeRejected(
62 ack.error.unwrap_or_else(|| "Unknown reason".to_string()),
63 ));
64 }
65
66 tracing::info!(
67 "Handshake successful with peer '{}' (protocol v{})",
68 ack.machine_name,
69 ack.protocol_version
70 );
71
72 Ok(Self {
73 conn,
74 name: ack.machine_name,
75 protocol_version: ack.protocol_version,
76 capabilities: vec![], // Server doesn't send capabilities in ack
77 })
78 }
79 _ => Err(PeerError::UnexpectedMessage),
80 }
81 }
82
83 /// Perform server-side handshake
84 pub async fn handshake_server(
85 mut conn: FramedConnection,
86 our_name: &str,
87 _our_capabilities: &[String],
88 ) -> Result<Self, PeerError> {
89 // Wait for Hello
90 let request = timeout(HANDSHAKE_TIMEOUT, conn.recv())
91 .await
92 .map_err(|_| PeerError::HandshakeTimeout)?
93 .map_err(PeerError::Transport)?
94 .ok_or(PeerError::ConnectionClosed)?;
95
96 match request {
97 Message::Hello(hello) => {
98 // Check protocol version compatibility
99 let accepted = hello.protocol_version == PROTOCOL_VERSION;
100 let error = if !accepted {
101 Some(format!(
102 "Protocol version mismatch: expected {}, got {}",
103 PROTOCOL_VERSION, hello.protocol_version
104 ))
105 } else {
106 None
107 };
108
109 // Send HelloAck
110 let ack = Message::HelloAck(HelloAckPayload {
111 accepted,
112 protocol_version: PROTOCOL_VERSION,
113 machine_name: our_name.to_string(),
114 error: error.clone(),
115 });
116
117 conn.send(&ack).await?;
118
119 if !accepted {
120 return Err(PeerError::HandshakeRejected(error.unwrap()));
121 }
122
123 tracing::info!(
124 "Handshake successful with peer '{}' (protocol v{})",
125 hello.machine_name,
126 hello.protocol_version
127 );
128
129 Ok(Self {
130 conn,
131 name: hello.machine_name,
132 protocol_version: hello.protocol_version,
133 capabilities: hello.capabilities,
134 })
135 }
136 _ => Err(PeerError::UnexpectedMessage),
137 }
138 }
139
140 /// Send a message to the peer
141 pub async fn send(&mut self, msg: &Message) -> Result<(), PeerError> {
142 self.conn.send(msg).await.map_err(PeerError::Transport)
143 }
144
145 /// Receive a message from the peer
146 pub async fn recv(&mut self) -> Result<Option<Message>, PeerError> {
147 self.conn.recv().await.map_err(PeerError::Transport)
148 }
149
150 /// Get the remote address
151 pub fn remote_addr(&self) -> SocketAddr {
152 self.conn.remote_addr()
153 }
154 }
155
156 /// Manages connections to multiple peers
157 pub struct PeerManager {
158 /// Our machine name
159 our_name: String,
160 /// Our capabilities
161 our_capabilities: Vec<String>,
162 /// Connected peers by name
163 peers: Arc<RwLock<HashMap<String, Peer>>>,
164 /// Server for incoming connections
165 server: Option<Server>,
166 }
167
168 impl PeerManager {
169 /// Create a new peer manager
170 pub fn new(our_name: String, our_capabilities: Vec<String>) -> Self {
171 Self {
172 our_name,
173 our_capabilities,
174 peers: Arc::new(RwLock::new(HashMap::new())),
175 server: None,
176 }
177 }
178
179 /// Start listening for incoming connections
180 pub async fn listen(&mut self, addr: SocketAddr) -> Result<SocketAddr, PeerError> {
181 let server = Server::bind(addr).await.map_err(PeerError::Transport)?;
182 let local_addr = server.local_addr();
183 self.server = Some(server);
184 Ok(local_addr)
185 }
186
187 /// Accept a new incoming connection (call in a loop)
188 pub async fn accept(&self) -> Result<Option<Peer>, PeerError> {
189 let server = match &self.server {
190 Some(s) => s,
191 None => return Ok(None),
192 };
193
194 let conn = server.accept().await.map_err(PeerError::Transport)?;
195 let peer = Peer::handshake_server(conn, &self.our_name, &self.our_capabilities).await?;
196
197 // Add to peers map
198 {
199 let mut peers = self.peers.write().await;
200 peers.insert(peer.name.clone(), peer);
201 }
202
203 // Return None to indicate we stored the peer internally
204 // Callers should use get_peer to access it
205 Ok(None)
206 }
207
208 /// Connect to a remote peer
209 pub async fn connect(&self, addr: SocketAddr) -> Result<String, PeerError> {
210 let conn = connect(addr).await.map_err(PeerError::Transport)?;
211 let peer = Peer::handshake_client(conn, &self.our_name, &self.our_capabilities).await?;
212 let name = peer.name.clone();
213
214 // Add to peers map
215 {
216 let mut peers = self.peers.write().await;
217 peers.insert(name.clone(), peer);
218 }
219
220 Ok(name)
221 }
222
223 /// Get a peer by name
224 pub async fn get_peer(&self, name: &str) -> Option<tokio::sync::RwLockReadGuard<'_, HashMap<String, Peer>>> {
225 let peers = self.peers.read().await;
226 if peers.contains_key(name) {
227 Some(peers)
228 } else {
229 None
230 }
231 }
232
233 /// Remove a peer
234 pub async fn remove_peer(&self, name: &str) -> Option<Peer> {
235 let mut peers = self.peers.write().await;
236 peers.remove(name)
237 }
238
239 /// Get names of all connected peers
240 pub async fn peer_names(&self) -> Vec<String> {
241 let peers = self.peers.read().await;
242 peers.keys().cloned().collect()
243 }
244
245 /// Check if a peer is connected
246 pub async fn is_connected(&self, name: &str) -> bool {
247 let peers = self.peers.read().await;
248 peers.contains_key(name)
249 }
250 }
251
252 #[derive(Debug, thiserror::Error)]
253 pub enum PeerError {
254 #[error("Transport error: {0}")]
255 Transport(#[from] TransportError),
256
257 #[error("Handshake timeout")]
258 HandshakeTimeout,
259
260 #[error("Handshake rejected: {0}")]
261 HandshakeRejected(String),
262
263 #[error("Connection closed")]
264 ConnectionClosed,
265
266 #[error("Unexpected message during handshake")]
267 UnexpectedMessage,
268
269 #[error("Peer not found: {0}")]
270 PeerNotFound(String),
271 }
272