Rust · 32453 bytes Raw Blame History
1 //! Conversion between MathBox and CAS Expr types
2 //!
3 //! Enables bidirectional conversion:
4 //! - MathBox -> Expr for evaluation
5 //! - Expr -> MathBox for result display
6
7 use crate::mathbox::{LimitDirection as MathLimitDirection, MathBox, Operator};
8 use garcalc_cas::expr::{Expr, LimitDirection, Rational, Sign, Symbol};
9 use thiserror::Error;
10
11 /// Errors that can occur during conversion
12 #[derive(Debug, Error)]
13 pub enum ConvertError {
14 #[error("Empty slot in expression")]
15 EmptySlot,
16 #[error("Invalid number: {0}")]
17 InvalidNumber(String),
18 #[error("Unsupported expression type")]
19 Unsupported,
20 #[error("Missing required field: {0}")]
21 MissingField(String),
22 }
23
24 /// Convert a MathBox to a CAS Expr
25 pub fn to_expr(mathbox: &MathBox) -> Result<Expr, ConvertError> {
26 match mathbox {
27 MathBox::Number(s) => parse_number(s),
28 MathBox::Symbol(s) => Ok(Expr::Symbol(Symbol::new(s.clone()))),
29 MathBox::Operator(_) => Err(ConvertError::Unsupported),
30 MathBox::Slot => Err(ConvertError::EmptySlot),
31
32 MathBox::Fraction { num, den } => {
33 let num_expr = to_expr(num)?;
34 let den_expr = to_expr(den)?;
35 Ok(Expr::Mul(vec![
36 num_expr,
37 Expr::Pow(Box::new(den_expr), Box::new(Expr::Integer(-1))),
38 ]))
39 }
40
41 MathBox::Power { base, exp } => {
42 let base_expr = to_expr(base)?;
43 let exp_expr = to_expr(exp)?;
44 Ok(Expr::Pow(Box::new(base_expr), Box::new(exp_expr)))
45 }
46
47 MathBox::Subscript { base, sub } => {
48 // Subscripted variables become a combined symbol
49 if let (MathBox::Symbol(b), MathBox::Number(s) | MathBox::Symbol(s)) =
50 (base.as_ref(), sub.as_ref())
51 {
52 Ok(Expr::Symbol(Symbol::new(format!("{}_{}", b, s))))
53 } else {
54 let base_expr = to_expr(base)?;
55 Ok(base_expr)
56 }
57 }
58
59 MathBox::Root { index, radicand } => {
60 let radicand_expr = to_expr(radicand)?;
61 let exp = if let Some(idx) = index {
62 let idx_expr = to_expr(idx)?;
63 Expr::Mul(vec![
64 Expr::Integer(1),
65 Expr::Pow(Box::new(idx_expr), Box::new(Expr::Integer(-1))),
66 ])
67 } else {
68 // Square root: 1/2
69 Expr::Rational(Rational::new(1, 2))
70 };
71 Ok(Expr::Pow(Box::new(radicand_expr), Box::new(exp)))
72 }
73
74 MathBox::Func { name, args } => {
75 let arg_exprs: Result<Vec<Expr>, ConvertError> = args.iter().map(to_expr).collect();
76 Ok(Expr::Func(name.clone(), arg_exprs?))
77 }
78
79 MathBox::Abs(inner) => {
80 let inner_expr = to_expr(inner)?;
81 Ok(Expr::Func("abs".to_string(), vec![inner_expr]))
82 }
83
84 MathBox::Parens(inner) => to_expr(inner),
85
86 MathBox::Integral {
87 lower,
88 upper,
89 body,
90 var,
91 } => {
92 let body_expr = to_expr(body)?;
93 let var_sym = Symbol::new(extract_symbol_str(var)?);
94
95 let lower_expr = match lower {
96 Some(lo) => Some(Box::new(to_expr(lo)?)),
97 None => None,
98 };
99 let upper_expr = match upper {
100 Some(hi) => Some(Box::new(to_expr(hi)?)),
101 None => None,
102 };
103
104 Ok(Expr::Integral {
105 expr: Box::new(body_expr),
106 var: var_sym,
107 lower: lower_expr,
108 upper: upper_expr,
109 })
110 }
111
112 MathBox::Derivative { order, var, body } => {
113 let body_expr = to_expr(body)?;
114 let var_sym = Symbol::new(extract_symbol_str(var)?);
115 Ok(Expr::Derivative {
116 expr: Box::new(body_expr),
117 var: var_sym,
118 order: *order,
119 })
120 }
121
122 MathBox::Limit {
123 var,
124 to,
125 direction,
126 body,
127 } => {
128 let body_expr = to_expr(body)?;
129 let var_sym = Symbol::new(extract_symbol_str(var)?);
130 let point_expr = to_expr(to)?;
131
132 let dir = direction.map(|d| match d {
133 MathLimitDirection::FromLeft => LimitDirection::Left,
134 MathLimitDirection::FromRight => LimitDirection::Right,
135 });
136
137 Ok(Expr::Limit {
138 expr: Box::new(body_expr),
139 var: var_sym,
140 point: Box::new(point_expr),
141 direction: dir,
142 })
143 }
144
145 MathBox::Sum {
146 var,
147 lower,
148 upper,
149 body,
150 } => {
151 let var_sym = Symbol::new(extract_symbol_str(var)?);
152 let lower_expr = to_expr(lower)?;
153 let upper_expr = to_expr(upper)?;
154 let body_expr = to_expr(body)?;
155
156 Ok(Expr::Sum {
157 expr: Box::new(body_expr),
158 var: var_sym,
159 lower: Box::new(lower_expr),
160 upper: Box::new(upper_expr),
161 })
162 }
163
164 MathBox::Product {
165 var,
166 lower,
167 upper,
168 body,
169 } => {
170 let var_sym = Symbol::new(extract_symbol_str(var)?);
171 let lower_expr = to_expr(lower)?;
172 let upper_expr = to_expr(upper)?;
173 let body_expr = to_expr(body)?;
174
175 Ok(Expr::Product {
176 expr: Box::new(body_expr),
177 var: var_sym,
178 lower: Box::new(lower_expr),
179 upper: Box::new(upper_expr),
180 })
181 }
182
183 MathBox::Matrix { rows } => {
184 let expr_rows: Result<Vec<Vec<Expr>>, ConvertError> = rows
185 .iter()
186 .map(|row| row.iter().map(to_expr).collect())
187 .collect();
188 Ok(Expr::Matrix(expr_rows?))
189 }
190
191 MathBox::Row(items) => {
192 // Parse a row as an expression (needs operator precedence)
193 convert_row(items)
194 }
195 }
196 }
197
198 /// Convert a CAS Expr to a MathBox for display
199 pub fn from_expr(expr: &Expr) -> MathBox {
200 match expr {
201 Expr::Integer(n) => MathBox::Number(n.to_string()),
202
203 Expr::Rational(r) => {
204 if r.den == 1 {
205 MathBox::Number(r.num.to_string())
206 } else {
207 MathBox::Fraction {
208 num: Box::new(MathBox::Number(r.num.to_string())),
209 den: Box::new(MathBox::Number(r.den.to_string())),
210 }
211 }
212 }
213
214 Expr::Float(f) => MathBox::Number(format_float(*f)),
215
216 Expr::Complex(re, im) => {
217 if *re == 0.0 {
218 MathBox::Row(vec![
219 MathBox::Number(format_float(*im)),
220 MathBox::Symbol("i".to_string()),
221 ])
222 } else if *im >= 0.0 {
223 MathBox::Row(vec![
224 MathBox::Number(format_float(*re)),
225 MathBox::Operator(Operator::Add),
226 MathBox::Number(format_float(*im)),
227 MathBox::Symbol("i".to_string()),
228 ])
229 } else {
230 MathBox::Row(vec![
231 MathBox::Number(format_float(*re)),
232 MathBox::Operator(Operator::Sub),
233 MathBox::Number(format_float(im.abs())),
234 MathBox::Symbol("i".to_string()),
235 ])
236 }
237 }
238
239 Expr::Symbol(s) => {
240 // Convert common symbols to proper display
241 match s.as_str() {
242 "pi" => MathBox::Symbol("π".to_string()),
243 "e" => MathBox::Symbol("e".to_string()),
244 "inf" | "infinity" => MathBox::Symbol("∞".to_string()),
245 _ => MathBox::Symbol(s.0.clone()),
246 }
247 }
248
249 Expr::Add(terms) => {
250 let mut items = Vec::new();
251 for (i, term) in terms.iter().enumerate() {
252 if i > 0 {
253 // Check if term is negative
254 if is_negative(term) {
255 items.push(MathBox::Operator(Operator::Sub));
256 items.push(from_expr(&negate(term)));
257 } else {
258 items.push(MathBox::Operator(Operator::Add));
259 items.push(from_expr(term));
260 }
261 } else {
262 items.push(from_expr(term));
263 }
264 }
265 if items.len() == 1 {
266 items.pop().unwrap()
267 } else {
268 MathBox::Row(items)
269 }
270 }
271
272 Expr::Mul(factors) => {
273 // Check for division pattern (x * y^-1)
274 let (numerator, denominator): (Vec<_>, Vec<_>) =
275 factors.iter().partition(|f| !is_reciprocal(f));
276
277 let mut numerator_factors: Vec<Expr> = numerator.into_iter().cloned().collect();
278 numerator_factors.retain(|f| !is_multiplicative_identity(f));
279
280 if !denominator.is_empty() {
281 let den_factors: Vec<Expr> = denominator
282 .iter()
283 .map(|f| extract_base_of_reciprocal(f))
284 .filter(|f| !is_multiplicative_identity(f))
285 .collect();
286
287 let num_box = if numerator_factors.is_empty() {
288 MathBox::Number("1".to_string())
289 } else if numerator_factors.len() == 1 {
290 render_mul_factor(&numerator_factors[0])
291 } else {
292 from_mul_factors(&numerator_factors)
293 };
294
295 if den_factors.is_empty() {
296 return num_box;
297 }
298
299 let den_box = if den_factors.len() == 1 {
300 render_mul_factor(&den_factors[0])
301 } else {
302 from_mul_factors(&den_factors)
303 };
304
305 return MathBox::Fraction {
306 num: Box::new(num_box),
307 den: Box::new(den_box),
308 };
309 }
310
311 // Normal multiplication
312 if numerator_factors.is_empty() {
313 MathBox::Number("1".to_string())
314 } else if numerator_factors.len() == 1 {
315 render_mul_factor(&numerator_factors[0])
316 } else {
317 from_mul_factors(&numerator_factors)
318 }
319 }
320
321 Expr::Pow(base, exp) => {
322 // Check for square root / nth root
323 if let Expr::Rational(r) = exp.as_ref() {
324 if r.num == 1 && r.den == 2 {
325 return MathBox::Root {
326 index: None,
327 radicand: Box::new(from_expr(base)),
328 };
329 } else if r.num == 1 && r.den > 2 {
330 return MathBox::Root {
331 index: Some(Box::new(MathBox::Number(r.den.to_string()))),
332 radicand: Box::new(from_expr(base)),
333 };
334 }
335 }
336
337 MathBox::Power {
338 base: Box::new(from_expr(base)),
339 exp: Box::new(from_expr(exp)),
340 }
341 }
342
343 Expr::Neg(inner) => MathBox::Row(vec![MathBox::Operator(Operator::Sub), from_expr(inner)]),
344
345 Expr::Func(name, args) => {
346 if name == "factorial" && args.len() == 1 {
347 let arg_expr = &args[0];
348 let arg_box = from_expr(arg_expr);
349 let formatted_arg = if factorial_arg_needs_parens(arg_expr) {
350 MathBox::Parens(Box::new(arg_box))
351 } else {
352 arg_box
353 };
354
355 return MathBox::Row(vec![formatted_arg, MathBox::Symbol("!".to_string())]);
356 }
357
358 let arg_boxes: Vec<MathBox> = args.iter().map(from_expr).collect();
359
360 // Special functions get special rendering
361 match name.as_str() {
362 "abs" if args.len() == 1 => {
363 MathBox::Abs(Box::new(arg_boxes.into_iter().next().unwrap()))
364 }
365 _ => MathBox::Func {
366 name: name.clone(),
367 args: arg_boxes,
368 },
369 }
370 }
371
372 Expr::Derivative { expr, var, order } => MathBox::Derivative {
373 order: *order,
374 var: Box::new(MathBox::Symbol(var.0.clone())),
375 body: Box::new(from_expr(expr)),
376 },
377
378 Expr::Integral {
379 expr,
380 var,
381 lower,
382 upper,
383 } => MathBox::Integral {
384 lower: lower.as_ref().map(|lo| Box::new(from_expr(lo))),
385 upper: upper.as_ref().map(|hi| Box::new(from_expr(hi))),
386 body: Box::new(from_expr(expr)),
387 var: Box::new(MathBox::Symbol(var.0.clone())),
388 },
389
390 Expr::Limit {
391 expr,
392 var,
393 point,
394 direction,
395 } => MathBox::Limit {
396 var: Box::new(MathBox::Symbol(var.0.clone())),
397 to: Box::new(from_expr(point)),
398 direction: direction.map(|d| match d {
399 LimitDirection::Left => MathLimitDirection::FromLeft,
400 LimitDirection::Right => MathLimitDirection::FromRight,
401 }),
402 body: Box::new(from_expr(expr)),
403 },
404
405 Expr::Sum {
406 expr,
407 var,
408 lower,
409 upper,
410 } => MathBox::Sum {
411 var: Box::new(MathBox::Symbol(var.0.clone())),
412 lower: Box::new(from_expr(lower)),
413 upper: Box::new(from_expr(upper)),
414 body: Box::new(from_expr(expr)),
415 },
416
417 Expr::Product {
418 expr,
419 var,
420 lower,
421 upper,
422 } => MathBox::Product {
423 var: Box::new(MathBox::Symbol(var.0.clone())),
424 lower: Box::new(from_expr(lower)),
425 upper: Box::new(from_expr(upper)),
426 body: Box::new(from_expr(expr)),
427 },
428
429 Expr::Matrix(rows) => {
430 let box_rows: Vec<Vec<MathBox>> = rows
431 .iter()
432 .map(|row| row.iter().map(from_expr).collect())
433 .collect();
434 MathBox::Matrix { rows: box_rows }
435 }
436
437 Expr::Equation(lhs, rhs) => MathBox::Row(vec![
438 from_expr(lhs),
439 MathBox::Operator(Operator::Eq),
440 from_expr(rhs),
441 ]),
442
443 Expr::Undefined => MathBox::Symbol("undefined".to_string()),
444 Expr::Infinity(sign) => match sign {
445 Sign::Positive => MathBox::Symbol("∞".to_string()),
446 Sign::Negative => MathBox::Row(vec![
447 MathBox::Operator(Operator::Sub),
448 MathBox::Symbol("∞".to_string()),
449 ]),
450 },
451
452 Expr::Vector(elems) => {
453 // Display vector as a row matrix
454 let box_row: Vec<MathBox> = elems.iter().map(from_expr).collect();
455 MathBox::Matrix {
456 rows: vec![box_row],
457 }
458 }
459
460 Expr::Inequality { lhs, op, rhs } => {
461 use garcalc_cas::expr::InequalityOp;
462 let op_box = match op {
463 InequalityOp::Lt => MathBox::Operator(Operator::Lt),
464 InequalityOp::Le => MathBox::Operator(Operator::Le),
465 InequalityOp::Gt => MathBox::Operator(Operator::Gt),
466 InequalityOp::Ge => MathBox::Operator(Operator::Ge),
467 InequalityOp::Ne => MathBox::Operator(Operator::Ne),
468 };
469 MathBox::Row(vec![from_expr(lhs), op_box, from_expr(rhs)])
470 }
471 }
472 }
473
474 // Helper functions
475
476 fn parse_number(s: &str) -> Result<Expr, ConvertError> {
477 if s.contains('.') {
478 s.parse::<f64>()
479 .map(Expr::Float)
480 .map_err(|_| ConvertError::InvalidNumber(s.to_string()))
481 } else {
482 s.parse::<i64>()
483 .map(Expr::Integer)
484 .map_err(|_| ConvertError::InvalidNumber(s.to_string()))
485 }
486 }
487
488 fn extract_symbol_str(mathbox: &MathBox) -> Result<String, ConvertError> {
489 match mathbox {
490 MathBox::Symbol(s) => Ok(s.clone()),
491 MathBox::Slot => Err(ConvertError::EmptySlot),
492 _ => Err(ConvertError::MissingField("variable".to_string())),
493 }
494 }
495
496 fn convert_row(items: &[MathBox]) -> Result<Expr, ConvertError> {
497 if items.is_empty() {
498 return Err(ConvertError::EmptySlot);
499 }
500 if items.len() == 1 {
501 return to_expr(&items[0]);
502 }
503
504 // Parse a flat row with implicit multiplication and basic precedence.
505 let mut operands: Vec<Expr> = Vec::new();
506 let mut operators: Vec<Operator> = Vec::new();
507 let mut expecting_operand = true;
508
509 let mut i = 0;
510 while i < items.len() {
511 let mut push_operand = |expr: Expr| {
512 if !expecting_operand {
513 operators.push(Operator::Mul);
514 }
515 operands.push(expr);
516 expecting_operand = false;
517 };
518
519 match (&items[i], items.get(i + 1)) {
520 (MathBox::Operator(op), _) => {
521 match op {
522 Operator::Add if expecting_operand => {
523 // Unary plus: no-op.
524 }
525 Operator::Sub if expecting_operand => {
526 // Unary minus: rewrite as 0 - ...
527 operands.push(Expr::Integer(0));
528 operators.push(Operator::Sub);
529 }
530 _ => {
531 operators.push(*op);
532 }
533 }
534 expecting_operand = true;
535 }
536 (MathBox::Symbol(s), _) if s == "!" => {
537 let Some(last) = operands.pop() else {
538 return Err(ConvertError::MissingField(
539 "factorial operand before '!'".to_string(),
540 ));
541 };
542 operands.push(Expr::Func("factorial".to_string(), vec![last]));
543 expecting_operand = false;
544 }
545 (MathBox::Symbol(name), Some(MathBox::Parens(inner)))
546 if is_function_name(name.as_str()) =>
547 {
548 let args = parse_paren_args(inner)?;
549 push_operand(Expr::Func(name.clone(), args));
550 i += 1; // Consume the following parens node.
551 }
552 (other, _) => {
553 push_operand(to_expr(other)?);
554 }
555 }
556 i += 1;
557 }
558
559 if operands.is_empty() {
560 return Err(ConvertError::EmptySlot);
561 }
562
563 // If there are adjacent operands but no explicit operators, treat as multiplication.
564 if operators.is_empty() {
565 return if operands.len() == 1 {
566 Ok(operands.pop().unwrap())
567 } else {
568 Ok(Expr::Mul(operands))
569 };
570 }
571
572 // Maintain operand/operator alignment by appending implicit multiplications
573 // when needed due malformed rows.
574 while operators.len() + 1 < operands.len() {
575 operators.push(Operator::Mul);
576 }
577
578 if operators.len() + 1 != operands.len() {
579 return Err(ConvertError::Unsupported);
580 }
581
582 // First pass: *, /
583 let mut idx = 0;
584 while idx < operators.len() {
585 match operators[idx] {
586 Operator::Mul | Operator::Div => {
587 let lhs = operands[idx].clone();
588 let rhs = operands[idx + 1].clone();
589 let merged = match operators[idx] {
590 Operator::Mul => Expr::Mul(vec![lhs, rhs]),
591 Operator::Div => Expr::Mul(vec![
592 lhs,
593 Expr::Pow(Box::new(rhs), Box::new(Expr::Integer(-1))),
594 ]),
595 _ => unreachable!(),
596 };
597 operands[idx] = merged;
598 operands.remove(idx + 1);
599 operators.remove(idx);
600 }
601 _ => idx += 1,
602 }
603 }
604
605 // Second pass: +, -, =
606 let mut result = operands[0].clone();
607 for (i, op) in operators.iter().enumerate() {
608 let rhs = operands[i + 1].clone();
609 result = match op {
610 Operator::Add => Expr::Add(vec![result, rhs]),
611 Operator::Sub => Expr::Add(vec![result, Expr::Neg(Box::new(rhs))]),
612 Operator::Eq => Expr::Equation(Box::new(result), Box::new(rhs)),
613 Operator::Lt => Expr::Inequality {
614 lhs: Box::new(result),
615 op: garcalc_cas::expr::InequalityOp::Lt,
616 rhs: Box::new(rhs),
617 },
618 Operator::Gt => Expr::Inequality {
619 lhs: Box::new(result),
620 op: garcalc_cas::expr::InequalityOp::Gt,
621 rhs: Box::new(rhs),
622 },
623 Operator::Le => Expr::Inequality {
624 lhs: Box::new(result),
625 op: garcalc_cas::expr::InequalityOp::Le,
626 rhs: Box::new(rhs),
627 },
628 Operator::Ge => Expr::Inequality {
629 lhs: Box::new(result),
630 op: garcalc_cas::expr::InequalityOp::Ge,
631 rhs: Box::new(rhs),
632 },
633 Operator::Ne => Expr::Inequality {
634 lhs: Box::new(result),
635 op: garcalc_cas::expr::InequalityOp::Ne,
636 rhs: Box::new(rhs),
637 },
638 Operator::Mul | Operator::Div | Operator::Comma => {
639 return Err(ConvertError::Unsupported);
640 }
641 };
642 }
643
644 Ok(result)
645 }
646
647 fn parse_paren_args(inner: &MathBox) -> Result<Vec<Expr>, ConvertError> {
648 match inner {
649 MathBox::Row(items) => {
650 let mut args = Vec::new();
651 let mut current = Vec::new();
652
653 for item in items {
654 if matches!(item, MathBox::Operator(Operator::Comma)) {
655 if current.is_empty() {
656 return Err(ConvertError::MissingField(
657 "function argument before comma".to_string(),
658 ));
659 }
660 args.push(to_expr(&MathBox::Row(std::mem::take(&mut current)))?);
661 } else {
662 current.push(item.clone());
663 }
664 }
665
666 if current.is_empty() && args.is_empty() {
667 return Err(ConvertError::EmptySlot);
668 }
669 if !current.is_empty() {
670 args.push(to_expr(&MathBox::Row(current))?);
671 }
672
673 Ok(args)
674 }
675 _ => Ok(vec![to_expr(inner)?]),
676 }
677 }
678
679 fn is_function_name(name: &str) -> bool {
680 matches!(
681 name,
682 "sin"
683 | "cos"
684 | "tan"
685 | "cot"
686 | "sec"
687 | "csc"
688 | "asin"
689 | "acos"
690 | "atan"
691 | "sinh"
692 | "cosh"
693 | "tanh"
694 | "asinh"
695 | "acosh"
696 | "atanh"
697 | "ln"
698 | "log"
699 | "log10"
700 | "log2"
701 | "exp"
702 | "sqrt"
703 | "cbrt"
704 | "abs"
705 | "floor"
706 | "ceil"
707 | "round"
708 | "trunc"
709 | "sign"
710 | "gamma"
711 | "factorial"
712 | "diff"
713 | "derivative"
714 | "integrate"
715 | "integral"
716 | "limit"
717 | "lim"
718 | "solve"
719 | "sum"
720 | "product"
721 | "prod"
722 | "simplify"
723 | "expand"
724 | "factor"
725 | "substitute"
726 | "subs"
727 | "min"
728 | "max"
729 | "gcd"
730 | "lcm"
731 | "det"
732 | "determinant"
733 | "inv"
734 | "inverse"
735 | "transpose"
736 | "trace"
737 | "matmul"
738 | "identity"
739 )
740 }
741
742 fn format_float(f: f64) -> String {
743 if f == f.trunc() && f.abs() < 1e15 {
744 format!("{:.0}", f)
745 } else {
746 format!("{:.10}", f)
747 .trim_end_matches('0')
748 .trim_end_matches('.')
749 .to_string()
750 }
751 }
752
753 fn is_negative(expr: &Expr) -> bool {
754 match expr {
755 Expr::Integer(n) => *n < 0,
756 Expr::Float(f) => *f < 0.0,
757 Expr::Neg(_) => true,
758 Expr::Mul(factors) => factors.first().map(is_negative).unwrap_or(false),
759 _ => false,
760 }
761 }
762
763 fn negate(expr: &Expr) -> Expr {
764 match expr {
765 Expr::Integer(n) => Expr::Integer(-n),
766 Expr::Float(f) => Expr::Float(-f),
767 Expr::Neg(inner) => inner.as_ref().clone(),
768 _ => Expr::Neg(Box::new(expr.clone())),
769 }
770 }
771
772 fn is_reciprocal(expr: &Expr) -> bool {
773 matches!(expr, Expr::Pow(_, exp) if matches!(exp.as_ref(), Expr::Integer(-1) | Expr::Neg(_)))
774 }
775
776 fn extract_base_of_reciprocal(expr: &Expr) -> Expr {
777 if let Expr::Pow(base, _) = expr {
778 base.as_ref().clone()
779 } else {
780 expr.clone()
781 }
782 }
783
784 fn is_multiplicative_identity(expr: &Expr) -> bool {
785 match expr {
786 Expr::Integer(1) => true,
787 Expr::Rational(r) if r.den != 0 && r.num == r.den => true,
788 Expr::Float(f) if *f == 1.0 => true,
789 _ => false,
790 }
791 }
792
793 fn mul_factor_needs_parens(expr: &Expr) -> bool {
794 matches!(
795 expr,
796 Expr::Add(_) | Expr::Equation(_, _) | Expr::Inequality { .. }
797 )
798 }
799
800 fn render_mul_factor(expr: &Expr) -> MathBox {
801 let rendered = from_expr(expr);
802 if mul_factor_needs_parens(expr) {
803 MathBox::Parens(Box::new(rendered))
804 } else {
805 rendered
806 }
807 }
808
809 fn from_mul_factors(factors: &[Expr]) -> MathBox {
810 let mut items = Vec::new();
811 for (i, factor) in factors.iter().enumerate() {
812 if i > 0 {
813 items.push(MathBox::Operator(Operator::Mul));
814 }
815 items.push(render_mul_factor(factor));
816 }
817 if items.len() == 1 {
818 items.pop().unwrap()
819 } else {
820 MathBox::Row(items)
821 }
822 }
823
824 fn factorial_arg_needs_parens(arg: &Expr) -> bool {
825 !matches!(
826 arg,
827 Expr::Integer(_) | Expr::Rational(_) | Expr::Float(_) | Expr::Symbol(_) | Expr::Func(_, _)
828 )
829 }
830
831 #[cfg(test)]
832 mod tests {
833 use super::*;
834
835 #[test]
836 fn test_number_conversion() {
837 let mb = MathBox::Number("42".to_string());
838 let expr = to_expr(&mb).unwrap();
839 assert!(matches!(expr, Expr::Integer(42)));
840
841 let back = from_expr(&expr);
842 assert!(matches!(back, MathBox::Number(s) if s == "42"));
843 }
844
845 #[test]
846 fn test_fraction_conversion() {
847 let mb = MathBox::Fraction {
848 num: Box::new(MathBox::Number("1".to_string())),
849 den: Box::new(MathBox::Number("2".to_string())),
850 };
851
852 let expr = to_expr(&mb).unwrap();
853 // Should be 1 * 2^-1
854 assert!(matches!(expr, Expr::Mul(_)));
855 }
856
857 #[test]
858 fn test_rational_to_fraction() {
859 let expr = Expr::Rational(Rational::new(3, 4));
860 let mb = from_expr(&expr);
861
862 assert!(matches!(mb, MathBox::Fraction { .. }));
863 }
864
865 #[test]
866 fn test_factorial_renders_as_postfix_bang() {
867 let expr = Expr::Func(
868 "factorial".to_string(),
869 vec![Expr::Symbol(Symbol::new("n"))],
870 );
871 let mb = from_expr(&expr);
872
873 if let MathBox::Row(items) = mb {
874 assert_eq!(items.len(), 2);
875 assert!(matches!(&items[0], MathBox::Symbol(s) if s == "n"));
876 assert!(matches!(&items[1], MathBox::Symbol(s) if s == "!"));
877 } else {
878 panic!("expected row for factorial rendering");
879 }
880 }
881
882 #[test]
883 fn test_factorial_wraps_complex_arg_in_parens() {
884 let expr = Expr::Func(
885 "factorial".to_string(),
886 vec![Expr::Add(vec![
887 Expr::Symbol(Symbol::new("n")),
888 Expr::Integer(1),
889 ])],
890 );
891 let mb = from_expr(&expr);
892
893 if let MathBox::Row(items) = mb {
894 assert!(matches!(&items[0], MathBox::Parens(_)));
895 assert!(matches!(&items[1], MathBox::Symbol(s) if s == "!"));
896 } else {
897 panic!("expected row for factorial rendering");
898 }
899 }
900
901 #[test]
902 fn test_row_postfix_factorial_converts_to_expr() {
903 let mb = MathBox::Row(vec![
904 MathBox::Symbol("n".to_string()),
905 MathBox::Symbol("!".to_string()),
906 ]);
907
908 let expr = to_expr(&mb).unwrap();
909 assert_eq!(
910 expr,
911 Expr::Func(
912 "factorial".to_string(),
913 vec![Expr::Symbol(Symbol::new("n"))]
914 )
915 );
916 }
917
918 #[test]
919 fn test_row_implicit_multiplication_converts_to_mul() {
920 let mb = MathBox::Row(vec![
921 MathBox::Number("2".to_string()),
922 MathBox::Symbol("x".to_string()),
923 ]);
924
925 let expr = to_expr(&mb).unwrap();
926 assert_eq!(
927 expr,
928 Expr::Mul(vec![Expr::Integer(2), Expr::Symbol(Symbol::new("x"))])
929 );
930 }
931
932 #[test]
933 fn test_row_symbol_parens_known_function_converts_to_func() {
934 let mb = MathBox::Row(vec![
935 MathBox::Symbol("sin".to_string()),
936 MathBox::Parens(Box::new(MathBox::Symbol("x".to_string()))),
937 ]);
938
939 let expr = to_expr(&mb).unwrap();
940 assert_eq!(
941 expr,
942 Expr::Func("sin".to_string(), vec![Expr::Symbol(Symbol::new("x"))])
943 );
944 }
945
946 #[test]
947 fn test_row_symbol_parens_unknown_name_stays_multiplication() {
948 let mb = MathBox::Row(vec![
949 MathBox::Symbol("f".to_string()),
950 MathBox::Parens(Box::new(MathBox::Symbol("x".to_string()))),
951 ]);
952
953 let expr = to_expr(&mb).unwrap();
954 assert_eq!(
955 expr,
956 Expr::Mul(vec![
957 Expr::Symbol(Symbol::new("f")),
958 Expr::Symbol(Symbol::new("x"))
959 ])
960 );
961 }
962
963 #[test]
964 fn test_from_expr_fraction_drops_unity_factor_in_numerator() {
965 let expr = Expr::Mul(vec![
966 Expr::Integer(1),
967 Expr::Symbol(Symbol::new("n")),
968 Expr::Pow(Box::new(Expr::Integer(2)), Box::new(Expr::Integer(-1))),
969 ]);
970
971 let mb = from_expr(&expr);
972 if let MathBox::Fraction { num, den } = mb {
973 assert!(!matches!(num.as_ref(), MathBox::Number(s) if s == "1"));
974 assert!(matches!(den.as_ref(), MathBox::Number(s) if s == "2"));
975 } else {
976 panic!("expected fraction");
977 }
978 }
979
980 #[test]
981 fn test_from_expr_mul_wraps_additive_factor_with_parens() {
982 let expr = Expr::Mul(vec![
983 Expr::Symbol(Symbol::new("n")),
984 Expr::Add(vec![Expr::Integer(1), Expr::Symbol(Symbol::new("n"))]),
985 ]);
986
987 let mb = from_expr(&expr);
988 if let MathBox::Row(items) = mb {
989 assert_eq!(items.len(), 3);
990 assert!(matches!(&items[0], MathBox::Symbol(s) if s == "n"));
991 assert!(matches!(&items[1], MathBox::Operator(Operator::Mul)));
992 assert!(matches!(&items[2], MathBox::Parens(_)));
993 } else {
994 panic!("expected multiplication row");
995 }
996 }
997
998 #[test]
999 fn test_from_expr_fraction_numerator_hides_one_and_wraps_additive_factor() {
1000 let expr = Expr::Mul(vec![
1001 Expr::Integer(1),
1002 Expr::Symbol(Symbol::new("n")),
1003 Expr::Add(vec![Expr::Integer(1), Expr::Symbol(Symbol::new("n"))]),
1004 Expr::Pow(Box::new(Expr::Integer(2)), Box::new(Expr::Integer(-1))),
1005 ]);
1006
1007 let mb = from_expr(&expr);
1008 if let MathBox::Fraction { num, den } = mb {
1009 assert!(matches!(den.as_ref(), MathBox::Number(s) if s == "2"));
1010 if let MathBox::Row(items) = num.as_ref() {
1011 assert_eq!(items.len(), 3);
1012 assert!(matches!(&items[0], MathBox::Symbol(s) if s == "n"));
1013 assert!(matches!(&items[1], MathBox::Operator(Operator::Mul)));
1014 assert!(matches!(&items[2], MathBox::Parens(_)));
1015 } else {
1016 panic!("expected row in fraction numerator");
1017 }
1018 } else {
1019 panic!("expected fraction");
1020 }
1021 }
1022 }
1023