Rust · 25237 bytes Raw Blame History
1 use std::fmt;
2
3 pub const PROTOCOL_VERSION: u16 = 1;
4 pub const DEFAULT_RUNTIME_SUBDIR: &str = "garwarp";
5 pub const DEFAULT_CONTROL_SOCKET: &str = "control.sock";
6
7 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
8 pub enum HealthStatus {
9 Starting,
10 Healthy,
11 Degraded,
12 Stopping,
13 }
14
15 impl HealthStatus {
16 #[must_use]
17 pub fn as_str(self) -> &'static str {
18 match self {
19 Self::Starting => "starting",
20 Self::Healthy => "healthy",
21 Self::Degraded => "degraded",
22 Self::Stopping => "stopping",
23 }
24 }
25
26 fn parse(input: &str) -> Option<Self> {
27 match input {
28 "starting" => Some(Self::Starting),
29 "healthy" => Some(Self::Healthy),
30 "degraded" => Some(Self::Degraded),
31 "stopping" => Some(Self::Stopping),
32 _ => None,
33 }
34 }
35 }
36
37 #[derive(Debug, Clone, PartialEq, Eq)]
38 pub enum ControlRequest {
39 Status,
40 Stop,
41 ListRequests,
42 InspectRequest {
43 id: String,
44 },
45 BeginRequest {
46 id: String,
47 sender: String,
48 app_id: Option<String>,
49 parent_window: Option<String>,
50 },
51 TransitionRequest {
52 id: String,
53 sender: String,
54 app_id: Option<String>,
55 target: RequestTransitionTarget,
56 },
57 }
58
59 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
60 pub enum RequestTransitionTarget {
61 AwaitingUser,
62 Fulfilled,
63 Cancelled,
64 Failed,
65 }
66
67 impl RequestTransitionTarget {
68 #[must_use]
69 pub fn as_str(self) -> &'static str {
70 match self {
71 Self::AwaitingUser => "awaiting_user",
72 Self::Fulfilled => "fulfilled",
73 Self::Cancelled => "cancelled",
74 Self::Failed => "failed",
75 }
76 }
77
78 fn parse(input: &str) -> Option<Self> {
79 match input {
80 "awaiting_user" => Some(Self::AwaitingUser),
81 "fulfilled" => Some(Self::Fulfilled),
82 "cancelled" => Some(Self::Cancelled),
83 "failed" => Some(Self::Failed),
84 _ => None,
85 }
86 }
87 }
88
89 impl ControlRequest {
90 #[must_use]
91 pub fn as_line(&self) -> String {
92 match self {
93 Self::Status => "status".to_string(),
94 Self::Stop => "stop".to_string(),
95 Self::ListRequests => "list".to_string(),
96 Self::InspectRequest { id } => format!("inspect id={id}"),
97 Self::BeginRequest {
98 id,
99 sender,
100 app_id,
101 parent_window,
102 } => {
103 let mut parts = vec![
104 "begin".to_string(),
105 format!("id={id}"),
106 format!("sender={sender}"),
107 ];
108 if let Some(app_id) = app_id {
109 parts.push(format!("app_id={app_id}"));
110 }
111 if let Some(parent_window) = parent_window {
112 parts.push(format!("parent={parent_window}"));
113 }
114 parts.join(" ")
115 }
116 Self::TransitionRequest {
117 id,
118 sender,
119 app_id,
120 target,
121 } => {
122 let mut parts = vec![
123 "transition".to_string(),
124 format!("id={id}"),
125 format!("sender={sender}"),
126 format!("state={}", target.as_str()),
127 ];
128 if let Some(app_id) = app_id {
129 parts.push(format!("app_id={app_id}"));
130 }
131 parts.join(" ")
132 }
133 }
134 }
135
136 #[must_use]
137 pub fn parse_line(input: &str) -> Option<Self> {
138 let trimmed = input.trim();
139 if trimmed == "status" {
140 return Some(Self::Status);
141 }
142 if trimmed == "stop" {
143 return Some(Self::Stop);
144 }
145 if trimmed == "list" {
146 return Some(Self::ListRequests);
147 }
148
149 let mut parts = trimmed.split_whitespace();
150 match parts.next() {
151 Some("inspect") => {
152 let fields = parse_fields(parts)?;
153 if !fields_only(&fields, &["id"]) {
154 return None;
155 }
156 let id = fields.get("id")?.clone();
157 Some(Self::InspectRequest { id })
158 }
159 Some("begin") => {
160 let fields = parse_fields(parts)?;
161 if !fields_only(&fields, &["id", "sender", "app_id", "parent"]) {
162 return None;
163 }
164 let id = fields.get("id")?.clone();
165 let sender = fields.get("sender")?.clone();
166 let app_id = fields.get("app_id").cloned();
167 let parent_window = fields.get("parent").cloned();
168 Some(Self::BeginRequest {
169 id,
170 sender,
171 app_id,
172 parent_window,
173 })
174 }
175 Some("transition") => {
176 let fields = parse_fields(parts)?;
177 if !fields_only(&fields, &["id", "sender", "state", "app_id"]) {
178 return None;
179 }
180 let id = fields.get("id")?.clone();
181 let sender = fields.get("sender")?.clone();
182 let app_id = fields.get("app_id").cloned();
183 let target = RequestTransitionTarget::parse(fields.get("state")?)?;
184 Some(Self::TransitionRequest {
185 id,
186 sender,
187 app_id,
188 target,
189 })
190 }
191 _ => None,
192 }
193 }
194 }
195
196 #[derive(Debug, Clone, PartialEq, Eq)]
197 pub struct StatusResponse {
198 pub protocol_version: u16,
199 pub health: HealthStatus,
200 pub in_flight_requests: usize,
201 pub total_requests: usize,
202 pub terminal_requests: usize,
203 }
204
205 impl StatusResponse {
206 #[must_use]
207 pub fn healthy() -> Self {
208 Self {
209 protocol_version: PROTOCOL_VERSION,
210 health: HealthStatus::Healthy,
211 in_flight_requests: 0,
212 total_requests: 0,
213 terminal_requests: 0,
214 }
215 }
216 }
217
218 #[derive(Debug, Clone, PartialEq, Eq)]
219 pub enum ControlResponse {
220 Status(StatusResponse),
221 AckStopping,
222 RequestList {
223 ids: Vec<String>,
224 },
225 AckRequest {
226 id: String,
227 state: String,
228 },
229 RequestSnapshot {
230 id: String,
231 state: String,
232 sender: String,
233 app_id: Option<String>,
234 parent_window: Option<String>,
235 },
236 Error {
237 code: u32,
238 reason: String,
239 },
240 }
241
242 impl ControlResponse {
243 #[must_use]
244 pub fn to_line(&self) -> String {
245 match self {
246 Self::Status(status) => format!(
247 "status protocol={} health={} in_flight={} total={} terminal={}\n",
248 status.protocol_version,
249 status.health.as_str(),
250 status.in_flight_requests,
251 status.total_requests,
252 status.terminal_requests
253 ),
254 Self::AckStopping => "ack stopping\n".to_string(),
255 Self::RequestList { ids } => {
256 let ids = if ids.is_empty() {
257 "-".to_string()
258 } else {
259 ids.join(",")
260 };
261 format!("list ids={ids}\n")
262 }
263 Self::AckRequest { id, state } => {
264 format!("ack request id={} state={}\n", id, state)
265 }
266 Self::RequestSnapshot {
267 id,
268 state,
269 sender,
270 app_id,
271 parent_window,
272 } => {
273 let app_id = app_id.as_deref().unwrap_or("-");
274 let parent_window = parent_window.as_deref().unwrap_or("-");
275 format!(
276 "snapshot id={} state={} sender={} app_id={} parent={}\n",
277 id, state, sender, app_id, parent_window
278 )
279 }
280 Self::Error { code, reason } => format!("error code={} reason={}\n", code, reason),
281 }
282 }
283
284 pub fn parse_line(input: &str) -> Result<Self, ParseError> {
285 let trimmed = input.trim();
286 let mut parts = trimmed.split_whitespace();
287
288 match parts.next() {
289 Some("status") => {
290 let mut protocol_version = None;
291 let mut health = None;
292 let mut in_flight_requests = None;
293 let mut total_requests = None;
294 let mut terminal_requests = None;
295
296 for part in parts {
297 let (key, value) = part
298 .split_once('=')
299 .ok_or(ParseError::InvalidField(part.to_string()))?;
300 match key {
301 "protocol" => {
302 let parsed = value
303 .parse::<u16>()
304 .map_err(|_| ParseError::InvalidField(part.to_string()))?;
305 if protocol_version.replace(parsed).is_some() {
306 return Err(ParseError::InvalidField(part.to_string()));
307 }
308 }
309 "health" => {
310 let parsed = HealthStatus::parse(value)
311 .ok_or(ParseError::InvalidField(part.to_string()))?;
312 if health.replace(parsed).is_some() {
313 return Err(ParseError::InvalidField(part.to_string()));
314 }
315 }
316 "in_flight" => {
317 let parsed = value
318 .parse::<usize>()
319 .map_err(|_| ParseError::InvalidField(part.to_string()))?;
320 if in_flight_requests.replace(parsed).is_some() {
321 return Err(ParseError::InvalidField(part.to_string()));
322 }
323 }
324 "total" => {
325 let parsed = value
326 .parse::<usize>()
327 .map_err(|_| ParseError::InvalidField(part.to_string()))?;
328 if total_requests.replace(parsed).is_some() {
329 return Err(ParseError::InvalidField(part.to_string()));
330 }
331 }
332 "terminal" => {
333 let parsed = value
334 .parse::<usize>()
335 .map_err(|_| ParseError::InvalidField(part.to_string()))?;
336 if terminal_requests.replace(parsed).is_some() {
337 return Err(ParseError::InvalidField(part.to_string()));
338 }
339 }
340 _ => return Err(ParseError::InvalidField(part.to_string())),
341 }
342 }
343
344 let status = StatusResponse {
345 protocol_version: protocol_version
346 .ok_or(ParseError::MissingField("protocol"))?,
347 health: health.ok_or(ParseError::MissingField("health"))?,
348 in_flight_requests: in_flight_requests
349 .ok_or(ParseError::MissingField("in_flight"))?,
350 total_requests: total_requests.ok_or(ParseError::MissingField("total"))?,
351 terminal_requests: terminal_requests
352 .ok_or(ParseError::MissingField("terminal"))?,
353 };
354 Ok(Self::Status(status))
355 }
356 Some("ack") => match parts.next() {
357 Some("stopping") => Ok(Self::AckStopping),
358 Some("request") => {
359 let mut id = None;
360 let mut state = None;
361 for part in parts {
362 let (key, value) = part
363 .split_once('=')
364 .ok_or(ParseError::InvalidField(part.to_string()))?;
365 match key {
366 "id" => {
367 if id.replace(value.to_string()).is_some() {
368 return Err(ParseError::InvalidField(part.to_string()));
369 }
370 }
371 "state" => {
372 if state.replace(value.to_string()).is_some() {
373 return Err(ParseError::InvalidField(part.to_string()));
374 }
375 }
376 _ => return Err(ParseError::InvalidField(part.to_string())),
377 }
378 }
379 Ok(Self::AckRequest {
380 id: id.ok_or(ParseError::MissingField("id"))?,
381 state: state.ok_or(ParseError::MissingField("state"))?,
382 })
383 }
384 Some(other) => Err(ParseError::UnknownToken(other.to_string())),
385 None => Err(ParseError::MissingField("ack")),
386 },
387 Some("list") => {
388 let mut ids = None;
389 for part in parts {
390 let (key, value) = part
391 .split_once('=')
392 .ok_or(ParseError::InvalidField(part.to_string()))?;
393 match key {
394 "ids" => {
395 let parsed = if value == "-" {
396 Vec::new()
397 } else {
398 let parsed =
399 value.split(',').map(str::to_string).collect::<Vec<_>>();
400 if parsed.iter().any(|id| id.is_empty()) {
401 return Err(ParseError::InvalidField(part.to_string()));
402 }
403 parsed
404 };
405 if ids.replace(parsed).is_some() {
406 return Err(ParseError::InvalidField(part.to_string()));
407 }
408 }
409 _ => return Err(ParseError::InvalidField(part.to_string())),
410 }
411 }
412 Ok(Self::RequestList {
413 ids: ids.ok_or(ParseError::MissingField("ids"))?,
414 })
415 }
416 Some("snapshot") => {
417 let mut id = None;
418 let mut state = None;
419 let mut sender = None;
420 let mut app_id = None;
421 let mut parent_window = None;
422 let mut saw_app_id = false;
423 let mut saw_parent_window = false;
424
425 for part in parts {
426 let (key, value) = part
427 .split_once('=')
428 .ok_or(ParseError::InvalidField(part.to_string()))?;
429 match key {
430 "id" => {
431 if id.replace(value.to_string()).is_some() {
432 return Err(ParseError::InvalidField(part.to_string()));
433 }
434 }
435 "state" => {
436 if state.replace(value.to_string()).is_some() {
437 return Err(ParseError::InvalidField(part.to_string()));
438 }
439 }
440 "sender" => {
441 if sender.replace(value.to_string()).is_some() {
442 return Err(ParseError::InvalidField(part.to_string()));
443 }
444 }
445 "app_id" => {
446 if saw_app_id {
447 return Err(ParseError::InvalidField(part.to_string()));
448 }
449 saw_app_id = true;
450 app_id = if value == "-" {
451 None
452 } else {
453 Some(value.to_string())
454 };
455 }
456 "parent" => {
457 if saw_parent_window {
458 return Err(ParseError::InvalidField(part.to_string()));
459 }
460 saw_parent_window = true;
461 parent_window = if value == "-" {
462 None
463 } else {
464 Some(value.to_string())
465 };
466 }
467 _ => return Err(ParseError::InvalidField(part.to_string())),
468 }
469 }
470
471 Ok(Self::RequestSnapshot {
472 id: id.ok_or(ParseError::MissingField("id"))?,
473 state: state.ok_or(ParseError::MissingField("state"))?,
474 sender: sender.ok_or(ParseError::MissingField("sender"))?,
475 app_id,
476 parent_window,
477 })
478 }
479 Some("error") => match parts.next() {
480 Some(first_field) => {
481 let mut code = None;
482 let mut reason = None;
483 let mut fields = vec![first_field];
484 fields.extend(parts);
485
486 for field in fields {
487 let (key, value) = field
488 .split_once('=')
489 .ok_or(ParseError::InvalidField(field.to_string()))?;
490 match key {
491 "code" => {
492 let parsed = value
493 .parse::<u32>()
494 .map_err(|_| ParseError::InvalidField(field.to_string()))?;
495 if code.replace(parsed).is_some() {
496 return Err(ParseError::InvalidField(field.to_string()));
497 }
498 }
499 "reason" => {
500 if reason.replace(value.to_string()).is_some() {
501 return Err(ParseError::InvalidField(field.to_string()));
502 }
503 }
504 _ => return Err(ParseError::InvalidField(field.to_string())),
505 }
506 }
507 Ok(Self::Error {
508 code: code.ok_or(ParseError::MissingField("code"))?,
509 reason: reason.ok_or(ParseError::MissingField("reason"))?,
510 })
511 }
512 None => Err(ParseError::MissingField("reason")),
513 },
514 Some(other) => Err(ParseError::UnknownToken(other.to_string())),
515 None => Err(ParseError::Empty),
516 }
517 }
518 }
519
520 #[derive(Debug, Clone, PartialEq, Eq)]
521 pub enum ParseError {
522 Empty,
523 MissingField(&'static str),
524 InvalidField(String),
525 UnknownToken(String),
526 }
527
528 impl fmt::Display for ParseError {
529 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 match self {
531 Self::Empty => write!(f, "empty input"),
532 Self::MissingField(field) => write!(f, "missing field: {field}"),
533 Self::InvalidField(field) => write!(f, "invalid field: {field}"),
534 Self::UnknownToken(token) => write!(f, "unknown token: {token}"),
535 }
536 }
537 }
538
539 impl std::error::Error for ParseError {}
540
541 fn parse_fields<'a, I>(parts: I) -> Option<std::collections::HashMap<String, String>>
542 where
543 I: Iterator<Item = &'a str>,
544 {
545 let mut fields = std::collections::HashMap::new();
546 for part in parts {
547 let (key, value) = part.split_once('=')?;
548 if fields.contains_key(key) {
549 return None;
550 }
551 fields.insert(key.to_string(), value.to_string());
552 }
553 Some(fields)
554 }
555
556 fn fields_only(fields: &std::collections::HashMap<String, String>, allowed: &[&str]) -> bool {
557 fields
558 .keys()
559 .all(|key| allowed.iter().any(|allowed| key == allowed))
560 }
561
562 #[cfg(test)]
563 mod tests {
564 use super::{
565 ControlRequest, ControlResponse, HealthStatus, PROTOCOL_VERSION, RequestTransitionTarget,
566 StatusResponse,
567 };
568
569 #[test]
570 fn request_parse_roundtrip() {
571 for request in [
572 ControlRequest::Status,
573 ControlRequest::Stop,
574 ControlRequest::ListRequests,
575 ControlRequest::InspectRequest {
576 id: "req-1".to_string(),
577 },
578 ControlRequest::BeginRequest {
579 id: "req-1".to_string(),
580 sender: ":1.2".to_string(),
581 app_id: Some("org.test.App".to_string()),
582 parent_window: Some("x11:0x2a".to_string()),
583 },
584 ControlRequest::TransitionRequest {
585 id: "req-1".to_string(),
586 sender: ":1.2".to_string(),
587 app_id: Some("org.test.App".to_string()),
588 target: RequestTransitionTarget::Cancelled,
589 },
590 ] {
591 let line = request.as_line();
592 let parsed = ControlRequest::parse_line(&line);
593 assert_eq!(parsed, Some(request));
594 }
595 }
596
597 #[test]
598 fn response_status_roundtrip() {
599 let response = ControlResponse::Status(StatusResponse {
600 protocol_version: PROTOCOL_VERSION,
601 health: HealthStatus::Healthy,
602 in_flight_requests: 7,
603 total_requests: 10,
604 terminal_requests: 3,
605 });
606 let line = response.to_line();
607 let parsed = ControlResponse::parse_line(&line).expect("response should parse");
608 assert_eq!(parsed, response);
609 }
610
611 #[test]
612 fn response_ack_roundtrip() {
613 for response in [
614 ControlResponse::AckStopping,
615 ControlResponse::RequestList {
616 ids: vec!["req-1".to_string(), "req-2".to_string()],
617 },
618 ControlResponse::RequestList { ids: Vec::new() },
619 ControlResponse::AckRequest {
620 id: "req-1".to_string(),
621 state: "pending".to_string(),
622 },
623 ControlResponse::RequestSnapshot {
624 id: "req-1".to_string(),
625 state: "awaiting_user".to_string(),
626 sender: ":1.2".to_string(),
627 app_id: Some("org.test.App".to_string()),
628 parent_window: Some("x11:0x2a".to_string()),
629 },
630 ControlResponse::Error {
631 code: 2,
632 reason: "invalid_request".to_string(),
633 },
634 ] {
635 let line = response.to_line();
636 let parsed = ControlResponse::parse_line(&line).expect("response should parse");
637 assert_eq!(parsed, response);
638 }
639 }
640
641 #[test]
642 fn healthy_response_uses_protocol_version() {
643 let response = StatusResponse::healthy();
644 assert_eq!(response.protocol_version, PROTOCOL_VERSION);
645 assert_eq!(response.health, HealthStatus::Healthy);
646 assert_eq!(response.in_flight_requests, 0);
647 assert_eq!(response.total_requests, 0);
648 assert_eq!(response.terminal_requests, 0);
649 }
650
651 #[test]
652 fn malformed_status_is_rejected() {
653 let parsed = ControlResponse::parse_line(
654 "status protocol=one health=healthy in_flight=0 total=0 terminal=0",
655 );
656 assert!(parsed.is_err());
657 }
658
659 #[test]
660 fn request_parse_rejects_duplicate_fields() {
661 let parsed = ControlRequest::parse_line("inspect id=req-1 id=req-2");
662 assert_eq!(parsed, None);
663 }
664
665 #[test]
666 fn request_parse_rejects_unknown_fields() {
667 let parsed = ControlRequest::parse_line("inspect id=req-1 bogus=1");
668 assert_eq!(parsed, None);
669 }
670
671 #[test]
672 fn response_parse_rejects_duplicate_fields() {
673 assert!(
674 ControlResponse::parse_line(
675 "status protocol=1 protocol=2 health=healthy in_flight=0 total=0 terminal=0",
676 )
677 .is_err()
678 );
679 assert!(
680 ControlResponse::parse_line("ack request id=req-1 id=req-2 state=pending").is_err()
681 );
682 assert!(ControlResponse::parse_line("list ids=req-1 ids=req-2").is_err());
683 assert!(
684 ControlResponse::parse_line(
685 "snapshot id=req-1 state=pending sender=:1.2 app_id=- app_id=org.test.App parent=-",
686 )
687 .is_err()
688 );
689 assert!(ControlResponse::parse_line("error code=2 reason=bad reason=worse").is_err());
690 }
691 }
692