Rust · 19438 bytes Raw Blame History
1 //! Expression parser using recursive descent
2 //!
3 //! Grammar:
4 //! ```text
5 //! expr = equation
6 //! equation = additive ("=" additive)?
7 //! additive = multiplicative (("+"|"-") multiplicative)*
8 //! multiplicative = power (("*"|"/") power)*
9 //! power = unary ("^" power)?
10 //! unary = "-" unary | postfix
11 //! postfix = primary ("!")*
12 //! primary = NUMBER | SYMBOL | "(" expr ")" | function | "[" matrix "]"
13 //! function = SYMBOL "(" args? ")"
14 //! args = expr ("," expr)*
15 //! matrix = row (";" row)*
16 //! row = expr ("," expr)*
17 //! ```
18
19 use crate::error::{CasError, Result};
20 use crate::expr::{Expr, Symbol};
21
22 /// Token types
23 #[derive(Debug, Clone, PartialEq)]
24 pub enum Token {
25 Number(f64),
26 Symbol(String),
27 Plus,
28 Minus,
29 Star,
30 Slash,
31 Caret,
32 Bang,
33 Eq,
34 LParen,
35 RParen,
36 LBracket,
37 RBracket,
38 Comma,
39 Semicolon,
40 Eof,
41 }
42
43 /// Tokenizer for mathematical expressions
44 pub struct Lexer<'a> {
45 input: &'a str,
46 pos: usize,
47 }
48
49 impl<'a> Lexer<'a> {
50 pub fn new(input: &'a str) -> Self {
51 Self { input, pos: 0 }
52 }
53
54 fn peek_char(&self) -> Option<char> {
55 self.input[self.pos..].chars().next()
56 }
57
58 fn advance(&mut self) -> Option<char> {
59 let c = self.peek_char()?;
60 self.pos += c.len_utf8();
61 Some(c)
62 }
63
64 fn skip_whitespace(&mut self) {
65 while let Some(c) = self.peek_char() {
66 if c.is_whitespace() {
67 self.advance();
68 } else {
69 break;
70 }
71 }
72 }
73
74 fn read_number(&mut self) -> Token {
75 let start = self.pos;
76 let mut has_dot = false;
77 let mut has_exp = false;
78
79 while let Some(c) = self.peek_char() {
80 if c.is_ascii_digit() {
81 self.advance();
82 } else if c == '.' && !has_dot && !has_exp {
83 has_dot = true;
84 self.advance();
85 } else if (c == 'e' || c == 'E') && !has_exp {
86 has_exp = true;
87 self.advance();
88 // Handle optional sign after exponent
89 if let Some(sign) = self.peek_char() {
90 if sign == '+' || sign == '-' {
91 self.advance();
92 }
93 }
94 } else {
95 break;
96 }
97 }
98
99 let num_str = &self.input[start..self.pos];
100 let value: f64 = num_str.parse().unwrap_or(f64::NAN);
101 Token::Number(value)
102 }
103
104 fn read_symbol(&mut self) -> Token {
105 let start = self.pos;
106
107 while let Some(c) = self.peek_char() {
108 if c.is_alphanumeric() || c == '_' {
109 self.advance();
110 } else {
111 break;
112 }
113 }
114
115 let name = self.input[start..self.pos].to_string();
116 Token::Symbol(name)
117 }
118
119 pub fn next_token(&mut self) -> Result<Token> {
120 self.skip_whitespace();
121
122 let Some(c) = self.peek_char() else {
123 return Ok(Token::Eof);
124 };
125
126 let token = match c {
127 '+' => {
128 self.advance();
129 Token::Plus
130 }
131 '-' => {
132 self.advance();
133 Token::Minus
134 }
135 '*' => {
136 self.advance();
137 Token::Star
138 }
139 '/' => {
140 self.advance();
141 Token::Slash
142 }
143 '^' => {
144 self.advance();
145 Token::Caret
146 }
147 '!' => {
148 self.advance();
149 Token::Bang
150 }
151 '=' => {
152 self.advance();
153 Token::Eq
154 }
155 '(' => {
156 self.advance();
157 Token::LParen
158 }
159 ')' => {
160 self.advance();
161 Token::RParen
162 }
163 '[' => {
164 self.advance();
165 Token::LBracket
166 }
167 ']' => {
168 self.advance();
169 Token::RBracket
170 }
171 ',' => {
172 self.advance();
173 Token::Comma
174 }
175 ';' => {
176 self.advance();
177 Token::Semicolon
178 }
179 _ if c.is_ascii_digit() || c == '.' => self.read_number(),
180 _ if c.is_alphabetic() || c == '_' => self.read_symbol(),
181 _ => {
182 return Err(CasError::Parse {
183 position: self.pos,
184 message: format!("unexpected character: '{c}'"),
185 });
186 }
187 };
188
189 Ok(token)
190 }
191
192 pub fn position(&self) -> usize {
193 self.pos
194 }
195 }
196
197 /// Recursive descent parser
198 pub struct Parser<'a> {
199 lexer: Lexer<'a>,
200 current: Token,
201 }
202
203 impl<'a> Parser<'a> {
204 pub fn new(input: &'a str) -> Result<Self> {
205 let mut lexer = Lexer::new(input);
206 let current = lexer.next_token()?;
207 Ok(Self { lexer, current })
208 }
209
210 fn advance(&mut self) -> Result<Token> {
211 let prev = std::mem::replace(&mut self.current, self.lexer.next_token()?);
212 Ok(prev)
213 }
214
215 fn expect(&mut self, expected: Token) -> Result<()> {
216 if self.current == expected {
217 self.advance()?;
218 Ok(())
219 } else {
220 Err(CasError::Parse {
221 position: self.lexer.position(),
222 message: format!("expected {expected:?}, found {:?}", self.current),
223 })
224 }
225 }
226
227 /// Parse a complete expression
228 pub fn parse(&mut self) -> Result<Expr> {
229 let expr = self.parse_equation()?;
230 if self.current != Token::Eof {
231 return Err(CasError::Parse {
232 position: self.lexer.position(),
233 message: format!("unexpected token: {:?}", self.current),
234 });
235 }
236 Ok(expr)
237 }
238
239 fn parse_equation(&mut self) -> Result<Expr> {
240 let lhs = self.parse_additive()?;
241 if self.current == Token::Eq {
242 self.advance()?;
243 let rhs = self.parse_additive()?;
244 Ok(Expr::Equation(Box::new(lhs), Box::new(rhs)))
245 } else {
246 Ok(lhs)
247 }
248 }
249
250 fn parse_additive(&mut self) -> Result<Expr> {
251 let mut terms = vec![self.parse_multiplicative()?];
252
253 loop {
254 match &self.current {
255 Token::Plus => {
256 self.advance()?;
257 terms.push(self.parse_multiplicative()?);
258 }
259 Token::Minus => {
260 self.advance()?;
261 terms.push(Expr::neg(self.parse_multiplicative()?));
262 }
263 _ => break,
264 }
265 }
266
267 Ok(Expr::add(terms))
268 }
269
270 fn parse_multiplicative(&mut self) -> Result<Expr> {
271 let mut factors = vec![self.parse_power()?];
272
273 loop {
274 match &self.current {
275 Token::Star => {
276 self.advance()?;
277 factors.push(self.parse_power()?);
278 }
279 Token::Slash => {
280 self.advance()?;
281 let divisor = self.parse_power()?;
282 factors.push(Expr::pow(divisor, Expr::integer(-1)));
283 }
284 // Implicit multiplication: 2x, xy, x(y+1)
285 Token::Symbol(_) | Token::LParen | Token::Number(_) => {
286 // Check if this could be implicit multiplication
287 if let Some(last) = factors.last() {
288 if matches!(
289 last,
290 Expr::Symbol(_) | Expr::Integer(_) | Expr::Float(_) | Expr::Func(_, _)
291 ) {
292 factors.push(self.parse_power()?);
293 continue;
294 }
295 }
296 break;
297 }
298 _ => break,
299 }
300 }
301
302 Ok(Expr::mul(factors))
303 }
304
305 fn parse_power(&mut self) -> Result<Expr> {
306 let base = self.parse_unary()?;
307 if self.current == Token::Caret {
308 self.advance()?;
309 let exp = self.parse_power()?; // Right associative
310 Ok(Expr::pow(base, exp))
311 } else {
312 Ok(base)
313 }
314 }
315
316 fn parse_unary(&mut self) -> Result<Expr> {
317 if self.current == Token::Minus {
318 self.advance()?;
319 let expr = self.parse_unary()?;
320 Ok(Expr::neg(expr))
321 } else {
322 self.parse_postfix()
323 }
324 }
325
326 fn parse_postfix(&mut self) -> Result<Expr> {
327 let mut expr = self.parse_primary()?;
328
329 while self.current == Token::Bang {
330 self.advance()?;
331 expr = Expr::func("factorial", vec![expr]);
332 }
333
334 Ok(expr)
335 }
336
337 fn parse_primary(&mut self) -> Result<Expr> {
338 match &self.current {
339 Token::Number(n) => {
340 let n = *n;
341 self.advance()?;
342 // Convert to integer if possible
343 if n.fract() == 0.0 && n.abs() < i64::MAX as f64 {
344 Ok(Expr::integer(n as i64))
345 } else {
346 Ok(Expr::float(n))
347 }
348 }
349 Token::Symbol(name) => {
350 let name = name.clone();
351 self.advance()?;
352
353 // Check for function call
354 if self.current == Token::LParen {
355 self.advance()?;
356 let args = self.parse_args()?;
357 self.expect(Token::RParen)?;
358 Ok(self.handle_special_function(&name, args))
359 } else {
360 // Constants
361 match name.as_str() {
362 "pi" | "PI" => Ok(Expr::symbol("pi")),
363 "e" | "E" => Ok(Expr::symbol("e")),
364 "i" => Ok(Expr::Complex(0.0, 1.0)),
365 "inf" | "infinity" => Ok(Expr::Infinity(crate::expr::Sign::Positive)),
366 _ => Ok(Expr::symbol(name)),
367 }
368 }
369 }
370 Token::LParen => {
371 self.advance()?;
372 let expr = self.parse_equation()?;
373 self.expect(Token::RParen)?;
374 Ok(expr)
375 }
376 Token::LBracket => {
377 self.advance()?;
378 let matrix = self.parse_matrix()?;
379 self.expect(Token::RBracket)?;
380 Ok(matrix)
381 }
382 _ => Err(CasError::Parse {
383 position: self.lexer.position(),
384 message: format!("unexpected token: {:?}", self.current),
385 }),
386 }
387 }
388
389 fn parse_args(&mut self) -> Result<Vec<Expr>> {
390 if self.current == Token::RParen {
391 return Ok(vec![]);
392 }
393
394 let mut args = vec![self.parse_equation()?];
395
396 while self.current == Token::Comma {
397 self.advance()?;
398 args.push(self.parse_equation()?);
399 }
400
401 Ok(args)
402 }
403
404 fn parse_matrix(&mut self) -> Result<Expr> {
405 let mut rows = vec![self.parse_row()?];
406
407 while self.current == Token::Semicolon {
408 self.advance()?;
409 rows.push(self.parse_row()?);
410 }
411
412 // Single row = vector, multiple rows = matrix
413 if rows.len() == 1 {
414 Ok(Expr::Vector(rows.into_iter().next().unwrap()))
415 } else {
416 Ok(Expr::Matrix(rows))
417 }
418 }
419
420 fn parse_row(&mut self) -> Result<Vec<Expr>> {
421 let mut elems = vec![self.parse_equation()?];
422
423 while self.current == Token::Comma {
424 self.advance()?;
425 elems.push(self.parse_equation()?);
426 }
427
428 Ok(elems)
429 }
430
431 /// Handle special functions that need custom AST nodes
432 fn handle_special_function(&self, name: &str, mut args: Vec<Expr>) -> Expr {
433 match name {
434 "diff" | "derivative" => {
435 if args.len() >= 2 {
436 let expr = args.remove(0);
437 let var = match args.remove(0) {
438 Expr::Symbol(s) => s,
439 _ => Symbol::new("x"),
440 };
441 let order = if args.is_empty() {
442 1
443 } else {
444 match &args[0] {
445 Expr::Integer(n) => *n as u32,
446 _ => 1,
447 }
448 };
449 Expr::Derivative {
450 expr: Box::new(expr),
451 var,
452 order,
453 }
454 } else if args.len() == 1 {
455 Expr::Derivative {
456 expr: Box::new(args.remove(0)),
457 var: Symbol::new("x"),
458 order: 1,
459 }
460 } else {
461 Expr::func(name, args)
462 }
463 }
464 "integrate" | "int" => {
465 if args.len() >= 2 {
466 let expr = args.remove(0);
467 let var = match args.remove(0) {
468 Expr::Symbol(s) => s,
469 _ => Symbol::new("x"),
470 };
471 let (lower, upper) = if args.len() >= 2 {
472 (
473 Some(Box::new(args.remove(0))),
474 Some(Box::new(args.remove(0))),
475 )
476 } else {
477 (None, None)
478 };
479 Expr::Integral {
480 expr: Box::new(expr),
481 var,
482 lower,
483 upper,
484 }
485 } else {
486 Expr::func(name, args)
487 }
488 }
489 "lim" | "limit" => {
490 if args.len() >= 3 {
491 let expr = args.remove(0);
492 let var = match args.remove(0) {
493 Expr::Symbol(s) => s,
494 _ => Symbol::new("x"),
495 };
496 let point = args.remove(0);
497 Expr::Limit {
498 expr: Box::new(expr),
499 var,
500 point: Box::new(point),
501 direction: None,
502 }
503 } else {
504 Expr::func(name, args)
505 }
506 }
507 "sum" => {
508 if args.len() >= 4 {
509 let expr = args.remove(0);
510 let var = match args.remove(0) {
511 Expr::Symbol(s) => s,
512 _ => Symbol::new("n"),
513 };
514 let lower = args.remove(0);
515 let upper = args.remove(0);
516 Expr::Sum {
517 expr: Box::new(expr),
518 var,
519 lower: Box::new(lower),
520 upper: Box::new(upper),
521 }
522 } else {
523 Expr::func(name, args)
524 }
525 }
526 "product" | "prod" => {
527 if args.len() >= 4 {
528 let expr = args.remove(0);
529 let var = match args.remove(0) {
530 Expr::Symbol(s) => s,
531 _ => Symbol::new("n"),
532 };
533 let lower = args.remove(0);
534 let upper = args.remove(0);
535 Expr::Product {
536 expr: Box::new(expr),
537 var,
538 lower: Box::new(lower),
539 upper: Box::new(upper),
540 }
541 } else {
542 Expr::func(name, args)
543 }
544 }
545 _ => Expr::func(name, args),
546 }
547 }
548 }
549
550 /// Parse a string into an expression
551 pub fn parse(input: &str) -> Result<Expr> {
552 Parser::new(input)?.parse()
553 }
554
555 #[cfg(test)]
556 mod tests {
557 use super::*;
558
559 #[test]
560 fn test_number() {
561 let expr = parse("42").unwrap();
562 assert_eq!(expr, Expr::integer(42));
563 }
564
565 #[test]
566 fn test_float() {
567 let expr = parse("3.14").unwrap();
568 assert_eq!(expr, Expr::float(3.14));
569 }
570
571 #[test]
572 fn test_addition() {
573 let expr = parse("1 + 2 + 3").unwrap();
574 assert_eq!(
575 expr,
576 Expr::add(vec![Expr::integer(1), Expr::integer(2), Expr::integer(3)])
577 );
578 }
579
580 #[test]
581 fn test_subtraction() {
582 let expr = parse("5 - 3").unwrap();
583 assert_eq!(
584 expr,
585 Expr::add(vec![Expr::integer(5), Expr::neg(Expr::integer(3))])
586 );
587 }
588
589 #[test]
590 fn test_multiplication() {
591 let expr = parse("2 * 3").unwrap();
592 assert_eq!(expr, Expr::mul(vec![Expr::integer(2), Expr::integer(3)]));
593 }
594
595 #[test]
596 fn test_division() {
597 let expr = parse("6 / 2").unwrap();
598 assert_eq!(
599 expr,
600 Expr::mul(vec![
601 Expr::integer(6),
602 Expr::pow(Expr::integer(2), Expr::integer(-1))
603 ])
604 );
605 }
606
607 #[test]
608 fn test_power() {
609 let expr = parse("2^3").unwrap();
610 assert_eq!(expr, Expr::pow(Expr::integer(2), Expr::integer(3)));
611 }
612
613 #[test]
614 fn test_implicit_multiplication() {
615 let expr = parse("2x").unwrap();
616 assert_eq!(expr, Expr::mul(vec![Expr::integer(2), Expr::symbol("x")]));
617 }
618
619 #[test]
620 fn test_function() {
621 let expr = parse("sin(x)").unwrap();
622 assert_eq!(expr, Expr::func("sin", vec![Expr::symbol("x")]));
623 }
624
625 #[test]
626 fn test_nested() {
627 let expr = parse("sin(x + 1)").unwrap();
628 assert_eq!(
629 expr,
630 Expr::func(
631 "sin",
632 vec![Expr::add(vec![Expr::symbol("x"), Expr::integer(1)])]
633 )
634 );
635 }
636
637 #[test]
638 fn test_equation() {
639 let expr = parse("x + 1 = 5").unwrap();
640 assert!(matches!(expr, Expr::Equation(_, _)));
641 }
642
643 #[test]
644 fn test_derivative() {
645 let expr = parse("diff(x^2, x)").unwrap();
646 assert!(matches!(expr, Expr::Derivative { .. }));
647 }
648
649 #[test]
650 fn test_integral() {
651 let expr = parse("integrate(x^2, x)").unwrap();
652 assert!(matches!(expr, Expr::Integral { lower: None, .. }));
653 }
654
655 #[test]
656 fn test_definite_integral() {
657 let expr = parse("integrate(x^2, x, 0, 1)").unwrap();
658 assert!(matches!(expr, Expr::Integral { lower: Some(_), .. }));
659 }
660
661 #[test]
662 fn test_vector() {
663 let expr = parse("[1, 2, 3]").unwrap();
664 assert!(matches!(expr, Expr::Vector(_)));
665 }
666
667 #[test]
668 fn test_matrix() {
669 let expr = parse("[1, 2; 3, 4]").unwrap();
670 assert!(matches!(expr, Expr::Matrix(_)));
671 }
672
673 #[test]
674 fn test_factorial() {
675 let expr = parse("5!").unwrap();
676 assert_eq!(expr, Expr::func("factorial", vec![Expr::integer(5)]));
677 }
678 }
679