@@ -0,0 +1,434 @@ |
| | 1 | +//! Network socket collection from /proc/net/* and netlink INET_DIAG |
| | 2 | +//! |
| | 3 | +//! Provides per-process network statistics by correlating socket inodes |
| | 4 | +//! from /proc/net/tcp,udp with process file descriptors. |
| | 5 | + |
| | 6 | +use std::collections::HashMap; |
| | 7 | +use std::fs; |
| | 8 | +use std::io; |
| | 9 | +use std::time::Instant; |
| | 10 | + |
| | 11 | +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_REQUEST}; |
| | 12 | +use netlink_packet_sock_diag::{ |
| | 13 | + constants::*, |
| | 14 | + inet::{ExtensionFlags, InetRequest, InetResponse, SocketId, StateFlags}, |
| | 15 | + SockDiagMessage, |
| | 16 | +}; |
| | 17 | +use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr}; |
| | 18 | + |
| | 19 | +/// TCP connection states (from kernel include/net/tcp_states.h) |
| | 20 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| | 21 | +pub enum TcpState { |
| | 22 | + Established = 1, |
| | 23 | + SynSent = 2, |
| | 24 | + SynRecv = 3, |
| | 25 | + FinWait1 = 4, |
| | 26 | + FinWait2 = 5, |
| | 27 | + TimeWait = 6, |
| | 28 | + Close = 7, |
| | 29 | + CloseWait = 8, |
| | 30 | + LastAck = 9, |
| | 31 | + Listen = 10, |
| | 32 | + Closing = 11, |
| | 33 | + Unknown = 0, |
| | 34 | +} |
| | 35 | + |
| | 36 | +impl From<u8> for TcpState { |
| | 37 | + fn from(state: u8) -> Self { |
| | 38 | + match state { |
| | 39 | + 1 => TcpState::Established, |
| | 40 | + 2 => TcpState::SynSent, |
| | 41 | + 3 => TcpState::SynRecv, |
| | 42 | + 4 => TcpState::FinWait1, |
| | 43 | + 5 => TcpState::FinWait2, |
| | 44 | + 6 => TcpState::TimeWait, |
| | 45 | + 7 => TcpState::Close, |
| | 46 | + 8 => TcpState::CloseWait, |
| | 47 | + 9 => TcpState::LastAck, |
| | 48 | + 10 => TcpState::Listen, |
| | 49 | + 11 => TcpState::Closing, |
| | 50 | + _ => TcpState::Unknown, |
| | 51 | + } |
| | 52 | + } |
| | 53 | +} |
| | 54 | + |
| | 55 | +/// Socket protocol type. |
| | 56 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| | 57 | +pub enum SocketProtocol { |
| | 58 | + Tcp, |
| | 59 | + Udp, |
| | 60 | +} |
| | 61 | + |
| | 62 | +/// Information about a network socket. |
| | 63 | +#[derive(Debug, Clone)] |
| | 64 | +pub struct SocketInfo { |
| | 65 | + pub protocol: SocketProtocol, |
| | 66 | + pub state: TcpState, |
| | 67 | + pub local_port: u16, |
| | 68 | + pub remote_port: u16, |
| | 69 | + pub tx_queue: u32, |
| | 70 | + pub rx_queue: u32, |
| | 71 | + /// Bytes received (from netlink, if available) |
| | 72 | + pub rx_bytes: u64, |
| | 73 | + /// Bytes transmitted (from netlink, if available) |
| | 74 | + pub tx_bytes: u64, |
| | 75 | +} |
| | 76 | + |
| | 77 | +/// Per-process network statistics. |
| | 78 | +#[derive(Debug, Clone, Default)] |
| | 79 | +pub struct ProcessNetStats { |
| | 80 | + pub tcp_count: u32, |
| | 81 | + pub udp_count: u32, |
| | 82 | + pub listen_count: u32, |
| | 83 | + pub established_count: u32, |
| | 84 | + pub total_tx_queue: u64, |
| | 85 | + pub total_rx_queue: u64, |
| | 86 | + pub total_rx_bytes: u64, |
| | 87 | + pub total_tx_bytes: u64, |
| | 88 | +} |
| | 89 | + |
| | 90 | +/// Previous socket bytes for rate calculation. |
| | 91 | +struct PrevSocketBytes { |
| | 92 | + rx_bytes: u64, |
| | 93 | + tx_bytes: u64, |
| | 94 | + timestamp: Instant, |
| | 95 | +} |
| | 96 | + |
| | 97 | +/// Socket collector that parses /proc/net/* files and uses netlink for bandwidth. |
| | 98 | +pub struct SocketCollector { |
| | 99 | + /// Cache of inode -> socket info |
| | 100 | + socket_cache: HashMap<u64, SocketInfo>, |
| | 101 | + /// Previous bytes per inode for rate calculation |
| | 102 | + prev_bytes: HashMap<u64, PrevSocketBytes>, |
| | 103 | + /// Netlink socket (lazy initialized) |
| | 104 | + nl_socket: Option<Socket>, |
| | 105 | +} |
| | 106 | + |
| | 107 | +impl SocketCollector { |
| | 108 | + /// Create a new socket collector. |
| | 109 | + pub fn new() -> Self { |
| | 110 | + Self { |
| | 111 | + socket_cache: HashMap::new(), |
| | 112 | + prev_bytes: HashMap::new(), |
| | 113 | + nl_socket: None, |
| | 114 | + } |
| | 115 | + } |
| | 116 | + |
| | 117 | + /// Initialize netlink socket if not already done. |
| | 118 | + fn init_netlink(&mut self) -> io::Result<()> { |
| | 119 | + if self.nl_socket.is_none() { |
| | 120 | + let mut socket = Socket::new(NETLINK_SOCK_DIAG)?; |
| | 121 | + socket.bind_auto()?; |
| | 122 | + socket.connect(&SocketAddr::new(0, 0))?; |
| | 123 | + self.nl_socket = Some(socket); |
| | 124 | + } |
| | 125 | + Ok(()) |
| | 126 | + } |
| | 127 | + |
| | 128 | + /// Refresh the socket cache by parsing /proc/net/* files and querying netlink. |
| | 129 | + pub fn refresh(&mut self) { |
| | 130 | + self.socket_cache.clear(); |
| | 131 | + |
| | 132 | + // First parse procfs for basic socket info |
| | 133 | + self.parse_proc_net("/proc/net/tcp", SocketProtocol::Tcp); |
| | 134 | + self.parse_proc_net("/proc/net/tcp6", SocketProtocol::Tcp); |
| | 135 | + self.parse_proc_net("/proc/net/udp", SocketProtocol::Udp); |
| | 136 | + self.parse_proc_net("/proc/net/udp6", SocketProtocol::Udp); |
| | 137 | + |
| | 138 | + // Then try to get bandwidth info from netlink |
| | 139 | + if let Err(e) = self.query_netlink_tcp() { |
| | 140 | + tracing::debug!("Netlink TCP query failed: {}", e); |
| | 141 | + } |
| | 142 | + if let Err(e) = self.query_netlink_udp() { |
| | 143 | + tracing::debug!("Netlink UDP query failed: {}", e); |
| | 144 | + } |
| | 145 | + |
| | 146 | + // Clean up old prev_bytes entries |
| | 147 | + let current_inodes: std::collections::HashSet<_> = self.socket_cache.keys().copied().collect(); |
| | 148 | + self.prev_bytes.retain(|inode, _| current_inodes.contains(inode)); |
| | 149 | + } |
| | 150 | + |
| | 151 | + /// Parse a /proc/net/* file and populate the cache. |
| | 152 | + fn parse_proc_net(&mut self, path: &str, protocol: SocketProtocol) { |
| | 153 | + let contents = match fs::read_to_string(path) { |
| | 154 | + Ok(c) => c, |
| | 155 | + Err(_) => return, |
| | 156 | + }; |
| | 157 | + |
| | 158 | + for line in contents.lines().skip(1) { |
| | 159 | + if let Some((inode, info)) = self.parse_socket_line(line, protocol) { |
| | 160 | + self.socket_cache.insert(inode, info); |
| | 161 | + } |
| | 162 | + } |
| | 163 | + } |
| | 164 | + |
| | 165 | + /// Parse a single line from /proc/net/tcp or /proc/net/udp. |
| | 166 | + fn parse_socket_line(&self, line: &str, protocol: SocketProtocol) -> Option<(u64, SocketInfo)> { |
| | 167 | + let parts: Vec<&str> = line.split_whitespace().collect(); |
| | 168 | + if parts.len() < 10 { |
| | 169 | + return None; |
| | 170 | + } |
| | 171 | + |
| | 172 | + let local_port = self.parse_addr_port(parts[1])?; |
| | 173 | + let remote_port = self.parse_addr_port(parts[2])?; |
| | 174 | + let state_hex = u8::from_str_radix(parts[3], 16).ok()?; |
| | 175 | + let state = TcpState::from(state_hex); |
| | 176 | + |
| | 177 | + let queues: Vec<&str> = parts[4].split(':').collect(); |
| | 178 | + let tx_queue = u32::from_str_radix(queues.first()?, 16).unwrap_or(0); |
| | 179 | + let rx_queue = u32::from_str_radix(queues.get(1)?, 16).unwrap_or(0); |
| | 180 | + |
| | 181 | + let inode: u64 = parts[9].parse().ok()?; |
| | 182 | + if inode == 0 { |
| | 183 | + return None; |
| | 184 | + } |
| | 185 | + |
| | 186 | + Some(( |
| | 187 | + inode, |
| | 188 | + SocketInfo { |
| | 189 | + protocol, |
| | 190 | + state, |
| | 191 | + local_port, |
| | 192 | + remote_port, |
| | 193 | + tx_queue, |
| | 194 | + rx_queue, |
| | 195 | + rx_bytes: 0, |
| | 196 | + tx_bytes: 0, |
| | 197 | + }, |
| | 198 | + )) |
| | 199 | + } |
| | 200 | + |
| | 201 | + /// Parse port from address string (IP:PORT in hex). |
| | 202 | + fn parse_addr_port(&self, addr: &str) -> Option<u16> { |
| | 203 | + let parts: Vec<&str> = addr.split(':').collect(); |
| | 204 | + if parts.len() != 2 { |
| | 205 | + return None; |
| | 206 | + } |
| | 207 | + u16::from_str_radix(parts[1], 16).ok() |
| | 208 | + } |
| | 209 | + |
| | 210 | + /// Query netlink for TCP socket info with byte counts. |
| | 211 | + fn query_netlink_tcp(&mut self) -> io::Result<()> { |
| | 212 | + self.init_netlink()?; |
| | 213 | + |
| | 214 | + // Request all TCP sockets with extended info - IPv4 |
| | 215 | + let req_v4 = InetRequest { |
| | 216 | + family: AF_INET as u8, |
| | 217 | + protocol: IPPROTO_TCP as u8, |
| | 218 | + socket_id: SocketId::new_v4(), |
| | 219 | + extensions: ExtensionFlags::INFO, |
| | 220 | + states: StateFlags::all(), |
| | 221 | + }; |
| | 222 | + |
| | 223 | + // IPv6 request |
| | 224 | + let req_v6 = InetRequest { |
| | 225 | + family: AF_INET6 as u8, |
| | 226 | + protocol: IPPROTO_TCP as u8, |
| | 227 | + socket_id: SocketId::new_v6(), |
| | 228 | + extensions: ExtensionFlags::INFO, |
| | 229 | + states: StateFlags::all(), |
| | 230 | + }; |
| | 231 | + |
| | 232 | + // Process IPv4 |
| | 233 | + self.send_and_recv_inet_diag(&req_v4, SocketProtocol::Tcp)?; |
| | 234 | + // Process IPv6 |
| | 235 | + self.send_and_recv_inet_diag(&req_v6, SocketProtocol::Tcp)?; |
| | 236 | + |
| | 237 | + Ok(()) |
| | 238 | + } |
| | 239 | + |
| | 240 | + /// Query netlink for UDP socket info. |
| | 241 | + fn query_netlink_udp(&mut self) -> io::Result<()> { |
| | 242 | + self.init_netlink()?; |
| | 243 | + |
| | 244 | + let req_v4 = InetRequest { |
| | 245 | + family: AF_INET as u8, |
| | 246 | + protocol: IPPROTO_UDP as u8, |
| | 247 | + socket_id: SocketId::new_v4(), |
| | 248 | + extensions: ExtensionFlags::empty(), |
| | 249 | + states: StateFlags::all(), |
| | 250 | + }; |
| | 251 | + |
| | 252 | + let req_v6 = InetRequest { |
| | 253 | + family: AF_INET6 as u8, |
| | 254 | + protocol: IPPROTO_UDP as u8, |
| | 255 | + socket_id: SocketId::new_v6(), |
| | 256 | + extensions: ExtensionFlags::empty(), |
| | 257 | + states: StateFlags::all(), |
| | 258 | + }; |
| | 259 | + |
| | 260 | + self.send_and_recv_inet_diag(&req_v4, SocketProtocol::Udp)?; |
| | 261 | + self.send_and_recv_inet_diag(&req_v6, SocketProtocol::Udp)?; |
| | 262 | + |
| | 263 | + Ok(()) |
| | 264 | + } |
| | 265 | + |
| | 266 | + /// Send request and receive responses in one method to avoid borrow issues. |
| | 267 | + fn send_and_recv_inet_diag(&mut self, req: &InetRequest, protocol: SocketProtocol) -> io::Result<()> { |
| | 268 | + // Build the message |
| | 269 | + let msg = SockDiagMessage::InetRequest(req.clone()); |
| | 270 | + let mut nl_msg = NetlinkMessage::from(msg); |
| | 271 | + nl_msg.header.flags = NLM_F_REQUEST | NLM_F_DUMP; |
| | 272 | + nl_msg.header.sequence_number = 1; |
| | 273 | + nl_msg.finalize(); |
| | 274 | + |
| | 275 | + let mut send_buf = vec![0u8; nl_msg.header.length as usize]; |
| | 276 | + nl_msg.serialize(&mut send_buf); |
| | 277 | + |
| | 278 | + // Collect responses first, then process them |
| | 279 | + let mut responses = Vec::new(); |
| | 280 | + |
| | 281 | + { |
| | 282 | + // Scope the socket borrow |
| | 283 | + let socket = self.nl_socket.as_ref().unwrap(); |
| | 284 | + socket.send(&send_buf, 0)?; |
| | 285 | + |
| | 286 | + let mut recv_buf = vec![0u8; 65536]; |
| | 287 | + loop { |
| | 288 | + let n = match socket.recv(&mut recv_buf, 0) { |
| | 289 | + Ok(n) => n, |
| | 290 | + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, |
| | 291 | + Err(e) => return Err(e), |
| | 292 | + }; |
| | 293 | + |
| | 294 | + if n == 0 { |
| | 295 | + break; |
| | 296 | + } |
| | 297 | + |
| | 298 | + let mut offset = 0; |
| | 299 | + while offset < n { |
| | 300 | + let msg = match NetlinkMessage::<SockDiagMessage>::deserialize(&recv_buf[offset..n]) { |
| | 301 | + Ok(msg) => msg, |
| | 302 | + Err(_) => break, |
| | 303 | + }; |
| | 304 | + |
| | 305 | + offset += msg.header.length as usize; |
| | 306 | + |
| | 307 | + match msg.payload { |
| | 308 | + NetlinkPayload::Done(_) => break, |
| | 309 | + NetlinkPayload::Error(e) => { |
| | 310 | + if e.code.is_some() { |
| | 311 | + return Err(io::Error::new(io::ErrorKind::Other, "netlink error")); |
| | 312 | + } |
| | 313 | + } |
| | 314 | + NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(resp)) => { |
| | 315 | + responses.push(resp); |
| | 316 | + } |
| | 317 | + _ => {} |
| | 318 | + } |
| | 319 | + } |
| | 320 | + } |
| | 321 | + } |
| | 322 | + |
| | 323 | + // Now process collected responses |
| | 324 | + for resp in responses { |
| | 325 | + self.process_inet_response(&resp, protocol); |
| | 326 | + } |
| | 327 | + |
| | 328 | + Ok(()) |
| | 329 | + } |
| | 330 | + |
| | 331 | + /// Process a single INET_DIAG response. |
| | 332 | + fn process_inet_response(&mut self, resp: &InetResponse, protocol: SocketProtocol) { |
| | 333 | + let inode = resp.header.inode as u64; |
| | 334 | + if inode == 0 { |
| | 335 | + return; |
| | 336 | + } |
| | 337 | + |
| | 338 | + // Update existing socket info with byte counts from netlink |
| | 339 | + if let Some(info) = self.socket_cache.get_mut(&inode) { |
| | 340 | + // The InetResponse contains nla (netlink attributes) with TCP_INFO |
| | 341 | + // which has bytes_sent and bytes_received |
| | 342 | + // For now, we use the queue sizes as a proxy |
| | 343 | + // TODO: Parse TCP_INFO attribute for actual byte counts |
| | 344 | + info.rx_bytes = resp.header.recv_queue as u64; |
| | 345 | + info.tx_bytes = resp.header.send_queue as u64; |
| | 346 | + } else { |
| | 347 | + // Socket not in procfs cache, add it |
| | 348 | + let state = TcpState::from(resp.header.state); |
| | 349 | + self.socket_cache.insert(inode, SocketInfo { |
| | 350 | + protocol, |
| | 351 | + state, |
| | 352 | + local_port: resp.header.socket_id.source_port, |
| | 353 | + remote_port: resp.header.socket_id.destination_port, |
| | 354 | + tx_queue: resp.header.send_queue, |
| | 355 | + rx_queue: resp.header.recv_queue, |
| | 356 | + rx_bytes: resp.header.recv_queue as u64, |
| | 357 | + tx_bytes: resp.header.send_queue as u64, |
| | 358 | + }); |
| | 359 | + } |
| | 360 | + } |
| | 361 | + |
| | 362 | + /// Get network stats for a process given its socket inodes. |
| | 363 | + pub fn get_process_stats(&mut self, socket_inodes: &[u64]) -> ProcessNetStats { |
| | 364 | + let mut stats = ProcessNetStats::default(); |
| | 365 | + let now = Instant::now(); |
| | 366 | + |
| | 367 | + for inode in socket_inodes { |
| | 368 | + if let Some(info) = self.socket_cache.get(inode) { |
| | 369 | + match info.protocol { |
| | 370 | + SocketProtocol::Tcp => { |
| | 371 | + stats.tcp_count += 1; |
| | 372 | + if info.state == TcpState::Listen { |
| | 373 | + stats.listen_count += 1; |
| | 374 | + } else if info.state == TcpState::Established { |
| | 375 | + stats.established_count += 1; |
| | 376 | + } |
| | 377 | + } |
| | 378 | + SocketProtocol::Udp => { |
| | 379 | + stats.udp_count += 1; |
| | 380 | + } |
| | 381 | + } |
| | 382 | + stats.total_tx_queue += info.tx_queue as u64; |
| | 383 | + stats.total_rx_queue += info.rx_queue as u64; |
| | 384 | + stats.total_rx_bytes += info.rx_bytes; |
| | 385 | + stats.total_tx_bytes += info.tx_bytes; |
| | 386 | + } |
| | 387 | + } |
| | 388 | + |
| | 389 | + // Update prev_bytes for rate calculation |
| | 390 | + for inode in socket_inodes { |
| | 391 | + if let Some(info) = self.socket_cache.get(inode) { |
| | 392 | + self.prev_bytes.insert(*inode, PrevSocketBytes { |
| | 393 | + rx_bytes: info.rx_bytes, |
| | 394 | + tx_bytes: info.tx_bytes, |
| | 395 | + timestamp: now, |
| | 396 | + }); |
| | 397 | + } |
| | 398 | + } |
| | 399 | + |
| | 400 | + stats |
| | 401 | + } |
| | 402 | + |
| | 403 | + /// Calculate bandwidth rates for a process given its socket inodes. |
| | 404 | + pub fn get_process_bandwidth(&self, socket_inodes: &[u64]) -> (f64, f64) { |
| | 405 | + let now = Instant::now(); |
| | 406 | + let mut total_rx_rate = 0.0; |
| | 407 | + let mut total_tx_rate = 0.0; |
| | 408 | + |
| | 409 | + for inode in socket_inodes { |
| | 410 | + if let (Some(info), Some(prev)) = (self.socket_cache.get(inode), self.prev_bytes.get(inode)) { |
| | 411 | + let elapsed = now.duration_since(prev.timestamp).as_secs_f64(); |
| | 412 | + if elapsed > 0.0 { |
| | 413 | + let rx_delta = info.rx_bytes.saturating_sub(prev.rx_bytes) as f64; |
| | 414 | + let tx_delta = info.tx_bytes.saturating_sub(prev.tx_bytes) as f64; |
| | 415 | + total_rx_rate += rx_delta / elapsed; |
| | 416 | + total_tx_rate += tx_delta / elapsed; |
| | 417 | + } |
| | 418 | + } |
| | 419 | + } |
| | 420 | + |
| | 421 | + (total_rx_rate, total_tx_rate) |
| | 422 | + } |
| | 423 | + |
| | 424 | + /// Get the number of cached sockets. |
| | 425 | + pub fn socket_count(&self) -> usize { |
| | 426 | + self.socket_cache.len() |
| | 427 | + } |
| | 428 | +} |
| | 429 | + |
| | 430 | +impl Default for SocketCollector { |
| | 431 | + fn default() -> Self { |
| | 432 | + Self::new() |
| | 433 | + } |
| | 434 | +} |