Rust · 68045 bytes Raw Blame History
1 //! Expression evaluation
2 //!
3 //! Supports both numeric and symbolic evaluation modes.
4
5 use std::collections::BTreeSet;
6 use std::collections::HashMap;
7 use std::f64::consts::{E, PI};
8
9 use crate::error::{CasError, Result};
10 use crate::expr::{Expr, Rational, Sign, Symbol};
11 use crate::symbolic::{Differentiator, Factorer, Integrator, Limits, Simplifier, Solver};
12
13 /// Variable bindings for evaluation
14 pub type Environment = HashMap<String, Expr>;
15
16 /// Expression evaluator
17 pub struct Evaluator {
18 env: Environment,
19 /// If true, try to keep results symbolic when possible
20 exact_mode: bool,
21 /// Angle mode for trig functions
22 angle_mode: AngleMode,
23 }
24
25 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
26 pub enum AngleMode {
27 Radians,
28 Degrees,
29 }
30
31 impl Default for AngleMode {
32 fn default() -> Self {
33 Self::Radians
34 }
35 }
36
37 impl Default for Evaluator {
38 fn default() -> Self {
39 Self::new()
40 }
41 }
42
43 impl Evaluator {
44 pub fn new() -> Self {
45 Self {
46 env: Environment::new(),
47 exact_mode: false,
48 angle_mode: AngleMode::Radians,
49 }
50 }
51
52 pub fn with_exact_mode(mut self, exact: bool) -> Self {
53 self.exact_mode = exact;
54 self
55 }
56
57 pub fn with_angle_mode(mut self, mode: AngleMode) -> Self {
58 self.angle_mode = mode;
59 self
60 }
61
62 pub fn set_var(&mut self, name: impl Into<String>, value: Expr) {
63 self.env.insert(name.into(), value);
64 }
65
66 pub fn get_var(&self, name: &str) -> Option<&Expr> {
67 self.env.get(name)
68 }
69
70 pub fn clear_vars(&mut self) {
71 self.env.clear();
72 }
73
74 /// Evaluate an expression to a numeric result
75 pub fn eval(&self, expr: &Expr) -> Result<Expr> {
76 match expr {
77 Expr::Integer(n) => Ok(Expr::Integer(*n)),
78 Expr::Rational(r) => Ok(Expr::Rational(*r)),
79 Expr::Float(x) => Ok(Expr::Float(*x)),
80 Expr::Complex(re, im) => Ok(Expr::Complex(*re, *im)),
81
82 Expr::Symbol(sym) => {
83 // Check for constants
84 match sym.as_str() {
85 "pi" => Ok(Expr::Float(PI)),
86 "e" => Ok(Expr::Float(E)),
87 _ => {
88 // Check environment
89 if let Some(value) = self.env.get(sym.as_str()) {
90 self.eval(value)
91 } else if self.exact_mode {
92 // In exact mode, keep undefined symbols
93 Ok(expr.clone())
94 } else {
95 Err(CasError::UndefinedVariable(sym.to_string()))
96 }
97 }
98 }
99 }
100
101 Expr::Neg(e) => {
102 let val = self.eval(e)?;
103 self.negate(&val)
104 }
105
106 Expr::Add(terms) => {
107 let mut sum = Expr::integer(0);
108 for term in terms {
109 let val = self.eval(term)?;
110 sum = self.add(&sum, &val)?;
111 }
112 Ok(sum)
113 }
114
115 Expr::Mul(factors) => {
116 let mut product = Expr::integer(1);
117 for factor in factors {
118 let val = self.eval(factor)?;
119 product = self.multiply(&product, &val)?;
120 }
121 Ok(product)
122 }
123
124 Expr::Pow(base, exp) => {
125 let base_val = self.eval(base)?;
126 let exp_val = self.eval(exp)?;
127 self.power(&base_val, &exp_val)
128 }
129
130 Expr::Func(name, args) => {
131 // Symbolic functions operate on unevaluated expressions
132 match name.as_str() {
133 "diff" | "derivative" | "integrate" | "integral" | "limit" | "lim"
134 | "solve" | "simplify" | "expand" | "factor" | "substitute" | "subs" => {
135 self.call_function(name, args)
136 }
137 _ => {
138 let evaluated_args: Result<Vec<_>> =
139 args.iter().map(|a| self.eval(a)).collect();
140 self.call_function(name, &evaluated_args?)
141 }
142 }
143 }
144
145 Expr::Equation(lhs, rhs) => {
146 let lhs_val = self.eval(lhs)?;
147 let rhs_val = self.eval(rhs)?;
148 Ok(Expr::Equation(Box::new(lhs_val), Box::new(rhs_val)))
149 }
150
151 Expr::Vector(elems) => {
152 let evaluated: Result<Vec<_>> = elems.iter().map(|e| self.eval(e)).collect();
153 Ok(Expr::Vector(evaluated?))
154 }
155
156 Expr::Matrix(rows) => {
157 let evaluated: Result<Vec<Vec<_>>> = rows
158 .iter()
159 .map(|row| row.iter().map(|e| self.eval(e)).collect())
160 .collect();
161 Ok(Expr::Matrix(evaluated?))
162 }
163
164 // Symbolic operations - perform symbolic computation then try to evaluate
165 Expr::Derivative {
166 expr: inner,
167 var,
168 order,
169 } => {
170 let result = Differentiator::diff_n(inner, var, *order)?;
171 let simplified = Simplifier::simplify(&result);
172 // Try to evaluate the result
173 if simplified.contains_var(var) {
174 // Still has variable - return symbolic result
175 Ok(simplified)
176 } else {
177 self.eval(&simplified)
178 }
179 }
180
181 Expr::Integral {
182 expr: inner,
183 var,
184 lower,
185 upper,
186 } => {
187 if let (Some(l), Some(u)) = (lower, upper) {
188 // Definite integral - try symbolic antiderivative first.
189 let result = Integrator::integrate_definite(inner, var, l, u)?;
190 let simplified = Simplifier::simplify(&result);
191 if Self::is_unevaluated_definite_integral(&simplified) {
192 // If no closed form is available, fall back to numerical quadrature.
193 match self.eval_definite_integral_numeric(inner, var, l, u) {
194 Ok(value) => Ok(value),
195 Err(_) => Ok(simplified),
196 }
197 } else {
198 self.eval(&simplified)
199 }
200 } else {
201 // Indefinite integral - return symbolic result
202 let result = Integrator::integrate(inner, var)?;
203 Ok(Simplifier::simplify(&result))
204 }
205 }
206
207 Expr::Limit {
208 expr: inner,
209 var,
210 point,
211 direction,
212 } => {
213 let result = Limits::limit(inner, var, point, *direction)?;
214 let simplified = Simplifier::simplify(&result);
215 self.eval(&simplified)
216 }
217
218 Expr::Sum {
219 expr: inner,
220 var,
221 lower,
222 upper,
223 } => self.eval_sum(inner, var, lower, upper),
224
225 Expr::Product {
226 expr: inner,
227 var,
228 lower,
229 upper,
230 } => self.eval_product(inner, var, lower, upper),
231
232 Expr::Inequality { lhs, op, rhs } => {
233 let lhs_val = self.eval(lhs)?;
234 let rhs_val = self.eval(rhs)?;
235 Ok(Expr::Inequality {
236 lhs: Box::new(lhs_val),
237 op: *op,
238 rhs: Box::new(rhs_val),
239 })
240 }
241
242 Expr::Undefined => Ok(Expr::Undefined),
243 Expr::Infinity(sign) => Ok(Expr::Infinity(*sign)),
244 }
245 }
246
247 fn to_f64(&self, expr: &Expr) -> Result<f64> {
248 match expr {
249 Expr::Integer(n) => Ok(*n as f64),
250 Expr::Rational(r) => Ok(r.to_f64()),
251 Expr::Float(x) => Ok(*x),
252 _ => Err(CasError::Type(format!("expected number, got {expr}"))),
253 }
254 }
255
256 fn negate(&self, expr: &Expr) -> Result<Expr> {
257 match expr {
258 Expr::Integer(n) => Ok(Expr::Integer(-n)),
259 Expr::Rational(r) => Ok(Expr::Rational(Rational::new(-r.num, r.den))),
260 Expr::Float(x) => Ok(Expr::Float(-x)),
261 Expr::Complex(re, im) => Ok(Expr::Complex(-re, -im)),
262 Expr::Infinity(Sign::Positive) => Ok(Expr::Infinity(Sign::Negative)),
263 Expr::Infinity(Sign::Negative) => Ok(Expr::Infinity(Sign::Positive)),
264 _ => Ok(Expr::neg(expr.clone())),
265 }
266 }
267
268 fn add(&self, a: &Expr, b: &Expr) -> Result<Expr> {
269 match (a, b) {
270 (Expr::Integer(x), Expr::Integer(y)) => Ok(Expr::Integer(x + y)),
271 (Expr::Float(x), Expr::Float(y)) => Ok(Expr::Float(x + y)),
272 (Expr::Integer(x), Expr::Float(y)) | (Expr::Float(y), Expr::Integer(x)) => {
273 Ok(Expr::Float(*x as f64 + y))
274 }
275 (Expr::Complex(r1, i1), Expr::Complex(r2, i2)) => Ok(Expr::Complex(r1 + r2, i1 + i2)),
276 (Expr::Complex(re, im), Expr::Float(x)) | (Expr::Float(x), Expr::Complex(re, im)) => {
277 Ok(Expr::Complex(re + x, *im))
278 }
279 (Expr::Complex(re, im), Expr::Integer(n))
280 | (Expr::Integer(n), Expr::Complex(re, im)) => Ok(Expr::Complex(re + *n as f64, *im)),
281 (Expr::Rational(r1), Expr::Rational(r2)) => {
282 let num = r1.num * r2.den + r2.num * r1.den;
283 let den = r1.den * r2.den;
284 Ok(Expr::Rational(Rational::new(num, den)))
285 }
286 (Expr::Rational(r), Expr::Integer(n)) | (Expr::Integer(n), Expr::Rational(r)) => {
287 let num = r.num + n * r.den;
288 Ok(Expr::Rational(Rational::new(num, r.den)))
289 }
290 _ => {
291 // Try converting to floats
292 if let (Ok(x), Ok(y)) = (self.to_f64(a), self.to_f64(b)) {
293 Ok(Expr::Float(x + y))
294 } else if self.exact_mode {
295 Ok(Expr::add(vec![a.clone(), b.clone()]))
296 } else {
297 Err(CasError::Type(format!("cannot add {a} and {b}")))
298 }
299 }
300 }
301 }
302
303 fn multiply(&self, a: &Expr, b: &Expr) -> Result<Expr> {
304 if a.is_one() {
305 return Ok(b.clone());
306 }
307 if b.is_one() {
308 return Ok(a.clone());
309 }
310 if a.is_negative_one() {
311 return self.negate(b);
312 }
313 if b.is_negative_one() {
314 return self.negate(a);
315 }
316
317 match (a, b) {
318 (Expr::Integer(x), Expr::Integer(y)) => Ok(Expr::Integer(x * y)),
319 (Expr::Float(x), Expr::Float(y)) => Ok(Expr::Float(x * y)),
320 (Expr::Integer(x), Expr::Float(y)) | (Expr::Float(y), Expr::Integer(x)) => {
321 Ok(Expr::Float(*x as f64 * y))
322 }
323 (Expr::Complex(r1, i1), Expr::Complex(r2, i2)) => {
324 // (a+bi)(c+di) = (ac-bd) + (ad+bc)i
325 Ok(Expr::Complex(r1 * r2 - i1 * i2, r1 * i2 + i1 * r2))
326 }
327 (Expr::Complex(re, im), Expr::Float(x)) | (Expr::Float(x), Expr::Complex(re, im)) => {
328 Ok(Expr::Complex(re * x, im * x))
329 }
330 (Expr::Complex(re, im), Expr::Integer(n))
331 | (Expr::Integer(n), Expr::Complex(re, im)) => {
332 let n = *n as f64;
333 Ok(Expr::Complex(re * n, im * n))
334 }
335 (Expr::Rational(r1), Expr::Rational(r2)) => Ok(Expr::Rational(Rational::new(
336 r1.num * r2.num,
337 r1.den * r2.den,
338 ))),
339 (Expr::Rational(r), Expr::Integer(n)) | (Expr::Integer(n), Expr::Rational(r)) => {
340 Ok(Expr::Rational(Rational::new(r.num * n, r.den)))
341 }
342 _ => {
343 if let (Ok(x), Ok(y)) = (self.to_f64(a), self.to_f64(b)) {
344 Ok(Expr::Float(x * y))
345 } else if self.exact_mode {
346 Ok(Expr::mul(vec![a.clone(), b.clone()]))
347 } else {
348 Err(CasError::Type(format!("cannot multiply {a} and {b}")))
349 }
350 }
351 }
352 }
353
354 fn power(&self, base: &Expr, exp: &Expr) -> Result<Expr> {
355 // Special cases
356 if exp.is_zero() {
357 return Ok(Expr::integer(1));
358 }
359 if exp.is_one() {
360 return Ok(base.clone());
361 }
362 if base.is_zero() {
363 return Ok(Expr::integer(0));
364 }
365 if base.is_one() {
366 return Ok(Expr::integer(1));
367 }
368
369 match (base, exp) {
370 (Expr::Integer(b), Expr::Integer(e)) => {
371 if *e >= 0 {
372 Ok(Expr::Integer(b.pow(*e as u32)))
373 } else {
374 // Negative exponent -> rational or float
375 let denom = b.pow((-e) as u32);
376 if self.exact_mode {
377 Ok(Expr::Rational(Rational::new(1, denom)))
378 } else {
379 Ok(Expr::Float(1.0 / denom as f64))
380 }
381 }
382 }
383 (Expr::Float(b), Expr::Integer(e)) => Ok(Expr::Float(b.powi(*e as i32))),
384 (Expr::Float(b), Expr::Float(e)) => Ok(Expr::Float(b.powf(*e))),
385 (Expr::Integer(b), Expr::Float(e)) => Ok(Expr::Float((*b as f64).powf(*e))),
386 _ => {
387 if let (Ok(b), Ok(e)) = (self.to_f64(base), self.to_f64(exp)) {
388 Ok(Expr::Float(b.powf(e)))
389 } else if self.exact_mode {
390 Ok(Expr::pow(base.clone(), exp.clone()))
391 } else {
392 Err(CasError::Type(format!("cannot compute {base}^{exp}")))
393 }
394 }
395 }
396 }
397
398 fn call_function(&self, name: &str, args: &[Expr]) -> Result<Expr> {
399 // Get numeric argument if single-arg function
400 let arg = if args.len() == 1 {
401 self.to_f64(&args[0]).ok()
402 } else {
403 None
404 };
405
406 // Angle conversion for trig functions
407 let angle = |x: f64| match self.angle_mode {
408 AngleMode::Radians => x,
409 AngleMode::Degrees => x.to_radians(),
410 };
411
412 let from_angle = |x: f64| match self.angle_mode {
413 AngleMode::Radians => x,
414 AngleMode::Degrees => x.to_degrees(),
415 };
416
417 match (name, args.len(), arg) {
418 // Trigonometric
419 ("sin", 1, Some(x)) => Ok(Expr::Float(angle(x).sin())),
420 ("cos", 1, Some(x)) => Ok(Expr::Float(angle(x).cos())),
421 ("tan", 1, Some(x)) => Ok(Expr::Float(angle(x).tan())),
422 ("asin", 1, Some(x)) => Ok(Expr::Float(from_angle(x.asin()))),
423 ("acos", 1, Some(x)) => Ok(Expr::Float(from_angle(x.acos()))),
424 ("atan", 1, Some(x)) => Ok(Expr::Float(from_angle(x.atan()))),
425 ("sinh", 1, Some(x)) => Ok(Expr::Float(x.sinh())),
426 ("cosh", 1, Some(x)) => Ok(Expr::Float(x.cosh())),
427 ("tanh", 1, Some(x)) => Ok(Expr::Float(x.tanh())),
428 ("asinh", 1, Some(x)) => Ok(Expr::Float(x.asinh())),
429 ("acosh", 1, Some(x)) => Ok(Expr::Float(x.acosh())),
430 ("atanh", 1, Some(x)) => Ok(Expr::Float(x.atanh())),
431
432 // Exponential/logarithmic
433 ("exp", 1, Some(x)) => Ok(Expr::Float(x.exp())),
434 ("ln", 1, Some(x)) => Ok(Expr::Float(x.ln())),
435 ("log", 1, Some(x)) => Ok(Expr::Float(x.log10())),
436 ("log10", 1, Some(x)) => Ok(Expr::Float(x.log10())),
437 ("log2", 1, Some(x)) => Ok(Expr::Float(x.log2())),
438
439 // Roots
440 ("sqrt", 1, Some(x)) => {
441 if x >= 0.0 {
442 Ok(Expr::Float(x.sqrt()))
443 } else {
444 Ok(Expr::Complex(0.0, (-x).sqrt()))
445 }
446 }
447 ("cbrt", 1, Some(x)) => Ok(Expr::Float(x.cbrt())),
448
449 // Other
450 ("abs", 1, Some(x)) => Ok(Expr::Float(x.abs())),
451 ("floor", 1, Some(x)) => Ok(Expr::Integer(x.floor() as i64)),
452 ("ceil", 1, Some(x)) => Ok(Expr::Integer(x.ceil() as i64)),
453 ("round", 1, Some(x)) => Ok(Expr::Integer(x.round() as i64)),
454 ("sign", 1, Some(x)) => Ok(Expr::Integer(if x > 0.0 {
455 1
456 } else if x < 0.0 {
457 -1
458 } else {
459 0
460 })),
461
462 // Factorial
463 ("factorial", 1, _) => {
464 if let Expr::Integer(n) = &args[0] {
465 if *n < 0 {
466 Err(CasError::Domain("factorial of negative number".to_string()))
467 } else if *n > 20 {
468 // Use Stirling's approximation for large n
469 Ok(Expr::Float(gamma(*n as f64 + 1.0)))
470 } else {
471 Ok(Expr::Integer(factorial(*n as u64) as i64))
472 }
473 } else if let Some(x) = arg {
474 Ok(Expr::Float(gamma(x + 1.0)))
475 } else {
476 Err(CasError::Type("factorial requires a number".to_string()))
477 }
478 }
479
480 // Two-argument functions
481 ("atan2", 2, _) => {
482 let y = self.to_f64(&args[0])?;
483 let x = self.to_f64(&args[1])?;
484 Ok(Expr::Float(from_angle(y.atan2(x))))
485 }
486 ("pow", 2, _) => self.power(&args[0], &args[1]),
487 ("mod", 2, _) => {
488 let a = self.to_f64(&args[0])?;
489 let b = self.to_f64(&args[1])?;
490 Ok(Expr::Float(a % b))
491 }
492 ("min", _, _) if !args.is_empty() => {
493 let values: Result<Vec<f64>> = args.iter().map(|a| self.to_f64(a)).collect();
494 let min = values?.into_iter().fold(f64::INFINITY, |a, b| a.min(b));
495 Ok(Expr::Float(min))
496 }
497 ("max", _, _) if !args.is_empty() => {
498 let values: Result<Vec<f64>> = args.iter().map(|a| self.to_f64(a)).collect();
499 let max = values?.into_iter().fold(f64::NEG_INFINITY, |a, b| a.max(b));
500 Ok(Expr::Float(max))
501 }
502 ("gcd", 2, _) => {
503 if let (Expr::Integer(a), Expr::Integer(b)) = (&args[0], &args[1]) {
504 Ok(Expr::Integer(gcd(*a, *b)))
505 } else {
506 Err(CasError::Type("gcd requires integers".to_string()))
507 }
508 }
509 ("lcm", 2, _) => {
510 if let (Expr::Integer(a), Expr::Integer(b)) = (&args[0], &args[1]) {
511 Ok(Expr::Integer(lcm(*a, *b)))
512 } else {
513 Err(CasError::Type("lcm requires integers".to_string()))
514 }
515 }
516
517 // Symbolic operations
518 ("diff", 2, _) | ("derivative", 2, _) => {
519 // diff(expr, var)
520 if let Expr::Symbol(var) = &args[1] {
521 let result = Differentiator::diff(&args[0], var)?;
522 Ok(Simplifier::simplify(&result))
523 } else {
524 Err(CasError::Type(
525 "diff requires variable as second argument".to_string(),
526 ))
527 }
528 }
529 ("diff", 3, _) | ("derivative", 3, _) => {
530 // diff(expr, var, order)
531 if let (Expr::Symbol(var), Expr::Integer(n)) = (&args[1], &args[2]) {
532 let result = Differentiator::diff_n(&args[0], var, *n as u32)?;
533 Ok(Simplifier::simplify(&result))
534 } else {
535 Err(CasError::Type(
536 "diff requires variable and integer order".to_string(),
537 ))
538 }
539 }
540
541 ("integrate", 2, _) | ("integral", 2, _) => {
542 // integrate(expr, var)
543 if let Expr::Symbol(var) = &args[1] {
544 let result = Integrator::integrate(&args[0], var)?;
545 Ok(Simplifier::simplify(&result))
546 } else {
547 Err(CasError::Type(
548 "integrate requires variable as second argument".to_string(),
549 ))
550 }
551 }
552 ("integrate", 4, _) | ("integral", 4, _) => {
553 // integrate(expr, var, lower, upper)
554 if let Expr::Symbol(var) = &args[1] {
555 let result = Integrator::integrate_definite(&args[0], var, &args[2], &args[3])?;
556 let simplified = Simplifier::simplify(&result);
557 if Self::is_unevaluated_definite_integral(&simplified) {
558 match self.eval_definite_integral_numeric(&args[0], var, &args[2], &args[3])
559 {
560 Ok(value) => Ok(value),
561 Err(_) => Ok(simplified),
562 }
563 } else {
564 self.eval(&simplified)
565 }
566 } else {
567 Err(CasError::Type(
568 "integrate requires variable as second argument".to_string(),
569 ))
570 }
571 }
572
573 ("solve", 2, _) => {
574 // solve([eq1, eq2], [x, y]) — system of equations
575 if let (Expr::Vector(equations), Expr::Vector(vars)) = (&args[0], &args[1]) {
576 let var_symbols: std::result::Result<Vec<Symbol>, _> = vars
577 .iter()
578 .map(|v| match v {
579 Expr::Symbol(s) => Ok(s.clone()),
580 _ => Err(CasError::Type(
581 "solve system requires variables as second argument".to_string(),
582 )),
583 })
584 .collect();
585 let var_symbols = var_symbols?;
586 let solutions = Solver::solve_system(equations, &var_symbols)?;
587 // Return as vector of equations: [x = val1, y = val2]
588 let result: Vec<Expr> = solutions
589 .into_iter()
590 .map(|(var, val)| Expr::Equation(
591 Box::new(Expr::Symbol(var)),
592 Box::new(val),
593 ))
594 .collect();
595 Ok(Expr::Vector(result))
596 }
597 // solve(expr, var) or solve(equation, var)
598 else if let Expr::Symbol(var) = &args[1] {
599 let solutions = Solver::solve(&args[0], var)?;
600 if solutions.len() == 1 {
601 Ok(solutions.into_iter().next().unwrap())
602 } else {
603 Ok(Expr::Vector(solutions))
604 }
605 } else {
606 Err(CasError::Type(
607 "solve requires variable as second argument".to_string(),
608 ))
609 }
610 }
611
612 ("sum", 4, _) => {
613 if let Expr::Symbol(var) = &args[1] {
614 self.eval_sum(&args[0], var, &args[2], &args[3])
615 } else {
616 Err(CasError::Type(
617 "sum requires variable as second argument".to_string(),
618 ))
619 }
620 }
621
622 ("product", 4, _) | ("prod", 4, _) => {
623 if let Expr::Symbol(var) = &args[1] {
624 self.eval_product(&args[0], var, &args[2], &args[3])
625 } else {
626 Err(CasError::Type(
627 "product requires variable as second argument".to_string(),
628 ))
629 }
630 }
631
632 ("simplify", 1, _) => Ok(Simplifier::simplify(&args[0])),
633
634 ("expand", 1, _) => Ok(Simplifier::simplify(&Simplifier::expand(&args[0]))),
635
636 ("factor", 1, _) => {
637 let var = Self::infer_primary_var(&args[0])
638 .unwrap_or_else(|| Symbol::new("x"));
639 Ok(Factorer::factor(&args[0], &var))
640 }
641
642 ("factor", 2, _) => {
643 if let Expr::Symbol(var) = &args[1] {
644 Ok(Factorer::factor(&args[0], var))
645 } else {
646 Err(CasError::Type(
647 "factor requires variable as second argument".to_string(),
648 ))
649 }
650 }
651
652 ("substitute", 3, _) | ("subs", 3, _) => {
653 // substitute(expr, var, replacement)
654 if let Expr::Symbol(var) = &args[1] {
655 Ok(Simplifier::simplify(&Simplifier::substitute(
656 &args[0], var, &args[2],
657 )))
658 } else {
659 Err(CasError::Type(
660 "substitute requires variable as second argument".to_string(),
661 ))
662 }
663 }
664
665 ("limit", 3, _) | ("lim", 3, _) => {
666 // limit(expr, var, point)
667 if let Expr::Symbol(var) = &args[1] {
668 let result = Limits::limit(&args[0], var, &args[2], None)?;
669 self.eval(&Simplifier::simplify(&result))
670 } else {
671 Err(CasError::Type(
672 "limit requires variable as second argument".to_string(),
673 ))
674 }
675 }
676 ("limit", 4, _) | ("lim", 4, _) => {
677 // limit(expr, var, point, direction) where direction is "left", "right", "-", "+"
678 if let Expr::Symbol(var) = &args[1] {
679 let direction = match &args[3] {
680 Expr::Symbol(s) if s.as_str() == "left" || s.as_str() == "-" => {
681 Some(crate::expr::LimitDirection::Left)
682 }
683 Expr::Symbol(s) if s.as_str() == "right" || s.as_str() == "+" => {
684 Some(crate::expr::LimitDirection::Right)
685 }
686 _ => None,
687 };
688 let result = Limits::limit(&args[0], var, &args[2], direction)?;
689 self.eval(&Simplifier::simplify(&result))
690 } else {
691 Err(CasError::Type(
692 "limit requires variable as second argument".to_string(),
693 ))
694 }
695 }
696
697 // Matrix operations
698 ("det", 1, _) | ("determinant", 1, _) => {
699 if let Expr::Matrix(rows) = &args[0] {
700 self.matrix_det(rows)
701 } else {
702 Err(CasError::Type("det requires a matrix argument".to_string()))
703 }
704 }
705
706 ("inv", 1, _) | ("inverse", 1, _) => {
707 if let Expr::Matrix(rows) = &args[0] {
708 self.matrix_inv(rows)
709 } else {
710 Err(CasError::Type("inv requires a matrix argument".to_string()))
711 }
712 }
713
714 ("transpose", 1, _) | ("T", 1, _) => {
715 if let Expr::Matrix(rows) = &args[0] {
716 self.matrix_transpose(rows)
717 } else {
718 Err(CasError::Type(
719 "transpose requires a matrix argument".to_string(),
720 ))
721 }
722 }
723
724 ("trace", 1, _) | ("tr", 1, _) => {
725 if let Expr::Matrix(rows) = &args[0] {
726 self.matrix_trace(rows)
727 } else {
728 Err(CasError::Type(
729 "trace requires a matrix argument".to_string(),
730 ))
731 }
732 }
733
734 ("matmul", 2, _) => {
735 if let (Expr::Matrix(a), Expr::Matrix(b)) = (&args[0], &args[1]) {
736 self.matrix_mul(a, b)
737 } else {
738 Err(CasError::Type(
739 "matmul requires two matrix arguments".to_string(),
740 ))
741 }
742 }
743
744 ("identity", 1, Some(n)) => {
745 let n = n as usize;
746 if n == 0 || n > 100 {
747 return Err(CasError::EvaluationError(
748 "identity matrix size must be 1-100".to_string(),
749 ));
750 }
751 let mut rows = Vec::with_capacity(n);
752 for i in 0..n {
753 let mut row = vec![Expr::Integer(0); n];
754 row[i] = Expr::Integer(1);
755 rows.push(row);
756 }
757 Ok(Expr::Matrix(rows))
758 }
759
760 _ => {
761 if self.exact_mode {
762 Ok(Expr::func(name, args.to_vec()))
763 } else {
764 Err(CasError::UndefinedFunction(name.to_string()))
765 }
766 }
767 }
768 }
769
770 /// Compute matrix determinant
771 fn matrix_det(&self, rows: &[Vec<Expr>]) -> Result<Expr> {
772 let n = rows.len();
773 if n == 0 {
774 return Err(CasError::EvaluationError("empty matrix".to_string()));
775 }
776 if rows.iter().any(|r| r.len() != n) {
777 return Err(CasError::EvaluationError(
778 "det requires square matrix".to_string(),
779 ));
780 }
781
782 // Convert to f64 for numerical computation
783 let mut matrix: Vec<Vec<f64>> = Vec::with_capacity(n);
784 for row in rows {
785 let mut num_row = Vec::with_capacity(n);
786 for elem in row {
787 num_row.push(self.to_f64(elem)?);
788 }
789 matrix.push(num_row);
790 }
791
792 // LU decomposition for determinant
793 let det = self.det_lu(&mut matrix, n);
794
795 // Return as integer if close to integer
796 if det.fract().abs() < 1e-10 {
797 Ok(Expr::Integer(det.round() as i64))
798 } else {
799 Ok(Expr::Float(det))
800 }
801 }
802
803 /// LU decomposition determinant
804 fn det_lu(&self, matrix: &mut [Vec<f64>], n: usize) -> f64 {
805 let mut det = 1.0;
806
807 for col in 0..n {
808 // Find pivot
809 let mut max_row = col;
810 for row in (col + 1)..n {
811 if matrix[row][col].abs() > matrix[max_row][col].abs() {
812 max_row = row;
813 }
814 }
815
816 if max_row != col {
817 matrix.swap(col, max_row);
818 det = -det; // Swap changes sign
819 }
820
821 if matrix[col][col].abs() < 1e-15 {
822 return 0.0; // Singular matrix
823 }
824
825 det *= matrix[col][col];
826
827 for row in (col + 1)..n {
828 let factor = matrix[row][col] / matrix[col][col];
829 for j in col..n {
830 matrix[row][j] -= factor * matrix[col][j];
831 }
832 }
833 }
834
835 det
836 }
837
838 /// Compute matrix inverse using Gauss-Jordan elimination
839 fn matrix_inv(&self, rows: &[Vec<Expr>]) -> Result<Expr> {
840 let n = rows.len();
841 if n == 0 {
842 return Err(CasError::EvaluationError("empty matrix".to_string()));
843 }
844 if rows.iter().any(|r| r.len() != n) {
845 return Err(CasError::EvaluationError(
846 "inv requires square matrix".to_string(),
847 ));
848 }
849
850 // Convert to f64
851 let mut aug: Vec<Vec<f64>> = Vec::with_capacity(n);
852 for (i, row) in rows.iter().enumerate() {
853 let mut aug_row = Vec::with_capacity(2 * n);
854 for elem in row {
855 aug_row.push(self.to_f64(elem)?);
856 }
857 // Append identity matrix
858 for j in 0..n {
859 aug_row.push(if i == j { 1.0 } else { 0.0 });
860 }
861 aug.push(aug_row);
862 }
863
864 // Gauss-Jordan elimination
865 for col in 0..n {
866 // Find pivot
867 let mut max_row = col;
868 for row in (col + 1)..n {
869 if aug[row][col].abs() > aug[max_row][col].abs() {
870 max_row = row;
871 }
872 }
873 aug.swap(col, max_row);
874
875 if aug[col][col].abs() < 1e-15 {
876 return Err(CasError::EvaluationError("matrix is singular".to_string()));
877 }
878
879 // Scale pivot row
880 let pivot = aug[col][col];
881 for j in 0..(2 * n) {
882 aug[col][j] /= pivot;
883 }
884
885 // Eliminate column
886 for row in 0..n {
887 if row != col {
888 let factor = aug[row][col];
889 for j in 0..(2 * n) {
890 aug[row][j] -= factor * aug[col][j];
891 }
892 }
893 }
894 }
895
896 // Extract inverse from right half
897 let mut result = Vec::with_capacity(n);
898 for row in &aug {
899 let mut result_row = Vec::with_capacity(n);
900 for j in n..(2 * n) {
901 let val = row[j];
902 if val.fract().abs() < 1e-10 {
903 result_row.push(Expr::Integer(val.round() as i64));
904 } else {
905 result_row.push(Expr::Float(val));
906 }
907 }
908 result.push(result_row);
909 }
910
911 Ok(Expr::Matrix(result))
912 }
913
914 /// Transpose a matrix
915 fn matrix_transpose(&self, rows: &[Vec<Expr>]) -> Result<Expr> {
916 if rows.is_empty() {
917 return Ok(Expr::Matrix(vec![]));
918 }
919 let n_rows = rows.len();
920 let n_cols = rows[0].len();
921
922 let mut result = Vec::with_capacity(n_cols);
923 for j in 0..n_cols {
924 let mut new_row = Vec::with_capacity(n_rows);
925 for row in rows {
926 if j < row.len() {
927 new_row.push(row[j].clone());
928 } else {
929 new_row.push(Expr::Integer(0));
930 }
931 }
932 result.push(new_row);
933 }
934
935 Ok(Expr::Matrix(result))
936 }
937
938 /// Compute trace (sum of diagonal)
939 fn matrix_trace(&self, rows: &[Vec<Expr>]) -> Result<Expr> {
940 let n = rows.len();
941 if n == 0 {
942 return Err(CasError::EvaluationError("empty matrix".to_string()));
943 }
944 if rows.iter().any(|r| r.len() != n) {
945 return Err(CasError::EvaluationError(
946 "trace requires square matrix".to_string(),
947 ));
948 }
949
950 let mut sum = 0.0;
951 for i in 0..n {
952 sum += self.to_f64(&rows[i][i])?;
953 }
954
955 if sum.fract().abs() < 1e-10 {
956 Ok(Expr::Integer(sum.round() as i64))
957 } else {
958 Ok(Expr::Float(sum))
959 }
960 }
961
962 /// Matrix multiplication
963 fn matrix_mul(&self, a: &[Vec<Expr>], b: &[Vec<Expr>]) -> Result<Expr> {
964 if a.is_empty() || b.is_empty() {
965 return Err(CasError::EvaluationError("empty matrix".to_string()));
966 }
967
968 let m = a.len();
969 let n = a[0].len();
970 let p = b[0].len();
971
972 if b.len() != n {
973 return Err(CasError::EvaluationError(format!(
974 "matrix dimensions don't match for multiplication: {}x{} * {}x{}",
975 m,
976 n,
977 b.len(),
978 p
979 )));
980 }
981
982 // Convert to f64
983 let a_num: Vec<Vec<f64>> = a
984 .iter()
985 .map(|row| {
986 row.iter()
987 .map(|e| self.to_f64(e))
988 .collect::<Result<Vec<_>>>()
989 })
990 .collect::<Result<Vec<_>>>()?;
991 let b_num: Vec<Vec<f64>> = b
992 .iter()
993 .map(|row| {
994 row.iter()
995 .map(|e| self.to_f64(e))
996 .collect::<Result<Vec<_>>>()
997 })
998 .collect::<Result<Vec<_>>>()?;
999
1000 let mut result = Vec::with_capacity(m);
1001 for i in 0..m {
1002 let mut row = Vec::with_capacity(p);
1003 for j in 0..p {
1004 let mut sum = 0.0;
1005 for k in 0..n {
1006 sum += a_num[i][k] * b_num[k][j];
1007 }
1008 if sum.fract().abs() < 1e-10 {
1009 row.push(Expr::Integer(sum.round() as i64));
1010 } else {
1011 row.push(Expr::Float(sum));
1012 }
1013 }
1014 result.push(row);
1015 }
1016
1017 Ok(Expr::Matrix(result))
1018 }
1019
1020 fn eval_integer_bound(&self, bound: &Expr) -> Result<i64> {
1021 let value = self.eval(bound)?;
1022 match value {
1023 Expr::Integer(n) => Ok(n),
1024 Expr::Rational(r) if r.den == 1 => Ok(r.num),
1025 Expr::Float(x) => {
1026 if !x.is_finite() {
1027 return Err(CasError::Type("bound must be a finite number".to_string()));
1028 }
1029 let rounded = x.round();
1030 if (x - rounded).abs() < 1e-10
1031 && rounded >= i64::MIN as f64
1032 && rounded <= i64::MAX as f64
1033 {
1034 Ok(rounded as i64)
1035 } else {
1036 Err(CasError::Type(format!(
1037 "bound must be an integer, got {}",
1038 Expr::Float(x)
1039 )))
1040 }
1041 }
1042 _ => Err(CasError::Type(format!(
1043 "bound must be an integer, got {value}"
1044 ))),
1045 }
1046 }
1047
1048 fn is_unevaluated_definite_integral(expr: &Expr) -> bool {
1049 matches!(
1050 expr,
1051 Expr::Integral {
1052 lower: Some(_),
1053 upper: Some(_),
1054 ..
1055 }
1056 )
1057 }
1058
1059 fn eval_definite_integral_numeric(
1060 &self,
1061 body: &Expr,
1062 var: &Symbol,
1063 lower: &Expr,
1064 upper: &Expr,
1065 ) -> Result<Expr> {
1066 let lower_eval = self.eval(lower)?;
1067 let upper_eval = self.eval(upper)?;
1068 let mut a = self.to_f64(&lower_eval)?;
1069 let mut b = self.to_f64(&upper_eval)?;
1070
1071 if !a.is_finite() || !b.is_finite() {
1072 return Err(CasError::Type(
1073 "integral bounds must be finite numbers".to_string(),
1074 ));
1075 }
1076
1077 if (a - b).abs() < 1e-14 {
1078 return Ok(Expr::Integer(0));
1079 }
1080
1081 let mut sign = 1.0;
1082 if a > b {
1083 std::mem::swap(&mut a, &mut b);
1084 sign = -1.0;
1085 }
1086
1087 let mut slices = 64usize;
1088 let mut estimate = self.simpson_integral(body, var, a, b, slices)?;
1089 for _ in 0..8 {
1090 slices *= 2;
1091 let refined = self.simpson_integral(body, var, a, b, slices)?;
1092 if (refined - estimate).abs() <= 1e-10 * (1.0 + refined.abs()) {
1093 return Ok(Self::float_to_expr(sign * refined));
1094 }
1095 estimate = refined;
1096 }
1097
1098 Ok(Self::float_to_expr(sign * estimate))
1099 }
1100
1101 fn simpson_integral(
1102 &self,
1103 body: &Expr,
1104 var: &Symbol,
1105 a: f64,
1106 b: f64,
1107 slices: usize,
1108 ) -> Result<f64> {
1109 if slices == 0 || slices % 2 != 0 {
1110 return Err(CasError::EvaluationError(
1111 "simpson integration requires a positive even number of slices".to_string(),
1112 ));
1113 }
1114
1115 let h = (b - a) / slices as f64;
1116 let mut acc =
1117 self.eval_integrand_point(body, var, a)? + self.eval_integrand_point(body, var, b)?;
1118
1119 for i in 1..slices {
1120 let x = a + i as f64 * h;
1121 let fx = self.eval_integrand_point(body, var, x)?;
1122 if i % 2 == 0 {
1123 acc += 2.0 * fx;
1124 } else {
1125 acc += 4.0 * fx;
1126 }
1127 }
1128
1129 Ok(acc * h / 3.0)
1130 }
1131
1132 fn eval_integrand_point(&self, body: &Expr, var: &Symbol, x: f64) -> Result<f64> {
1133 let substituted = Simplifier::substitute(body, var, &Expr::Float(x));
1134 let evaluated = self.eval(&substituted)?;
1135 let value = self.to_f64(&evaluated)?;
1136 if value.is_finite() {
1137 Ok(value)
1138 } else {
1139 Err(CasError::EvaluationError(format!(
1140 "integrand is not finite at {x}"
1141 )))
1142 }
1143 }
1144
1145 fn float_to_expr(x: f64) -> Expr {
1146 if !x.is_finite() {
1147 return Expr::Float(x);
1148 }
1149 let rounded = x.round();
1150 if (x - rounded).abs() < 1e-10 && rounded >= i64::MIN as f64 && rounded <= i64::MAX as f64 {
1151 Expr::Integer(rounded as i64)
1152 } else {
1153 Expr::Float(x)
1154 }
1155 }
1156
1157 fn eval_sum(
1158 &self,
1159 body: &Expr,
1160 var: &crate::expr::Symbol,
1161 lower: &Expr,
1162 upper: &Expr,
1163 ) -> Result<Expr> {
1164 let inferred_var;
1165 let active_var = if body.contains_var(var) {
1166 var
1167 } else {
1168 inferred_var = Self::infer_iteration_var(body);
1169 inferred_var.as_ref().unwrap_or(var)
1170 };
1171
1172 match self.eval_discrete_series(body, active_var, lower, upper, false) {
1173 Ok(value) => Ok(value),
1174 Err(_err) => {
1175 if let Some(symbolic) = Self::symbolic_sum(body, active_var, lower, upper) {
1176 return Ok(Simplifier::simplify(&symbolic));
1177 }
1178
1179 Ok(Expr::Sum {
1180 expr: Box::new(body.clone()),
1181 var: active_var.clone(),
1182 lower: Box::new(lower.clone()),
1183 upper: Box::new(upper.clone()),
1184 })
1185 }
1186 }
1187 }
1188
1189 fn eval_product(
1190 &self,
1191 body: &Expr,
1192 var: &crate::expr::Symbol,
1193 lower: &Expr,
1194 upper: &Expr,
1195 ) -> Result<Expr> {
1196 let inferred_var;
1197 let active_var = if body.contains_var(var) {
1198 var
1199 } else {
1200 inferred_var = Self::infer_iteration_var(body);
1201 inferred_var.as_ref().unwrap_or(var)
1202 };
1203
1204 match self.eval_discrete_series(body, active_var, lower, upper, true) {
1205 Ok(value) => Ok(value),
1206 Err(_err) => {
1207 if let Some(symbolic) = Self::symbolic_product(body, active_var, lower, upper) {
1208 return Ok(Simplifier::simplify(&symbolic));
1209 }
1210
1211 Ok(Expr::Product {
1212 expr: Box::new(body.clone()),
1213 var: active_var.clone(),
1214 lower: Box::new(lower.clone()),
1215 upper: Box::new(upper.clone()),
1216 })
1217 }
1218 }
1219 }
1220
1221 fn infer_iteration_var(body: &Expr) -> Option<crate::expr::Symbol> {
1222 let mut vars = BTreeSet::new();
1223 Self::collect_symbols(body, &mut vars);
1224 if vars.len() == 1 {
1225 vars.into_iter().next().map(crate::expr::Symbol::new)
1226 } else {
1227 None
1228 }
1229 }
1230
1231 fn infer_primary_var(expr: &Expr) -> Option<crate::expr::Symbol> {
1232 let mut vars = BTreeSet::new();
1233 Self::collect_symbols(expr, &mut vars);
1234 vars.remove("pi");
1235 vars.remove("e");
1236 if vars.len() == 1 {
1237 vars.into_iter().next().map(crate::expr::Symbol::new)
1238 } else {
1239 None
1240 }
1241 }
1242
1243 fn collect_symbols(expr: &Expr, out: &mut BTreeSet<String>) {
1244 match expr {
1245 Expr::Symbol(s) => {
1246 out.insert(s.as_str().to_string());
1247 }
1248 Expr::Neg(inner) => Self::collect_symbols(inner, out),
1249 Expr::Add(terms) | Expr::Mul(terms) | Expr::Vector(terms) => {
1250 for term in terms {
1251 Self::collect_symbols(term, out);
1252 }
1253 }
1254 Expr::Pow(base, exp) => {
1255 Self::collect_symbols(base, out);
1256 Self::collect_symbols(exp, out);
1257 }
1258 Expr::Func(_, args) => {
1259 for arg in args {
1260 Self::collect_symbols(arg, out);
1261 }
1262 }
1263 Expr::Derivative { expr, .. } => Self::collect_symbols(expr, out),
1264 Expr::Integral {
1265 expr, lower, upper, ..
1266 } => {
1267 Self::collect_symbols(expr, out);
1268 if let Some(lo) = lower {
1269 Self::collect_symbols(lo, out);
1270 }
1271 if let Some(hi) = upper {
1272 Self::collect_symbols(hi, out);
1273 }
1274 }
1275 Expr::Limit { expr, point, .. } => {
1276 Self::collect_symbols(expr, out);
1277 Self::collect_symbols(point, out);
1278 }
1279 Expr::Sum {
1280 expr, lower, upper, ..
1281 }
1282 | Expr::Product {
1283 expr, lower, upper, ..
1284 } => {
1285 Self::collect_symbols(expr, out);
1286 Self::collect_symbols(lower, out);
1287 Self::collect_symbols(upper, out);
1288 }
1289 Expr::Equation(lhs, rhs) => {
1290 Self::collect_symbols(lhs, out);
1291 Self::collect_symbols(rhs, out);
1292 }
1293 Expr::Inequality { lhs, rhs, .. } => {
1294 Self::collect_symbols(lhs, out);
1295 Self::collect_symbols(rhs, out);
1296 }
1297 Expr::Matrix(rows) => {
1298 for row in rows {
1299 for elem in row {
1300 Self::collect_symbols(elem, out);
1301 }
1302 }
1303 }
1304 Expr::Integer(_)
1305 | Expr::Rational(_)
1306 | Expr::Float(_)
1307 | Expr::Complex(_, _)
1308 | Expr::Undefined
1309 | Expr::Infinity(_) => {}
1310 }
1311 }
1312
1313 fn symbolic_sum(
1314 body: &Expr,
1315 var: &crate::expr::Symbol,
1316 lower: &Expr,
1317 upper: &Expr,
1318 ) -> Option<Expr> {
1319 // Finite count for symbolic bounds: upper - lower + 1.
1320 let count = Expr::add(vec![
1321 Expr::sub(upper.clone(), lower.clone()),
1322 Expr::Integer(1),
1323 ]);
1324
1325 if !body.contains_var(var) {
1326 return Some(Expr::mul(vec![body.clone(), count]));
1327 }
1328
1329 match body {
1330 Expr::Symbol(s) if s == var => Some(Self::sum_linear(lower, upper)),
1331
1332 Expr::Pow(base, exp) if matches!(base.as_ref(), Expr::Symbol(s) if s == var) => {
1333 match exp.as_ref() {
1334 Expr::Integer(0) => Some(count),
1335 Expr::Integer(1) => Some(Self::sum_linear(lower, upper)),
1336 Expr::Integer(2) => Some(Self::sum_square(lower, upper)),
1337 Expr::Integer(3) => Some(Self::sum_cube(lower, upper)),
1338 _ => None,
1339 }
1340 }
1341
1342 Expr::Neg(inner) => Self::symbolic_sum(inner, var, lower, upper).map(Expr::neg),
1343
1344 Expr::Add(terms) => {
1345 let mut summed_terms = Vec::with_capacity(terms.len());
1346 for term in terms {
1347 summed_terms.push(Self::symbolic_sum(term, var, lower, upper)?);
1348 }
1349 Some(Expr::add(summed_terms))
1350 }
1351
1352 Expr::Mul(factors) => {
1353 let (mut independent, dependent): (Vec<Expr>, Vec<Expr>) =
1354 factors.iter().cloned().partition(|f| !f.contains_var(var));
1355
1356 if dependent.is_empty() {
1357 independent.push(count);
1358 return Some(Expr::mul(independent));
1359 }
1360
1361 if dependent.len() == 1 {
1362 let dep_sum = Self::symbolic_sum(&dependent[0], var, lower, upper)?;
1363 independent.push(dep_sum);
1364 return Some(Expr::mul(independent));
1365 }
1366
1367 None
1368 }
1369
1370 _ => None,
1371 }
1372 }
1373
1374 fn sum_linear(lower: &Expr, upper: &Expr) -> Expr {
1375 fn triangular(x: Expr) -> Expr {
1376 Expr::mul(vec![
1377 x.clone(),
1378 Expr::add(vec![x, Expr::Integer(1)]),
1379 Expr::Rational(Rational::new(1, 2)),
1380 ])
1381 }
1382
1383 let lo_minus_one = Expr::sub(lower.clone(), Expr::Integer(1));
1384 Expr::sub(triangular(upper.clone()), triangular(lo_minus_one))
1385 }
1386
1387 fn sum_square(lower: &Expr, upper: &Expr) -> Expr {
1388 fn square_sum_prefix(x: Expr) -> Expr {
1389 let two_x_plus_one = Expr::add(vec![
1390 Expr::mul(vec![Expr::Integer(2), x.clone()]),
1391 Expr::Integer(1),
1392 ]);
1393 Expr::mul(vec![
1394 x.clone(),
1395 Expr::add(vec![x, Expr::Integer(1)]),
1396 two_x_plus_one,
1397 Expr::Rational(Rational::new(1, 6)),
1398 ])
1399 }
1400
1401 let lo_minus_one = Expr::sub(lower.clone(), Expr::Integer(1));
1402 Expr::sub(
1403 square_sum_prefix(upper.clone()),
1404 square_sum_prefix(lo_minus_one),
1405 )
1406 }
1407
1408 fn sum_cube(lower: &Expr, upper: &Expr) -> Expr {
1409 fn cube_sum_prefix(x: Expr) -> Expr {
1410 let tri = Expr::mul(vec![
1411 x.clone(),
1412 Expr::add(vec![x, Expr::Integer(1)]),
1413 Expr::Rational(Rational::new(1, 2)),
1414 ]);
1415 Expr::pow(tri, Expr::Integer(2))
1416 }
1417
1418 let lo_minus_one = Expr::sub(lower.clone(), Expr::Integer(1));
1419 Expr::sub(
1420 cube_sum_prefix(upper.clone()),
1421 cube_sum_prefix(lo_minus_one),
1422 )
1423 }
1424
1425 fn symbolic_product(
1426 body: &Expr,
1427 var: &crate::expr::Symbol,
1428 lower: &Expr,
1429 upper: &Expr,
1430 ) -> Option<Expr> {
1431 let count = Expr::add(vec![
1432 Expr::sub(upper.clone(), lower.clone()),
1433 Expr::Integer(1),
1434 ]);
1435
1436 if !body.contains_var(var) {
1437 return Some(Expr::pow(body.clone(), count));
1438 }
1439
1440 if let Some(factored) = Self::factor_simple_product_body(body, var) {
1441 return Self::symbolic_product(&factored, var, lower, upper);
1442 }
1443
1444 match body {
1445 Expr::Symbol(s) if s == var => Some(Self::factorial_range(lower, upper)),
1446 Expr::Neg(inner) => {
1447 let inner_product = Self::symbolic_product(inner, var, lower, upper)?;
1448 Some(Expr::mul(vec![
1449 Expr::pow(Expr::Integer(-1), count),
1450 inner_product,
1451 ]))
1452 }
1453 Expr::Mul(factors) => {
1454 let (independent, dependent): (Vec<Expr>, Vec<Expr>) =
1455 factors.iter().cloned().partition(|f| !f.contains_var(var));
1456
1457 let mut parts = Vec::new();
1458 if !independent.is_empty() {
1459 parts.push(Expr::pow(Expr::mul(independent), count.clone()));
1460 }
1461
1462 for dep in dependent {
1463 parts.push(Self::symbolic_product(&dep, var, lower, upper)?);
1464 }
1465
1466 Some(Expr::mul(parts))
1467 }
1468 Expr::Add(_) => Self::product_linear_term(body, var, lower, upper),
1469 Expr::Pow(base, exp) => {
1470 let Expr::Integer(power) = exp.as_ref() else {
1471 return None;
1472 };
1473
1474 if matches!(base.as_ref(), Expr::Symbol(s) if s == var) {
1475 return Some(Expr::pow(
1476 Self::factorial_range(lower, upper),
1477 Expr::Integer(*power),
1478 ));
1479 }
1480
1481 if let Some(linear_base_product) =
1482 Self::product_linear_term(base, var, lower, upper)
1483 {
1484 return Some(Expr::pow(linear_base_product, Expr::Integer(*power)));
1485 }
1486
1487 None
1488 }
1489 _ => None,
1490 }
1491 }
1492
1493 fn factorial_range(lower: &Expr, upper: &Expr) -> Expr {
1494 Self::factorial_range_shifted(lower, upper, 0).unwrap_or_else(|| {
1495 let upper_fact = Expr::func("factorial", vec![upper.clone()]);
1496 if matches!(lower, Expr::Integer(1)) {
1497 upper_fact
1498 } else {
1499 let lower_minus_one = Expr::sub(lower.clone(), Expr::Integer(1));
1500 let lower_fact = Expr::func("factorial", vec![lower_minus_one]);
1501 Expr::mul(vec![upper_fact, Expr::pow(lower_fact, Expr::Integer(-1))])
1502 }
1503 })
1504 }
1505
1506 fn product_linear_term(
1507 expr: &Expr,
1508 var: &crate::expr::Symbol,
1509 lower: &Expr,
1510 upper: &Expr,
1511 ) -> Option<Expr> {
1512 let shift = Self::extract_linear_shift(expr, var)?;
1513 Self::factorial_range_shifted(lower, upper, shift)
1514 }
1515
1516 fn extract_linear_shift(expr: &Expr, var: &crate::expr::Symbol) -> Option<i64> {
1517 match expr {
1518 Expr::Symbol(s) if s == var => Some(0),
1519 Expr::Add(terms) => {
1520 let mut saw_var = false;
1521 let mut shift = 0_i64;
1522 for term in terms {
1523 match term {
1524 Expr::Symbol(s) if s == var => {
1525 if saw_var {
1526 return None;
1527 }
1528 saw_var = true;
1529 }
1530 Expr::Integer(n) => {
1531 shift = shift.checked_add(*n)?;
1532 }
1533 Expr::Neg(inner) => {
1534 if let Expr::Integer(n) = inner.as_ref() {
1535 shift = shift.checked_sub(*n)?;
1536 } else {
1537 return None;
1538 }
1539 }
1540 _ => return None,
1541 }
1542 }
1543 if saw_var { Some(shift) } else { None }
1544 }
1545 _ => None,
1546 }
1547 }
1548
1549 fn factorial_range_shifted(lower: &Expr, upper: &Expr, shift: i64) -> Option<Expr> {
1550 let shift_minus_one = shift.checked_sub(1)?;
1551 if let Expr::Integer(lo) = lower {
1552 if lo.checked_add(shift_minus_one)? < 0 {
1553 return None;
1554 }
1555 }
1556 if let Expr::Integer(hi) = upper {
1557 if hi.checked_add(shift)? < 0 {
1558 return None;
1559 }
1560 }
1561
1562 let shifted_upper =
1563 Simplifier::simplify(&Expr::add(vec![upper.clone(), Expr::Integer(shift)]));
1564 let shifted_lower_minus_one = Simplifier::simplify(&Expr::add(vec![
1565 lower.clone(),
1566 Expr::Integer(shift_minus_one),
1567 ]));
1568
1569 let upper_fact = Expr::func("factorial", vec![shifted_upper]);
1570 if matches!(shifted_lower_minus_one, Expr::Integer(0)) {
1571 Some(upper_fact)
1572 } else {
1573 let lower_fact = Expr::func("factorial", vec![shifted_lower_minus_one]);
1574 Some(Expr::mul(vec![
1575 upper_fact,
1576 Expr::pow(lower_fact, Expr::Integer(-1)),
1577 ]))
1578 }
1579 }
1580
1581 fn factor_simple_product_body(body: &Expr, var: &crate::expr::Symbol) -> Option<Expr> {
1582 let Expr::Add(terms) = body else {
1583 return None;
1584 };
1585 if terms.len() != 2 {
1586 return None;
1587 }
1588
1589 let mut has_square = false;
1590 let mut linear_sign = 0_i64;
1591 for term in terms {
1592 match term {
1593 Expr::Pow(base, exp)
1594 if matches!(base.as_ref(), Expr::Symbol(s) if s == var)
1595 && matches!(exp.as_ref(), Expr::Integer(2)) =>
1596 {
1597 has_square = true;
1598 }
1599 Expr::Symbol(s) if s == var => linear_sign += 1,
1600 Expr::Neg(inner) if matches!(inner.as_ref(), Expr::Symbol(s) if s == var) => {
1601 linear_sign -= 1;
1602 }
1603 _ => return None,
1604 }
1605 }
1606
1607 if !has_square {
1608 return None;
1609 }
1610
1611 let var_expr = Expr::Symbol(var.clone());
1612 match linear_sign {
1613 1 => Some(Expr::mul(vec![
1614 var_expr.clone(),
1615 Expr::add(vec![var_expr, Expr::Integer(1)]),
1616 ])),
1617 -1 => Some(Expr::mul(vec![
1618 var_expr.clone(),
1619 Expr::add(vec![var_expr, Expr::Integer(-1)]),
1620 ])),
1621 _ => None,
1622 }
1623 }
1624
1625 fn eval_discrete_series(
1626 &self,
1627 body: &Expr,
1628 var: &crate::expr::Symbol,
1629 lower: &Expr,
1630 upper: &Expr,
1631 is_product: bool,
1632 ) -> Result<Expr> {
1633 const MAX_TERMS: i64 = 100_000;
1634
1635 let lo = self.eval_integer_bound(lower)?;
1636 let hi = self.eval_integer_bound(upper)?;
1637
1638 if lo > hi {
1639 return Ok(if is_product {
1640 Expr::Integer(1)
1641 } else {
1642 Expr::Integer(0)
1643 });
1644 }
1645
1646 let count = hi.saturating_sub(lo).saturating_add(1);
1647 if count > MAX_TERMS {
1648 let kind = if is_product { "product" } else { "sum" };
1649 return Err(CasError::EvaluationError(format!(
1650 "{kind} has too many terms ({count}); limit is {MAX_TERMS}"
1651 )));
1652 }
1653
1654 let mut acc = if is_product {
1655 Expr::Integer(1)
1656 } else {
1657 Expr::Integer(0)
1658 };
1659
1660 for n in lo..=hi {
1661 let substituted = Simplifier::substitute(body, var, &Expr::Integer(n));
1662 let term = self.eval(&substituted)?;
1663 acc = if is_product {
1664 self.multiply(&acc, &term)?
1665 } else {
1666 self.add(&acc, &term)?
1667 };
1668 }
1669
1670 Ok(Simplifier::simplify(&acc))
1671 }
1672 }
1673
1674 // Helper functions
1675
1676 fn factorial(n: u64) -> u64 {
1677 (1..=n).product()
1678 }
1679
1680 fn gcd(a: i64, b: i64) -> i64 {
1681 let (a, b) = (a.abs(), b.abs());
1682 if b == 0 { a } else { gcd(b, a % b) }
1683 }
1684
1685 fn lcm(a: i64, b: i64) -> i64 {
1686 (a * b).abs() / gcd(a, b)
1687 }
1688
1689 /// Gamma function approximation (Lanczos)
1690 fn gamma(x: f64) -> f64 {
1691 if x < 0.5 {
1692 PI / (PI * x).sin() / gamma(1.0 - x)
1693 } else {
1694 let x = x - 1.0;
1695 let g = 7.0;
1696 let c = [
1697 0.99999999999980993,
1698 676.5203681218851,
1699 -1259.1392167224028,
1700 771.32342877765313,
1701 -176.61502916214059,
1702 12.507343278686905,
1703 -0.13857109526572012,
1704 9.9843695780195716e-6,
1705 1.5056327351493116e-7,
1706 ];
1707
1708 let mut sum = c[0];
1709 for (i, &ci) in c.iter().enumerate().skip(1) {
1710 sum += ci / (x + i as f64);
1711 }
1712
1713 (2.0 * PI).sqrt() * (x + g + 0.5).powf(x + 0.5) * (-(x + g + 0.5)).exp() * sum
1714 }
1715 }
1716
1717 #[cfg(test)]
1718 mod tests {
1719 use super::*;
1720 use crate::parser::parse;
1721
1722 fn eval(input: &str) -> Result<Expr> {
1723 let expr = parse(input)?;
1724 Evaluator::new().eval(&expr)
1725 }
1726
1727 fn eval_to_f64(input: &str) -> f64 {
1728 match eval(input).unwrap() {
1729 Expr::Integer(n) => n as f64,
1730 Expr::Float(x) => x,
1731 other => panic!("expected number, got {other}"),
1732 }
1733 }
1734
1735 fn expr_to_f64(expr: &Expr) -> f64 {
1736 match expr {
1737 Expr::Integer(n) => *n as f64,
1738 Expr::Rational(r) => r.to_f64(),
1739 Expr::Float(x) => *x,
1740 other => panic!("expected numeric expression, got {other}"),
1741 }
1742 }
1743
1744 #[test]
1745 fn test_arithmetic() {
1746 assert_eq!(eval("2 + 3").unwrap(), Expr::Integer(5));
1747 assert_eq!(eval("10 - 4").unwrap(), Expr::Integer(6));
1748 assert_eq!(eval("3 * 4").unwrap(), Expr::Integer(12));
1749 assert_eq!(eval("2^10").unwrap(), Expr::Integer(1024));
1750 }
1751
1752 #[test]
1753 fn test_float() {
1754 let result = eval_to_f64("3.14 * 2");
1755 assert!((result - 6.28).abs() < 1e-10);
1756 }
1757
1758 #[test]
1759 fn test_functions() {
1760 let result = eval_to_f64("sin(0)");
1761 assert!(result.abs() < 1e-10);
1762
1763 let result = eval_to_f64("cos(0)");
1764 assert!((result - 1.0).abs() < 1e-10);
1765
1766 let result = eval_to_f64("sqrt(4)");
1767 assert!((result - 2.0).abs() < 1e-10);
1768
1769 let result = eval_to_f64("ln(e)");
1770 assert!((result - 1.0).abs() < 1e-10);
1771 }
1772
1773 #[test]
1774 fn test_constants() {
1775 let result = eval_to_f64("pi");
1776 assert!((result - PI).abs() < 1e-10);
1777
1778 let result = eval_to_f64("e");
1779 assert!((result - E).abs() < 1e-10);
1780 }
1781
1782 #[test]
1783 fn test_factorial() {
1784 assert_eq!(eval("5!").unwrap(), Expr::Integer(120));
1785 }
1786
1787 #[test]
1788 fn test_complex_expr() {
1789 let result = eval_to_f64("2 * sin(pi/2) + 1");
1790 assert!((result - 3.0).abs() < 1e-10);
1791 }
1792
1793 #[test]
1794 fn test_variables() {
1795 let expr = parse("x + 1").unwrap();
1796 let mut evaluator = Evaluator::new();
1797 evaluator.set_var("x", Expr::integer(5));
1798 let result = evaluator.eval(&expr).unwrap();
1799 assert_eq!(result, Expr::Integer(6));
1800 }
1801
1802 #[test]
1803 fn test_sum_evaluation() {
1804 assert_eq!(eval("sum(n, n, 1, 5)").unwrap(), Expr::Integer(15));
1805 }
1806
1807 #[test]
1808 fn test_product_evaluation() {
1809 assert_eq!(eval("product(n, n, 1, 4)").unwrap(), Expr::Integer(24));
1810 }
1811
1812 #[test]
1813 fn test_definite_integral_numeric_fallback_for_non_elementary_antiderivative() {
1814 let result = eval_to_f64("integrate(x^(x+2), x, 0, 1)");
1815 assert!((result - 0.2781176122).abs() < 1e-8);
1816 }
1817
1818 #[test]
1819 fn test_definite_integral_with_symbolic_bounds_stays_symbolic() {
1820 let symbolic = eval("integrate(x^(x+2), x, 0, n)").unwrap();
1821 assert!(matches!(
1822 symbolic,
1823 Expr::Integral {
1824 lower: Some(_),
1825 upper: Some(_),
1826 ..
1827 }
1828 ));
1829 }
1830
1831 #[test]
1832 fn test_empty_discrete_range() {
1833 assert_eq!(eval("sum(n, n, 5, 1)").unwrap(), Expr::Integer(0));
1834 assert_eq!(eval("product(n, n, 5, 1)").unwrap(), Expr::Integer(1));
1835 }
1836
1837 #[test]
1838 fn test_symbolic_solve_function() {
1839 let result = eval("solve(x^2 - 4, x)").unwrap();
1840 if let Expr::Vector(solutions) = result {
1841 assert_eq!(solutions.len(), 2);
1842 let values: Vec<f64> = solutions
1843 .iter()
1844 .map(|s| match s {
1845 Expr::Integer(n) => *n as f64,
1846 Expr::Rational(r) => r.to_f64(),
1847 Expr::Float(x) => *x,
1848 other => panic!("expected numeric solution, got {other}"),
1849 })
1850 .collect();
1851 assert!(values.iter().any(|v| (*v - 2.0).abs() < 1e-10));
1852 assert!(values.iter().any(|v| (*v + 2.0).abs() < 1e-10));
1853 } else {
1854 panic!("expected vector of solutions");
1855 }
1856 }
1857
1858 #[test]
1859 fn test_symbolic_sum_linear_closed_form() {
1860 let symbolic = eval("sum(k, k, 1, n)").unwrap();
1861 let mut evaluator = Evaluator::new();
1862 evaluator.set_var("n", Expr::Integer(10));
1863 let value = evaluator.eval(&symbolic).unwrap();
1864 assert!((expr_to_f64(&value) - 55.0).abs() < 1e-10);
1865 }
1866
1867 #[test]
1868 fn test_symbolic_sum_quadratic_closed_form() {
1869 let symbolic = eval("sum(k^2 + k, k, 1, n)").unwrap();
1870 let mut evaluator = Evaluator::new();
1871 evaluator.set_var("n", Expr::Integer(5));
1872 let value = evaluator.eval(&symbolic).unwrap();
1873 assert!((expr_to_f64(&value) - 70.0).abs() < 1e-10);
1874 }
1875
1876 #[test]
1877 fn test_symbolic_sum_fallback_for_unreduced_form() {
1878 let symbolic = eval("sum(1/(k-1), k, 1, n)").unwrap();
1879 assert!(matches!(symbolic, Expr::Sum { .. }));
1880 }
1881
1882 #[test]
1883 fn test_symbolic_product_constant_closed_form() {
1884 let symbolic = eval("product(2, k, 1, n)").unwrap();
1885 let mut evaluator = Evaluator::new();
1886 evaluator.set_var("n", Expr::Integer(5));
1887 let value = evaluator.eval(&symbolic).unwrap();
1888 assert!((expr_to_f64(&value) - 32.0).abs() < 1e-10);
1889 }
1890
1891 #[test]
1892 fn test_symbolic_product_factorial_closed_form() {
1893 let symbolic = eval("product(k, k, 1, n)").unwrap();
1894 let mut evaluator = Evaluator::new();
1895 evaluator.set_var("n", Expr::Integer(5));
1896 let value = evaluator.eval(&symbolic).unwrap();
1897 assert!((expr_to_f64(&value) - 120.0).abs() < 1e-10);
1898 }
1899
1900 #[test]
1901 fn test_symbolic_product_linear_shift_closed_form() {
1902 let symbolic = eval("product(k+1, k, 1, n)").unwrap();
1903 let mut evaluator = Evaluator::new();
1904 evaluator.set_var("n", Expr::Integer(5));
1905 let value = evaluator.eval(&symbolic).unwrap();
1906 assert!((expr_to_f64(&value) - 720.0).abs() < 1e-10);
1907 }
1908
1909 #[test]
1910 fn test_symbolic_product_mul_decomposition_closed_form() {
1911 let symbolic = eval("product(2*k, k, 1, n)").unwrap();
1912 let mut evaluator = Evaluator::new();
1913 evaluator.set_var("n", Expr::Integer(5));
1914 let value = evaluator.eval(&symbolic).unwrap();
1915 assert!((expr_to_f64(&value) - 3840.0).abs() < 1e-10);
1916 }
1917
1918 #[test]
1919 fn test_symbolic_product_simple_quadratic_factoring() {
1920 let symbolic = eval("product(k^2 + k, k, 1, n)").unwrap();
1921 let mut evaluator = Evaluator::new();
1922 evaluator.set_var("n", Expr::Integer(4));
1923 let value = evaluator.eval(&symbolic).unwrap();
1924 assert!((expr_to_f64(&value) - 2880.0).abs() < 1e-10);
1925 }
1926
1927 #[test]
1928 fn test_symbolic_product_telescoping_ratio() {
1929 let symbolic = eval("product(k/(k-1), k, 2, n)").unwrap();
1930 let mut evaluator = Evaluator::new();
1931 evaluator.set_var("n", Expr::Integer(6));
1932 let value = evaluator.eval(&symbolic).unwrap();
1933 assert!((expr_to_f64(&value) - 6.0).abs() < 1e-10);
1934 }
1935
1936 #[test]
1937 fn test_symbolic_product_fallback_for_unreduced_form() {
1938 let symbolic = eval("product(k^2 + 2, k, 1, n)").unwrap();
1939 assert!(matches!(symbolic, Expr::Product { .. }));
1940 }
1941
1942 #[test]
1943 fn test_sum_infers_iteration_var_when_template_var_unused() {
1944 let expr = Expr::Sum {
1945 expr: Box::new(Expr::Symbol(crate::expr::Symbol::new("n"))),
1946 var: crate::expr::Symbol::new("i"),
1947 lower: Box::new(Expr::Integer(1)),
1948 upper: Box::new(Expr::Integer(5)),
1949 };
1950 let result = Evaluator::new().eval(&expr).unwrap();
1951 assert_eq!(result, Expr::Integer(15));
1952 }
1953
1954 #[test]
1955 fn test_product_infers_iteration_var_when_template_var_unused() {
1956 let expr = Expr::Product {
1957 expr: Box::new(Expr::Symbol(crate::expr::Symbol::new("n"))),
1958 var: crate::expr::Symbol::new("i"),
1959 lower: Box::new(Expr::Integer(1)),
1960 upper: Box::new(Expr::Integer(5)),
1961 };
1962 let result = Evaluator::new().eval(&expr).unwrap();
1963 assert_eq!(result, Expr::Integer(120));
1964 }
1965
1966 #[test]
1967 fn test_exact_multiply_rational_one_identity() {
1968 let mut evaluator = Evaluator::new();
1969 evaluator.exact_mode = true;
1970
1971 let expr = Expr::mul(vec![
1972 Expr::Rational(crate::expr::Rational::new(1, 1)),
1973 Expr::symbol("n"),
1974 ]);
1975 let result = evaluator.eval(&expr).unwrap();
1976
1977 assert_eq!(result, Expr::symbol("n"));
1978 }
1979 }
1980