Rust · 21732 bytes Raw Blame History
1 use std::fmt;
2
3 pub const PROTOCOL_VERSION: u16 = 1;
4 pub const DEFAULT_RUNTIME_SUBDIR: &str = "garwarp";
5 pub const DEFAULT_CONTROL_SOCKET: &str = "control.sock";
6
7 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
8 pub enum HealthStatus {
9 Starting,
10 Healthy,
11 Degraded,
12 Stopping,
13 }
14
15 impl HealthStatus {
16 #[must_use]
17 pub fn as_str(self) -> &'static str {
18 match self {
19 Self::Starting => "starting",
20 Self::Healthy => "healthy",
21 Self::Degraded => "degraded",
22 Self::Stopping => "stopping",
23 }
24 }
25
26 fn parse(input: &str) -> Option<Self> {
27 match input {
28 "starting" => Some(Self::Starting),
29 "healthy" => Some(Self::Healthy),
30 "degraded" => Some(Self::Degraded),
31 "stopping" => Some(Self::Stopping),
32 _ => None,
33 }
34 }
35 }
36
37 #[derive(Debug, Clone, PartialEq, Eq)]
38 pub enum ControlRequest {
39 Status,
40 Stop,
41 ListRequests,
42 InspectRequest {
43 id: String,
44 },
45 BeginRequest {
46 id: String,
47 sender: String,
48 app_id: Option<String>,
49 parent_window: Option<String>,
50 },
51 TransitionRequest {
52 id: String,
53 sender: String,
54 app_id: Option<String>,
55 target: RequestTransitionTarget,
56 },
57 }
58
59 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
60 pub enum RequestTransitionTarget {
61 AwaitingUser,
62 Fulfilled,
63 Cancelled,
64 Failed,
65 }
66
67 impl RequestTransitionTarget {
68 #[must_use]
69 pub fn as_str(self) -> &'static str {
70 match self {
71 Self::AwaitingUser => "awaiting_user",
72 Self::Fulfilled => "fulfilled",
73 Self::Cancelled => "cancelled",
74 Self::Failed => "failed",
75 }
76 }
77
78 fn parse(input: &str) -> Option<Self> {
79 match input {
80 "awaiting_user" => Some(Self::AwaitingUser),
81 "fulfilled" => Some(Self::Fulfilled),
82 "cancelled" => Some(Self::Cancelled),
83 "failed" => Some(Self::Failed),
84 _ => None,
85 }
86 }
87 }
88
89 impl ControlRequest {
90 #[must_use]
91 pub fn as_line(&self) -> String {
92 match self {
93 Self::Status => "status".to_string(),
94 Self::Stop => "stop".to_string(),
95 Self::ListRequests => "list".to_string(),
96 Self::InspectRequest { id } => format!("inspect id={id}"),
97 Self::BeginRequest {
98 id,
99 sender,
100 app_id,
101 parent_window,
102 } => {
103 let mut parts = vec![
104 "begin".to_string(),
105 format!("id={id}"),
106 format!("sender={sender}"),
107 ];
108 if let Some(app_id) = app_id {
109 parts.push(format!("app_id={app_id}"));
110 }
111 if let Some(parent_window) = parent_window {
112 parts.push(format!("parent={parent_window}"));
113 }
114 parts.join(" ")
115 }
116 Self::TransitionRequest {
117 id,
118 sender,
119 app_id,
120 target,
121 } => {
122 let mut parts = vec![
123 "transition".to_string(),
124 format!("id={id}"),
125 format!("sender={sender}"),
126 format!("state={}", target.as_str()),
127 ];
128 if let Some(app_id) = app_id {
129 parts.push(format!("app_id={app_id}"));
130 }
131 parts.join(" ")
132 }
133 }
134 }
135
136 #[must_use]
137 pub fn parse_line(input: &str) -> Option<Self> {
138 let trimmed = input.trim();
139 if trimmed == "status" {
140 return Some(Self::Status);
141 }
142 if trimmed == "stop" {
143 return Some(Self::Stop);
144 }
145 if trimmed == "list" {
146 return Some(Self::ListRequests);
147 }
148
149 let mut parts = trimmed.split_whitespace();
150 match parts.next() {
151 Some("inspect") => {
152 let fields = parse_fields(parts)?;
153 if !fields_only(&fields, &["id"]) {
154 return None;
155 }
156 let id = fields.get("id")?.clone();
157 Some(Self::InspectRequest { id })
158 }
159 Some("begin") => {
160 let fields = parse_fields(parts)?;
161 if !fields_only(&fields, &["id", "sender", "app_id", "parent"]) {
162 return None;
163 }
164 let id = fields.get("id")?.clone();
165 let sender = fields.get("sender")?.clone();
166 let app_id = fields.get("app_id").cloned();
167 let parent_window = fields.get("parent").cloned();
168 Some(Self::BeginRequest {
169 id,
170 sender,
171 app_id,
172 parent_window,
173 })
174 }
175 Some("transition") => {
176 let fields = parse_fields(parts)?;
177 if !fields_only(&fields, &["id", "sender", "state", "app_id"]) {
178 return None;
179 }
180 let id = fields.get("id")?.clone();
181 let sender = fields.get("sender")?.clone();
182 let app_id = fields.get("app_id").cloned();
183 let target = RequestTransitionTarget::parse(fields.get("state")?)?;
184 Some(Self::TransitionRequest {
185 id,
186 sender,
187 app_id,
188 target,
189 })
190 }
191 _ => None,
192 }
193 }
194 }
195
196 #[derive(Debug, Clone, PartialEq, Eq)]
197 pub struct StatusResponse {
198 pub protocol_version: u16,
199 pub health: HealthStatus,
200 pub in_flight_requests: usize,
201 pub total_requests: usize,
202 pub terminal_requests: usize,
203 }
204
205 impl StatusResponse {
206 #[must_use]
207 pub fn healthy() -> Self {
208 Self {
209 protocol_version: PROTOCOL_VERSION,
210 health: HealthStatus::Healthy,
211 in_flight_requests: 0,
212 total_requests: 0,
213 terminal_requests: 0,
214 }
215 }
216 }
217
218 #[derive(Debug, Clone, PartialEq, Eq)]
219 pub enum ControlResponse {
220 Status(StatusResponse),
221 AckStopping,
222 RequestList {
223 ids: Vec<String>,
224 },
225 AckRequest {
226 id: String,
227 state: String,
228 },
229 RequestSnapshot {
230 id: String,
231 state: String,
232 sender: String,
233 app_id: Option<String>,
234 parent_window: Option<String>,
235 },
236 Error {
237 code: u32,
238 reason: String,
239 },
240 }
241
242 impl ControlResponse {
243 #[must_use]
244 pub fn to_line(&self) -> String {
245 match self {
246 Self::Status(status) => format!(
247 "status protocol={} health={} in_flight={} total={} terminal={}\n",
248 status.protocol_version,
249 status.health.as_str(),
250 status.in_flight_requests,
251 status.total_requests,
252 status.terminal_requests
253 ),
254 Self::AckStopping => "ack stopping\n".to_string(),
255 Self::RequestList { ids } => {
256 let ids = if ids.is_empty() {
257 "-".to_string()
258 } else {
259 ids.join(",")
260 };
261 format!("list ids={ids}\n")
262 }
263 Self::AckRequest { id, state } => {
264 format!("ack request id={} state={}\n", id, state)
265 }
266 Self::RequestSnapshot {
267 id,
268 state,
269 sender,
270 app_id,
271 parent_window,
272 } => {
273 let app_id = app_id.as_deref().unwrap_or("-");
274 let parent_window = parent_window.as_deref().unwrap_or("-");
275 format!(
276 "snapshot id={} state={} sender={} app_id={} parent={}\n",
277 id, state, sender, app_id, parent_window
278 )
279 }
280 Self::Error { code, reason } => format!("error code={} reason={}\n", code, reason),
281 }
282 }
283
284 pub fn parse_line(input: &str) -> Result<Self, ParseError> {
285 let trimmed = input.trim();
286 let mut parts = trimmed.split_whitespace();
287
288 match parts.next() {
289 Some("status") => {
290 let mut protocol_version = None;
291 let mut health = None;
292 let mut in_flight_requests = None;
293 let mut total_requests = None;
294 let mut terminal_requests = None;
295
296 for part in parts {
297 let (key, value) = part
298 .split_once('=')
299 .ok_or(ParseError::InvalidField(part.to_string()))?;
300 match key {
301 "protocol" => {
302 protocol_version = Some(
303 value
304 .parse::<u16>()
305 .map_err(|_| ParseError::InvalidField(part.to_string()))?,
306 );
307 }
308 "health" => {
309 health = HealthStatus::parse(value);
310 if health.is_none() {
311 return Err(ParseError::InvalidField(part.to_string()));
312 }
313 }
314 "in_flight" => {
315 in_flight_requests = Some(
316 value
317 .parse::<usize>()
318 .map_err(|_| ParseError::InvalidField(part.to_string()))?,
319 );
320 }
321 "total" => {
322 total_requests = Some(
323 value
324 .parse::<usize>()
325 .map_err(|_| ParseError::InvalidField(part.to_string()))?,
326 );
327 }
328 "terminal" => {
329 terminal_requests = Some(
330 value
331 .parse::<usize>()
332 .map_err(|_| ParseError::InvalidField(part.to_string()))?,
333 );
334 }
335 _ => return Err(ParseError::InvalidField(part.to_string())),
336 }
337 }
338
339 let status = StatusResponse {
340 protocol_version: protocol_version
341 .ok_or(ParseError::MissingField("protocol"))?,
342 health: health.ok_or(ParseError::MissingField("health"))?,
343 in_flight_requests: in_flight_requests
344 .ok_or(ParseError::MissingField("in_flight"))?,
345 total_requests: total_requests.ok_or(ParseError::MissingField("total"))?,
346 terminal_requests: terminal_requests
347 .ok_or(ParseError::MissingField("terminal"))?,
348 };
349 Ok(Self::Status(status))
350 }
351 Some("ack") => match parts.next() {
352 Some("stopping") => Ok(Self::AckStopping),
353 Some("request") => {
354 let mut id = None;
355 let mut state = None;
356 for part in parts {
357 let (key, value) = part
358 .split_once('=')
359 .ok_or(ParseError::InvalidField(part.to_string()))?;
360 match key {
361 "id" => id = Some(value.to_string()),
362 "state" => state = Some(value.to_string()),
363 _ => return Err(ParseError::InvalidField(part.to_string())),
364 }
365 }
366 Ok(Self::AckRequest {
367 id: id.ok_or(ParseError::MissingField("id"))?,
368 state: state.ok_or(ParseError::MissingField("state"))?,
369 })
370 }
371 Some(other) => Err(ParseError::UnknownToken(other.to_string())),
372 None => Err(ParseError::MissingField("ack")),
373 },
374 Some("list") => {
375 let mut ids = None;
376 for part in parts {
377 let (key, value) = part
378 .split_once('=')
379 .ok_or(ParseError::InvalidField(part.to_string()))?;
380 match key {
381 "ids" => {
382 if value == "-" {
383 ids = Some(Vec::new());
384 } else {
385 let parsed =
386 value.split(',').map(str::to_string).collect::<Vec<_>>();
387 if parsed.iter().any(|id| id.is_empty()) {
388 return Err(ParseError::InvalidField(part.to_string()));
389 }
390 ids = Some(parsed);
391 }
392 }
393 _ => return Err(ParseError::InvalidField(part.to_string())),
394 }
395 }
396 Ok(Self::RequestList {
397 ids: ids.ok_or(ParseError::MissingField("ids"))?,
398 })
399 }
400 Some("snapshot") => {
401 let mut id = None;
402 let mut state = None;
403 let mut sender = None;
404 let mut app_id = None;
405 let mut parent_window = None;
406
407 for part in parts {
408 let (key, value) = part
409 .split_once('=')
410 .ok_or(ParseError::InvalidField(part.to_string()))?;
411 match key {
412 "id" => id = Some(value.to_string()),
413 "state" => state = Some(value.to_string()),
414 "sender" => sender = Some(value.to_string()),
415 "app_id" => {
416 if value != "-" {
417 app_id = Some(value.to_string());
418 }
419 }
420 "parent" => {
421 if value != "-" {
422 parent_window = Some(value.to_string());
423 }
424 }
425 _ => return Err(ParseError::InvalidField(part.to_string())),
426 }
427 }
428
429 Ok(Self::RequestSnapshot {
430 id: id.ok_or(ParseError::MissingField("id"))?,
431 state: state.ok_or(ParseError::MissingField("state"))?,
432 sender: sender.ok_or(ParseError::MissingField("sender"))?,
433 app_id,
434 parent_window,
435 })
436 }
437 Some("error") => match parts.next() {
438 Some(first_field) => {
439 let mut code = None;
440 let mut reason = None;
441 let mut fields = vec![first_field];
442 fields.extend(parts);
443
444 for field in fields {
445 let (key, value) = field
446 .split_once('=')
447 .ok_or(ParseError::InvalidField(field.to_string()))?;
448 match key {
449 "code" => {
450 code =
451 Some(value.parse::<u32>().map_err(|_| {
452 ParseError::InvalidField(field.to_string())
453 })?);
454 }
455 "reason" => reason = Some(value.to_string()),
456 _ => return Err(ParseError::InvalidField(field.to_string())),
457 }
458 }
459 Ok(Self::Error {
460 code: code.ok_or(ParseError::MissingField("code"))?,
461 reason: reason.ok_or(ParseError::MissingField("reason"))?,
462 })
463 }
464 None => Err(ParseError::MissingField("reason")),
465 },
466 Some(other) => Err(ParseError::UnknownToken(other.to_string())),
467 None => Err(ParseError::Empty),
468 }
469 }
470 }
471
472 #[derive(Debug, Clone, PartialEq, Eq)]
473 pub enum ParseError {
474 Empty,
475 MissingField(&'static str),
476 InvalidField(String),
477 UnknownToken(String),
478 }
479
480 impl fmt::Display for ParseError {
481 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482 match self {
483 Self::Empty => write!(f, "empty input"),
484 Self::MissingField(field) => write!(f, "missing field: {field}"),
485 Self::InvalidField(field) => write!(f, "invalid field: {field}"),
486 Self::UnknownToken(token) => write!(f, "unknown token: {token}"),
487 }
488 }
489 }
490
491 impl std::error::Error for ParseError {}
492
493 fn parse_fields<'a, I>(parts: I) -> Option<std::collections::HashMap<String, String>>
494 where
495 I: Iterator<Item = &'a str>,
496 {
497 let mut fields = std::collections::HashMap::new();
498 for part in parts {
499 let (key, value) = part.split_once('=')?;
500 if fields.contains_key(key) {
501 return None;
502 }
503 fields.insert(key.to_string(), value.to_string());
504 }
505 Some(fields)
506 }
507
508 fn fields_only(fields: &std::collections::HashMap<String, String>, allowed: &[&str]) -> bool {
509 fields
510 .keys()
511 .all(|key| allowed.iter().any(|allowed| key == allowed))
512 }
513
514 #[cfg(test)]
515 mod tests {
516 use super::{
517 ControlRequest, ControlResponse, HealthStatus, PROTOCOL_VERSION, RequestTransitionTarget,
518 StatusResponse,
519 };
520
521 #[test]
522 fn request_parse_roundtrip() {
523 for request in [
524 ControlRequest::Status,
525 ControlRequest::Stop,
526 ControlRequest::ListRequests,
527 ControlRequest::InspectRequest {
528 id: "req-1".to_string(),
529 },
530 ControlRequest::BeginRequest {
531 id: "req-1".to_string(),
532 sender: ":1.2".to_string(),
533 app_id: Some("org.test.App".to_string()),
534 parent_window: Some("x11:0x2a".to_string()),
535 },
536 ControlRequest::TransitionRequest {
537 id: "req-1".to_string(),
538 sender: ":1.2".to_string(),
539 app_id: Some("org.test.App".to_string()),
540 target: RequestTransitionTarget::Cancelled,
541 },
542 ] {
543 let line = request.as_line();
544 let parsed = ControlRequest::parse_line(&line);
545 assert_eq!(parsed, Some(request));
546 }
547 }
548
549 #[test]
550 fn response_status_roundtrip() {
551 let response = ControlResponse::Status(StatusResponse {
552 protocol_version: PROTOCOL_VERSION,
553 health: HealthStatus::Healthy,
554 in_flight_requests: 7,
555 total_requests: 10,
556 terminal_requests: 3,
557 });
558 let line = response.to_line();
559 let parsed = ControlResponse::parse_line(&line).expect("response should parse");
560 assert_eq!(parsed, response);
561 }
562
563 #[test]
564 fn response_ack_roundtrip() {
565 for response in [
566 ControlResponse::AckStopping,
567 ControlResponse::RequestList {
568 ids: vec!["req-1".to_string(), "req-2".to_string()],
569 },
570 ControlResponse::RequestList { ids: Vec::new() },
571 ControlResponse::AckRequest {
572 id: "req-1".to_string(),
573 state: "pending".to_string(),
574 },
575 ControlResponse::RequestSnapshot {
576 id: "req-1".to_string(),
577 state: "awaiting_user".to_string(),
578 sender: ":1.2".to_string(),
579 app_id: Some("org.test.App".to_string()),
580 parent_window: Some("x11:0x2a".to_string()),
581 },
582 ControlResponse::Error {
583 code: 2,
584 reason: "invalid_request".to_string(),
585 },
586 ] {
587 let line = response.to_line();
588 let parsed = ControlResponse::parse_line(&line).expect("response should parse");
589 assert_eq!(parsed, response);
590 }
591 }
592
593 #[test]
594 fn healthy_response_uses_protocol_version() {
595 let response = StatusResponse::healthy();
596 assert_eq!(response.protocol_version, PROTOCOL_VERSION);
597 assert_eq!(response.health, HealthStatus::Healthy);
598 assert_eq!(response.in_flight_requests, 0);
599 assert_eq!(response.total_requests, 0);
600 assert_eq!(response.terminal_requests, 0);
601 }
602
603 #[test]
604 fn malformed_status_is_rejected() {
605 let parsed = ControlResponse::parse_line(
606 "status protocol=one health=healthy in_flight=0 total=0 terminal=0",
607 );
608 assert!(parsed.is_err());
609 }
610
611 #[test]
612 fn request_parse_rejects_duplicate_fields() {
613 let parsed = ControlRequest::parse_line("inspect id=req-1 id=req-2");
614 assert_eq!(parsed, None);
615 }
616
617 #[test]
618 fn request_parse_rejects_unknown_fields() {
619 let parsed = ControlRequest::parse_line("inspect id=req-1 bogus=1");
620 assert_eq!(parsed, None);
621 }
622 }
623