Rust · 15045 bytes Raw Blame History
1 //! Network socket collection from /proc/net/* and netlink INET_DIAG
2 //!
3 //! Provides per-process network statistics by correlating socket inodes
4 //! from /proc/net/tcp,udp with process file descriptors.
5
6 use std::collections::HashMap;
7 use std::fs;
8 use std::io;
9 use std::time::Instant;
10
11 use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_REQUEST};
12 use netlink_packet_sock_diag::{
13 constants::*,
14 inet::{ExtensionFlags, InetRequest, InetResponse, SocketId, StateFlags},
15 SockDiagMessage,
16 };
17 use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr};
18
19 /// TCP connection states (from kernel include/net/tcp_states.h)
20 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
21 pub enum TcpState {
22 Established = 1,
23 SynSent = 2,
24 SynRecv = 3,
25 FinWait1 = 4,
26 FinWait2 = 5,
27 TimeWait = 6,
28 Close = 7,
29 CloseWait = 8,
30 LastAck = 9,
31 Listen = 10,
32 Closing = 11,
33 Unknown = 0,
34 }
35
36 impl From<u8> for TcpState {
37 fn from(state: u8) -> Self {
38 match state {
39 1 => TcpState::Established,
40 2 => TcpState::SynSent,
41 3 => TcpState::SynRecv,
42 4 => TcpState::FinWait1,
43 5 => TcpState::FinWait2,
44 6 => TcpState::TimeWait,
45 7 => TcpState::Close,
46 8 => TcpState::CloseWait,
47 9 => TcpState::LastAck,
48 10 => TcpState::Listen,
49 11 => TcpState::Closing,
50 _ => TcpState::Unknown,
51 }
52 }
53 }
54
55 /// Socket protocol type.
56 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
57 pub enum SocketProtocol {
58 Tcp,
59 Udp,
60 }
61
62 /// Information about a network socket.
63 #[derive(Debug, Clone)]
64 pub struct SocketInfo {
65 pub protocol: SocketProtocol,
66 pub state: TcpState,
67 pub local_port: u16,
68 pub remote_port: u16,
69 pub tx_queue: u32,
70 pub rx_queue: u32,
71 /// Bytes received (from netlink, if available)
72 pub rx_bytes: u64,
73 /// Bytes transmitted (from netlink, if available)
74 pub tx_bytes: u64,
75 }
76
77 /// Per-process network statistics.
78 #[derive(Debug, Clone, Default)]
79 pub struct ProcessNetStats {
80 pub tcp_count: u32,
81 pub udp_count: u32,
82 pub listen_count: u32,
83 pub established_count: u32,
84 pub total_tx_queue: u64,
85 pub total_rx_queue: u64,
86 pub total_rx_bytes: u64,
87 pub total_tx_bytes: u64,
88 }
89
90 /// Previous socket bytes for rate calculation.
91 struct PrevSocketBytes {
92 rx_bytes: u64,
93 tx_bytes: u64,
94 timestamp: Instant,
95 }
96
97 /// Socket collector that parses /proc/net/* files and uses netlink for bandwidth.
98 pub struct SocketCollector {
99 /// Cache of inode -> socket info
100 socket_cache: HashMap<u64, SocketInfo>,
101 /// Previous bytes per inode for rate calculation
102 prev_bytes: HashMap<u64, PrevSocketBytes>,
103 /// Netlink socket (lazy initialized)
104 nl_socket: Option<Socket>,
105 }
106
107 impl SocketCollector {
108 /// Create a new socket collector.
109 pub fn new() -> Self {
110 Self {
111 socket_cache: HashMap::new(),
112 prev_bytes: HashMap::new(),
113 nl_socket: None,
114 }
115 }
116
117 /// Initialize netlink socket if not already done.
118 fn init_netlink(&mut self) -> io::Result<()> {
119 if self.nl_socket.is_none() {
120 let mut socket = Socket::new(NETLINK_SOCK_DIAG)?;
121 socket.bind_auto()?;
122 socket.connect(&SocketAddr::new(0, 0))?;
123 // Set non-blocking to prevent hanging on recv
124 socket.set_non_blocking(true)?;
125 self.nl_socket = Some(socket);
126 }
127 Ok(())
128 }
129
130 /// Refresh the socket cache by parsing /proc/net/* files and querying netlink.
131 pub fn refresh(&mut self) {
132 self.socket_cache.clear();
133
134 // First parse procfs for basic socket info
135 self.parse_proc_net("/proc/net/tcp", SocketProtocol::Tcp);
136 self.parse_proc_net("/proc/net/tcp6", SocketProtocol::Tcp);
137 self.parse_proc_net("/proc/net/udp", SocketProtocol::Udp);
138 self.parse_proc_net("/proc/net/udp6", SocketProtocol::Udp);
139
140 // Then try to get bandwidth info from netlink
141 if let Err(e) = self.query_netlink_tcp() {
142 tracing::debug!("Netlink TCP query failed: {}", e);
143 }
144 if let Err(e) = self.query_netlink_udp() {
145 tracing::debug!("Netlink UDP query failed: {}", e);
146 }
147
148 // Clean up old prev_bytes entries
149 let current_inodes: std::collections::HashSet<_> = self.socket_cache.keys().copied().collect();
150 self.prev_bytes.retain(|inode, _| current_inodes.contains(inode));
151 }
152
153 /// Parse a /proc/net/* file and populate the cache.
154 fn parse_proc_net(&mut self, path: &str, protocol: SocketProtocol) {
155 let contents = match fs::read_to_string(path) {
156 Ok(c) => c,
157 Err(_) => return,
158 };
159
160 for line in contents.lines().skip(1) {
161 if let Some((inode, info)) = self.parse_socket_line(line, protocol) {
162 self.socket_cache.insert(inode, info);
163 }
164 }
165 }
166
167 /// Parse a single line from /proc/net/tcp or /proc/net/udp.
168 fn parse_socket_line(&self, line: &str, protocol: SocketProtocol) -> Option<(u64, SocketInfo)> {
169 let parts: Vec<&str> = line.split_whitespace().collect();
170 if parts.len() < 10 {
171 return None;
172 }
173
174 let local_port = self.parse_addr_port(parts[1])?;
175 let remote_port = self.parse_addr_port(parts[2])?;
176 let state_hex = u8::from_str_radix(parts[3], 16).ok()?;
177 let state = TcpState::from(state_hex);
178
179 let queues: Vec<&str> = parts[4].split(':').collect();
180 let tx_queue = u32::from_str_radix(queues.first()?, 16).unwrap_or(0);
181 let rx_queue = u32::from_str_radix(queues.get(1)?, 16).unwrap_or(0);
182
183 let inode: u64 = parts[9].parse().ok()?;
184 if inode == 0 {
185 return None;
186 }
187
188 Some((
189 inode,
190 SocketInfo {
191 protocol,
192 state,
193 local_port,
194 remote_port,
195 tx_queue,
196 rx_queue,
197 rx_bytes: 0,
198 tx_bytes: 0,
199 },
200 ))
201 }
202
203 /// Parse port from address string (IP:PORT in hex).
204 fn parse_addr_port(&self, addr: &str) -> Option<u16> {
205 let parts: Vec<&str> = addr.split(':').collect();
206 if parts.len() != 2 {
207 return None;
208 }
209 u16::from_str_radix(parts[1], 16).ok()
210 }
211
212 /// Query netlink for TCP socket info with byte counts.
213 fn query_netlink_tcp(&mut self) -> io::Result<()> {
214 self.init_netlink()?;
215
216 // Request all TCP sockets with extended info - IPv4
217 let req_v4 = InetRequest {
218 family: AF_INET as u8,
219 protocol: IPPROTO_TCP as u8,
220 socket_id: SocketId::new_v4(),
221 extensions: ExtensionFlags::INFO,
222 states: StateFlags::all(),
223 };
224
225 // IPv6 request
226 let req_v6 = InetRequest {
227 family: AF_INET6 as u8,
228 protocol: IPPROTO_TCP as u8,
229 socket_id: SocketId::new_v6(),
230 extensions: ExtensionFlags::INFO,
231 states: StateFlags::all(),
232 };
233
234 // Process IPv4
235 self.send_and_recv_inet_diag(&req_v4, SocketProtocol::Tcp)?;
236 // Process IPv6
237 self.send_and_recv_inet_diag(&req_v6, SocketProtocol::Tcp)?;
238
239 Ok(())
240 }
241
242 /// Query netlink for UDP socket info.
243 fn query_netlink_udp(&mut self) -> io::Result<()> {
244 self.init_netlink()?;
245
246 let req_v4 = InetRequest {
247 family: AF_INET as u8,
248 protocol: IPPROTO_UDP as u8,
249 socket_id: SocketId::new_v4(),
250 extensions: ExtensionFlags::empty(),
251 states: StateFlags::all(),
252 };
253
254 let req_v6 = InetRequest {
255 family: AF_INET6 as u8,
256 protocol: IPPROTO_UDP as u8,
257 socket_id: SocketId::new_v6(),
258 extensions: ExtensionFlags::empty(),
259 states: StateFlags::all(),
260 };
261
262 self.send_and_recv_inet_diag(&req_v4, SocketProtocol::Udp)?;
263 self.send_and_recv_inet_diag(&req_v6, SocketProtocol::Udp)?;
264
265 Ok(())
266 }
267
268 /// Send request and receive responses in one method to avoid borrow issues.
269 fn send_and_recv_inet_diag(&mut self, req: &InetRequest, protocol: SocketProtocol) -> io::Result<()> {
270 // Build the message
271 let msg = SockDiagMessage::InetRequest(req.clone());
272 let mut nl_msg = NetlinkMessage::from(msg);
273 nl_msg.header.flags = NLM_F_REQUEST | NLM_F_DUMP;
274 nl_msg.header.sequence_number = 1;
275 nl_msg.finalize();
276
277 let mut send_buf = vec![0u8; nl_msg.header.length as usize];
278 nl_msg.serialize(&mut send_buf);
279
280 // Collect responses first, then process them
281 let mut responses = Vec::new();
282
283 {
284 // Scope the socket borrow
285 let socket = self.nl_socket.as_ref().unwrap();
286 socket.send(&send_buf, 0)?;
287
288 let mut recv_buf = vec![0u8; 65536];
289 let mut done = false;
290 let mut retries = 0;
291 const MAX_RETRIES: u32 = 100; // 100ms max wait
292
293 while !done && retries < MAX_RETRIES {
294 let n = match socket.recv(&mut recv_buf, 0) {
295 Ok(n) => n,
296 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
297 // Wait a bit for data to arrive
298 retries += 1;
299 std::thread::sleep(std::time::Duration::from_millis(1));
300 continue;
301 }
302 Err(e) => return Err(e),
303 };
304
305 if n == 0 {
306 break;
307 }
308
309 let mut offset = 0;
310 while offset < n {
311 let msg = match NetlinkMessage::<SockDiagMessage>::deserialize(&recv_buf[offset..n]) {
312 Ok(msg) => msg,
313 Err(_) => break,
314 };
315
316 offset += msg.header.length as usize;
317
318 match msg.payload {
319 NetlinkPayload::Done(_) => {
320 done = true;
321 break;
322 }
323 NetlinkPayload::Error(e) => {
324 if e.code.is_some() {
325 return Err(io::Error::new(io::ErrorKind::Other, "netlink error"));
326 }
327 }
328 NetlinkPayload::InnerMessage(SockDiagMessage::InetResponse(resp)) => {
329 responses.push(resp);
330 }
331 _ => {}
332 }
333 }
334 }
335 }
336
337 // Now process collected responses
338 for resp in responses {
339 self.process_inet_response(&resp, protocol);
340 }
341
342 Ok(())
343 }
344
345 /// Process a single INET_DIAG response.
346 fn process_inet_response(&mut self, resp: &InetResponse, protocol: SocketProtocol) {
347 let inode = resp.header.inode as u64;
348 if inode == 0 {
349 return;
350 }
351
352 // Update existing socket info with byte counts from netlink
353 if let Some(info) = self.socket_cache.get_mut(&inode) {
354 // The InetResponse contains nla (netlink attributes) with TCP_INFO
355 // which has bytes_sent and bytes_received
356 // For now, we use the queue sizes as a proxy
357 // TODO: Parse TCP_INFO attribute for actual byte counts
358 info.rx_bytes = resp.header.recv_queue as u64;
359 info.tx_bytes = resp.header.send_queue as u64;
360 } else {
361 // Socket not in procfs cache, add it
362 let state = TcpState::from(resp.header.state);
363 self.socket_cache.insert(inode, SocketInfo {
364 protocol,
365 state,
366 local_port: resp.header.socket_id.source_port,
367 remote_port: resp.header.socket_id.destination_port,
368 tx_queue: resp.header.send_queue,
369 rx_queue: resp.header.recv_queue,
370 rx_bytes: resp.header.recv_queue as u64,
371 tx_bytes: resp.header.send_queue as u64,
372 });
373 }
374 }
375
376 /// Get network stats for a process given its socket inodes.
377 pub fn get_process_stats(&mut self, socket_inodes: &[u64]) -> ProcessNetStats {
378 let mut stats = ProcessNetStats::default();
379 let now = Instant::now();
380
381 for inode in socket_inodes {
382 if let Some(info) = self.socket_cache.get(inode) {
383 match info.protocol {
384 SocketProtocol::Tcp => {
385 stats.tcp_count += 1;
386 if info.state == TcpState::Listen {
387 stats.listen_count += 1;
388 } else if info.state == TcpState::Established {
389 stats.established_count += 1;
390 }
391 }
392 SocketProtocol::Udp => {
393 stats.udp_count += 1;
394 }
395 }
396 stats.total_tx_queue += info.tx_queue as u64;
397 stats.total_rx_queue += info.rx_queue as u64;
398 stats.total_rx_bytes += info.rx_bytes;
399 stats.total_tx_bytes += info.tx_bytes;
400 }
401 }
402
403 // Update prev_bytes for rate calculation
404 for inode in socket_inodes {
405 if let Some(info) = self.socket_cache.get(inode) {
406 self.prev_bytes.insert(*inode, PrevSocketBytes {
407 rx_bytes: info.rx_bytes,
408 tx_bytes: info.tx_bytes,
409 timestamp: now,
410 });
411 }
412 }
413
414 stats
415 }
416
417 /// Calculate bandwidth rates for a process given its socket inodes.
418 pub fn get_process_bandwidth(&self, socket_inodes: &[u64]) -> (f64, f64) {
419 let now = Instant::now();
420 let mut total_rx_rate = 0.0;
421 let mut total_tx_rate = 0.0;
422
423 for inode in socket_inodes {
424 if let (Some(info), Some(prev)) = (self.socket_cache.get(inode), self.prev_bytes.get(inode)) {
425 let elapsed = now.duration_since(prev.timestamp).as_secs_f64();
426 if elapsed > 0.0 {
427 let rx_delta = info.rx_bytes.saturating_sub(prev.rx_bytes) as f64;
428 let tx_delta = info.tx_bytes.saturating_sub(prev.tx_bytes) as f64;
429 total_rx_rate += rx_delta / elapsed;
430 total_tx_rate += tx_delta / elapsed;
431 }
432 }
433 }
434
435 (total_rx_rate, total_tx_rate)
436 }
437
438 /// Get the number of cached sockets.
439 pub fn socket_count(&self) -> usize {
440 self.socket_cache.len()
441 }
442 }
443
444 impl Default for SocketCollector {
445 fn default() -> Self {
446 Self::new()
447 }
448 }
449