add garcalc-cas with expression AST, recursive descent parser, evaluator
- SHA
e8ccf3f0011c089493c3a94c15df7ccb7b4cabfc- Parents
-
5861c28 - Tree
b07d58a
e8ccf3f
e8ccf3f0011c089493c3a94c15df7ccb7b4cabfc5861c28
b07d58a| Status | File | + | - |
|---|---|---|---|
| A |
garcalc-cas/Cargo.toml
|
13 | 0 |
| A |
garcalc-cas/src/error.rs
|
31 | 0 |
| A |
garcalc-cas/src/eval.rs
|
578 | 0 |
| A |
garcalc-cas/src/expr.rs
|
497 | 0 |
| A |
garcalc-cas/src/lib.rs
|
18 | 0 |
| A |
garcalc-cas/src/parser.rs
|
678 | 0 |
garcalc-cas/Cargo.tomladded@@ -0,0 +1,13 @@ | ||
| 1 | +[package] | |
| 2 | +name = "garcalc-cas" | |
| 3 | +version.workspace = true | |
| 4 | +edition.workspace = true | |
| 5 | +authors.workspace = true | |
| 6 | +license.workspace = true | |
| 7 | +description = "Computer algebra system for garcalc" | |
| 8 | + | |
| 9 | +[dependencies] | |
| 10 | +serde = { workspace = true } | |
| 11 | +thiserror = { workspace = true } | |
| 12 | + | |
| 13 | +[dev-dependencies] | |
garcalc-cas/src/error.rsadded@@ -0,0 +1,31 @@ | ||
| 1 | +use thiserror::Error; | |
| 2 | + | |
| 3 | +/// CAS error types | |
| 4 | +#[derive(Debug, Error)] | |
| 5 | +pub enum CasError { | |
| 6 | + #[error("Parse error at position {position}: {message}")] | |
| 7 | + Parse { position: usize, message: String }, | |
| 8 | + | |
| 9 | + #[error("Undefined variable: {0}")] | |
| 10 | + UndefinedVariable(String), | |
| 11 | + | |
| 12 | + #[error("Undefined function: {0}")] | |
| 13 | + UndefinedFunction(String), | |
| 14 | + | |
| 15 | + #[error("Type error: {0}")] | |
| 16 | + Type(String), | |
| 17 | + | |
| 18 | + #[error("Division by zero")] | |
| 19 | + DivisionByZero, | |
| 20 | + | |
| 21 | + #[error("Domain error: {0}")] | |
| 22 | + Domain(String), | |
| 23 | + | |
| 24 | + #[error("Not implemented: {0}")] | |
| 25 | + NotImplemented(String), | |
| 26 | + | |
| 27 | + #[error("Invalid argument: {0}")] | |
| 28 | + InvalidArgument(String), | |
| 29 | +} | |
| 30 | + | |
| 31 | +pub type Result<T> = std::result::Result<T, CasError>; | |
garcalc-cas/src/eval.rsadded@@ -0,0 +1,578 @@ | ||
| 1 | +//! Expression evaluation | |
| 2 | +//! | |
| 3 | +//! Supports both numeric and symbolic evaluation modes. | |
| 4 | + | |
| 5 | +use std::collections::HashMap; | |
| 6 | +use std::f64::consts::{E, PI}; | |
| 7 | + | |
| 8 | +use crate::error::{CasError, Result}; | |
| 9 | +use crate::expr::{Expr, Rational, Sign}; | |
| 10 | + | |
| 11 | +/// Variable bindings for evaluation | |
| 12 | +pub type Environment = HashMap<String, Expr>; | |
| 13 | + | |
| 14 | +/// Expression evaluator | |
| 15 | +pub struct Evaluator { | |
| 16 | + env: Environment, | |
| 17 | + /// If true, try to keep results symbolic when possible | |
| 18 | + exact_mode: bool, | |
| 19 | + /// Angle mode for trig functions | |
| 20 | + angle_mode: AngleMode, | |
| 21 | +} | |
| 22 | + | |
| 23 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] | |
| 24 | +pub enum AngleMode { | |
| 25 | + Radians, | |
| 26 | + Degrees, | |
| 27 | +} | |
| 28 | + | |
| 29 | +impl Default for AngleMode { | |
| 30 | + fn default() -> Self { | |
| 31 | + Self::Radians | |
| 32 | + } | |
| 33 | +} | |
| 34 | + | |
| 35 | +impl Default for Evaluator { | |
| 36 | + fn default() -> Self { | |
| 37 | + Self::new() | |
| 38 | + } | |
| 39 | +} | |
| 40 | + | |
| 41 | +impl Evaluator { | |
| 42 | + pub fn new() -> Self { | |
| 43 | + Self { | |
| 44 | + env: Environment::new(), | |
| 45 | + exact_mode: false, | |
| 46 | + angle_mode: AngleMode::Radians, | |
| 47 | + } | |
| 48 | + } | |
| 49 | + | |
| 50 | + pub fn with_exact_mode(mut self, exact: bool) -> Self { | |
| 51 | + self.exact_mode = exact; | |
| 52 | + self | |
| 53 | + } | |
| 54 | + | |
| 55 | + pub fn with_angle_mode(mut self, mode: AngleMode) -> Self { | |
| 56 | + self.angle_mode = mode; | |
| 57 | + self | |
| 58 | + } | |
| 59 | + | |
| 60 | + pub fn set_var(&mut self, name: impl Into<String>, value: Expr) { | |
| 61 | + self.env.insert(name.into(), value); | |
| 62 | + } | |
| 63 | + | |
| 64 | + pub fn get_var(&self, name: &str) -> Option<&Expr> { | |
| 65 | + self.env.get(name) | |
| 66 | + } | |
| 67 | + | |
| 68 | + pub fn clear_vars(&mut self) { | |
| 69 | + self.env.clear(); | |
| 70 | + } | |
| 71 | + | |
| 72 | + /// Evaluate an expression to a numeric result | |
| 73 | + pub fn eval(&self, expr: &Expr) -> Result<Expr> { | |
| 74 | + match expr { | |
| 75 | + Expr::Integer(n) => Ok(Expr::Integer(*n)), | |
| 76 | + Expr::Rational(r) => Ok(Expr::Rational(*r)), | |
| 77 | + Expr::Float(x) => Ok(Expr::Float(*x)), | |
| 78 | + Expr::Complex(re, im) => Ok(Expr::Complex(*re, *im)), | |
| 79 | + | |
| 80 | + Expr::Symbol(sym) => { | |
| 81 | + // Check for constants | |
| 82 | + match sym.as_str() { | |
| 83 | + "pi" => Ok(Expr::Float(PI)), | |
| 84 | + "e" => Ok(Expr::Float(E)), | |
| 85 | + _ => { | |
| 86 | + // Check environment | |
| 87 | + if let Some(value) = self.env.get(sym.as_str()) { | |
| 88 | + self.eval(value) | |
| 89 | + } else if self.exact_mode { | |
| 90 | + // In exact mode, keep undefined symbols | |
| 91 | + Ok(expr.clone()) | |
| 92 | + } else { | |
| 93 | + Err(CasError::UndefinedVariable(sym.to_string())) | |
| 94 | + } | |
| 95 | + } | |
| 96 | + } | |
| 97 | + } | |
| 98 | + | |
| 99 | + Expr::Neg(e) => { | |
| 100 | + let val = self.eval(e)?; | |
| 101 | + self.negate(&val) | |
| 102 | + } | |
| 103 | + | |
| 104 | + Expr::Add(terms) => { | |
| 105 | + let mut sum = Expr::integer(0); | |
| 106 | + for term in terms { | |
| 107 | + let val = self.eval(term)?; | |
| 108 | + sum = self.add(&sum, &val)?; | |
| 109 | + } | |
| 110 | + Ok(sum) | |
| 111 | + } | |
| 112 | + | |
| 113 | + Expr::Mul(factors) => { | |
| 114 | + let mut product = Expr::integer(1); | |
| 115 | + for factor in factors { | |
| 116 | + let val = self.eval(factor)?; | |
| 117 | + product = self.multiply(&product, &val)?; | |
| 118 | + } | |
| 119 | + Ok(product) | |
| 120 | + } | |
| 121 | + | |
| 122 | + Expr::Pow(base, exp) => { | |
| 123 | + let base_val = self.eval(base)?; | |
| 124 | + let exp_val = self.eval(exp)?; | |
| 125 | + self.power(&base_val, &exp_val) | |
| 126 | + } | |
| 127 | + | |
| 128 | + Expr::Func(name, args) => { | |
| 129 | + let evaluated_args: Result<Vec<_>> = | |
| 130 | + args.iter().map(|a| self.eval(a)).collect(); | |
| 131 | + self.call_function(name, &evaluated_args?) | |
| 132 | + } | |
| 133 | + | |
| 134 | + Expr::Equation(lhs, rhs) => { | |
| 135 | + let lhs_val = self.eval(lhs)?; | |
| 136 | + let rhs_val = self.eval(rhs)?; | |
| 137 | + Ok(Expr::Equation(Box::new(lhs_val), Box::new(rhs_val))) | |
| 138 | + } | |
| 139 | + | |
| 140 | + Expr::Vector(elems) => { | |
| 141 | + let evaluated: Result<Vec<_>> = elems.iter().map(|e| self.eval(e)).collect(); | |
| 142 | + Ok(Expr::Vector(evaluated?)) | |
| 143 | + } | |
| 144 | + | |
| 145 | + Expr::Matrix(rows) => { | |
| 146 | + let evaluated: Result<Vec<Vec<_>>> = rows | |
| 147 | + .iter() | |
| 148 | + .map(|row| row.iter().map(|e| self.eval(e)).collect()) | |
| 149 | + .collect(); | |
| 150 | + Ok(Expr::Matrix(evaluated?)) | |
| 151 | + } | |
| 152 | + | |
| 153 | + // Symbolic operations - keep as-is for now (will implement in later sprints) | |
| 154 | + Expr::Derivative { .. } | |
| 155 | + | Expr::Integral { .. } | |
| 156 | + | Expr::Limit { .. } | |
| 157 | + | Expr::Sum { .. } | |
| 158 | + | Expr::Product { .. } => { | |
| 159 | + if self.exact_mode { | |
| 160 | + Ok(expr.clone()) | |
| 161 | + } else { | |
| 162 | + Err(CasError::NotImplemented( | |
| 163 | + "symbolic operations in numeric mode".to_string(), | |
| 164 | + )) | |
| 165 | + } | |
| 166 | + } | |
| 167 | + | |
| 168 | + Expr::Inequality { lhs, op, rhs } => { | |
| 169 | + let lhs_val = self.eval(lhs)?; | |
| 170 | + let rhs_val = self.eval(rhs)?; | |
| 171 | + Ok(Expr::Inequality { | |
| 172 | + lhs: Box::new(lhs_val), | |
| 173 | + op: *op, | |
| 174 | + rhs: Box::new(rhs_val), | |
| 175 | + }) | |
| 176 | + } | |
| 177 | + | |
| 178 | + Expr::Undefined => Ok(Expr::Undefined), | |
| 179 | + Expr::Infinity(sign) => Ok(Expr::Infinity(*sign)), | |
| 180 | + } | |
| 181 | + } | |
| 182 | + | |
| 183 | + fn to_f64(&self, expr: &Expr) -> Result<f64> { | |
| 184 | + match expr { | |
| 185 | + Expr::Integer(n) => Ok(*n as f64), | |
| 186 | + Expr::Rational(r) => Ok(r.to_f64()), | |
| 187 | + Expr::Float(x) => Ok(*x), | |
| 188 | + _ => Err(CasError::Type(format!("expected number, got {expr}"))), | |
| 189 | + } | |
| 190 | + } | |
| 191 | + | |
| 192 | + fn negate(&self, expr: &Expr) -> Result<Expr> { | |
| 193 | + match expr { | |
| 194 | + Expr::Integer(n) => Ok(Expr::Integer(-n)), | |
| 195 | + Expr::Rational(r) => Ok(Expr::Rational(Rational::new(-r.num, r.den))), | |
| 196 | + Expr::Float(x) => Ok(Expr::Float(-x)), | |
| 197 | + Expr::Complex(re, im) => Ok(Expr::Complex(-re, -im)), | |
| 198 | + Expr::Infinity(Sign::Positive) => Ok(Expr::Infinity(Sign::Negative)), | |
| 199 | + Expr::Infinity(Sign::Negative) => Ok(Expr::Infinity(Sign::Positive)), | |
| 200 | + _ => Ok(Expr::neg(expr.clone())), | |
| 201 | + } | |
| 202 | + } | |
| 203 | + | |
| 204 | + fn add(&self, a: &Expr, b: &Expr) -> Result<Expr> { | |
| 205 | + match (a, b) { | |
| 206 | + (Expr::Integer(x), Expr::Integer(y)) => Ok(Expr::Integer(x + y)), | |
| 207 | + (Expr::Float(x), Expr::Float(y)) => Ok(Expr::Float(x + y)), | |
| 208 | + (Expr::Integer(x), Expr::Float(y)) | (Expr::Float(y), Expr::Integer(x)) => { | |
| 209 | + Ok(Expr::Float(*x as f64 + y)) | |
| 210 | + } | |
| 211 | + (Expr::Complex(r1, i1), Expr::Complex(r2, i2)) => Ok(Expr::Complex(r1 + r2, i1 + i2)), | |
| 212 | + (Expr::Complex(re, im), Expr::Float(x)) | (Expr::Float(x), Expr::Complex(re, im)) => { | |
| 213 | + Ok(Expr::Complex(re + x, *im)) | |
| 214 | + } | |
| 215 | + (Expr::Complex(re, im), Expr::Integer(n)) | |
| 216 | + | (Expr::Integer(n), Expr::Complex(re, im)) => { | |
| 217 | + Ok(Expr::Complex(re + *n as f64, *im)) | |
| 218 | + } | |
| 219 | + (Expr::Rational(r1), Expr::Rational(r2)) => { | |
| 220 | + let num = r1.num * r2.den + r2.num * r1.den; | |
| 221 | + let den = r1.den * r2.den; | |
| 222 | + Ok(Expr::Rational(Rational::new(num, den))) | |
| 223 | + } | |
| 224 | + (Expr::Rational(r), Expr::Integer(n)) | (Expr::Integer(n), Expr::Rational(r)) => { | |
| 225 | + let num = r.num + n * r.den; | |
| 226 | + Ok(Expr::Rational(Rational::new(num, r.den))) | |
| 227 | + } | |
| 228 | + _ => { | |
| 229 | + // Try converting to floats | |
| 230 | + if let (Ok(x), Ok(y)) = (self.to_f64(a), self.to_f64(b)) { | |
| 231 | + Ok(Expr::Float(x + y)) | |
| 232 | + } else if self.exact_mode { | |
| 233 | + Ok(Expr::add(vec![a.clone(), b.clone()])) | |
| 234 | + } else { | |
| 235 | + Err(CasError::Type(format!("cannot add {a} and {b}"))) | |
| 236 | + } | |
| 237 | + } | |
| 238 | + } | |
| 239 | + } | |
| 240 | + | |
| 241 | + fn multiply(&self, a: &Expr, b: &Expr) -> Result<Expr> { | |
| 242 | + match (a, b) { | |
| 243 | + (Expr::Integer(x), Expr::Integer(y)) => Ok(Expr::Integer(x * y)), | |
| 244 | + (Expr::Float(x), Expr::Float(y)) => Ok(Expr::Float(x * y)), | |
| 245 | + (Expr::Integer(x), Expr::Float(y)) | (Expr::Float(y), Expr::Integer(x)) => { | |
| 246 | + Ok(Expr::Float(*x as f64 * y)) | |
| 247 | + } | |
| 248 | + (Expr::Complex(r1, i1), Expr::Complex(r2, i2)) => { | |
| 249 | + // (a+bi)(c+di) = (ac-bd) + (ad+bc)i | |
| 250 | + Ok(Expr::Complex(r1 * r2 - i1 * i2, r1 * i2 + i1 * r2)) | |
| 251 | + } | |
| 252 | + (Expr::Complex(re, im), Expr::Float(x)) | (Expr::Float(x), Expr::Complex(re, im)) => { | |
| 253 | + Ok(Expr::Complex(re * x, im * x)) | |
| 254 | + } | |
| 255 | + (Expr::Complex(re, im), Expr::Integer(n)) | |
| 256 | + | (Expr::Integer(n), Expr::Complex(re, im)) => { | |
| 257 | + let n = *n as f64; | |
| 258 | + Ok(Expr::Complex(re * n, im * n)) | |
| 259 | + } | |
| 260 | + (Expr::Rational(r1), Expr::Rational(r2)) => { | |
| 261 | + Ok(Expr::Rational(Rational::new(r1.num * r2.num, r1.den * r2.den))) | |
| 262 | + } | |
| 263 | + (Expr::Rational(r), Expr::Integer(n)) | (Expr::Integer(n), Expr::Rational(r)) => { | |
| 264 | + Ok(Expr::Rational(Rational::new(r.num * n, r.den))) | |
| 265 | + } | |
| 266 | + _ => { | |
| 267 | + if let (Ok(x), Ok(y)) = (self.to_f64(a), self.to_f64(b)) { | |
| 268 | + Ok(Expr::Float(x * y)) | |
| 269 | + } else if self.exact_mode { | |
| 270 | + Ok(Expr::mul(vec![a.clone(), b.clone()])) | |
| 271 | + } else { | |
| 272 | + Err(CasError::Type(format!("cannot multiply {a} and {b}"))) | |
| 273 | + } | |
| 274 | + } | |
| 275 | + } | |
| 276 | + } | |
| 277 | + | |
| 278 | + fn power(&self, base: &Expr, exp: &Expr) -> Result<Expr> { | |
| 279 | + // Special cases | |
| 280 | + if exp.is_zero() { | |
| 281 | + return Ok(Expr::integer(1)); | |
| 282 | + } | |
| 283 | + if exp.is_one() { | |
| 284 | + return Ok(base.clone()); | |
| 285 | + } | |
| 286 | + if base.is_zero() { | |
| 287 | + return Ok(Expr::integer(0)); | |
| 288 | + } | |
| 289 | + if base.is_one() { | |
| 290 | + return Ok(Expr::integer(1)); | |
| 291 | + } | |
| 292 | + | |
| 293 | + match (base, exp) { | |
| 294 | + (Expr::Integer(b), Expr::Integer(e)) => { | |
| 295 | + if *e >= 0 { | |
| 296 | + Ok(Expr::Integer(b.pow(*e as u32))) | |
| 297 | + } else { | |
| 298 | + // Negative exponent -> rational or float | |
| 299 | + let denom = b.pow((-e) as u32); | |
| 300 | + if self.exact_mode { | |
| 301 | + Ok(Expr::Rational(Rational::new(1, denom))) | |
| 302 | + } else { | |
| 303 | + Ok(Expr::Float(1.0 / denom as f64)) | |
| 304 | + } | |
| 305 | + } | |
| 306 | + } | |
| 307 | + (Expr::Float(b), Expr::Integer(e)) => Ok(Expr::Float(b.powi(*e as i32))), | |
| 308 | + (Expr::Float(b), Expr::Float(e)) => Ok(Expr::Float(b.powf(*e))), | |
| 309 | + (Expr::Integer(b), Expr::Float(e)) => Ok(Expr::Float((*b as f64).powf(*e))), | |
| 310 | + _ => { | |
| 311 | + if let (Ok(b), Ok(e)) = (self.to_f64(base), self.to_f64(exp)) { | |
| 312 | + Ok(Expr::Float(b.powf(e))) | |
| 313 | + } else if self.exact_mode { | |
| 314 | + Ok(Expr::pow(base.clone(), exp.clone())) | |
| 315 | + } else { | |
| 316 | + Err(CasError::Type(format!("cannot compute {base}^{exp}"))) | |
| 317 | + } | |
| 318 | + } | |
| 319 | + } | |
| 320 | + } | |
| 321 | + | |
| 322 | + fn call_function(&self, name: &str, args: &[Expr]) -> Result<Expr> { | |
| 323 | + // Get numeric argument if single-arg function | |
| 324 | + let arg = if args.len() == 1 { | |
| 325 | + self.to_f64(&args[0]).ok() | |
| 326 | + } else { | |
| 327 | + None | |
| 328 | + }; | |
| 329 | + | |
| 330 | + // Angle conversion for trig functions | |
| 331 | + let angle = |x: f64| match self.angle_mode { | |
| 332 | + AngleMode::Radians => x, | |
| 333 | + AngleMode::Degrees => x.to_radians(), | |
| 334 | + }; | |
| 335 | + | |
| 336 | + let from_angle = |x: f64| match self.angle_mode { | |
| 337 | + AngleMode::Radians => x, | |
| 338 | + AngleMode::Degrees => x.to_degrees(), | |
| 339 | + }; | |
| 340 | + | |
| 341 | + match (name, args.len(), arg) { | |
| 342 | + // Trigonometric | |
| 343 | + ("sin", 1, Some(x)) => Ok(Expr::Float(angle(x).sin())), | |
| 344 | + ("cos", 1, Some(x)) => Ok(Expr::Float(angle(x).cos())), | |
| 345 | + ("tan", 1, Some(x)) => Ok(Expr::Float(angle(x).tan())), | |
| 346 | + ("asin", 1, Some(x)) => Ok(Expr::Float(from_angle(x.asin()))), | |
| 347 | + ("acos", 1, Some(x)) => Ok(Expr::Float(from_angle(x.acos()))), | |
| 348 | + ("atan", 1, Some(x)) => Ok(Expr::Float(from_angle(x.atan()))), | |
| 349 | + ("sinh", 1, Some(x)) => Ok(Expr::Float(x.sinh())), | |
| 350 | + ("cosh", 1, Some(x)) => Ok(Expr::Float(x.cosh())), | |
| 351 | + ("tanh", 1, Some(x)) => Ok(Expr::Float(x.tanh())), | |
| 352 | + ("asinh", 1, Some(x)) => Ok(Expr::Float(x.asinh())), | |
| 353 | + ("acosh", 1, Some(x)) => Ok(Expr::Float(x.acosh())), | |
| 354 | + ("atanh", 1, Some(x)) => Ok(Expr::Float(x.atanh())), | |
| 355 | + | |
| 356 | + // Exponential/logarithmic | |
| 357 | + ("exp", 1, Some(x)) => Ok(Expr::Float(x.exp())), | |
| 358 | + ("ln", 1, Some(x)) => Ok(Expr::Float(x.ln())), | |
| 359 | + ("log", 1, Some(x)) => Ok(Expr::Float(x.log10())), | |
| 360 | + ("log10", 1, Some(x)) => Ok(Expr::Float(x.log10())), | |
| 361 | + ("log2", 1, Some(x)) => Ok(Expr::Float(x.log2())), | |
| 362 | + | |
| 363 | + // Roots | |
| 364 | + ("sqrt", 1, Some(x)) => { | |
| 365 | + if x >= 0.0 { | |
| 366 | + Ok(Expr::Float(x.sqrt())) | |
| 367 | + } else { | |
| 368 | + Ok(Expr::Complex(0.0, (-x).sqrt())) | |
| 369 | + } | |
| 370 | + } | |
| 371 | + ("cbrt", 1, Some(x)) => Ok(Expr::Float(x.cbrt())), | |
| 372 | + | |
| 373 | + // Other | |
| 374 | + ("abs", 1, Some(x)) => Ok(Expr::Float(x.abs())), | |
| 375 | + ("floor", 1, Some(x)) => Ok(Expr::Integer(x.floor() as i64)), | |
| 376 | + ("ceil", 1, Some(x)) => Ok(Expr::Integer(x.ceil() as i64)), | |
| 377 | + ("round", 1, Some(x)) => Ok(Expr::Integer(x.round() as i64)), | |
| 378 | + ("sign", 1, Some(x)) => Ok(Expr::Integer(if x > 0.0 { | |
| 379 | + 1 | |
| 380 | + } else if x < 0.0 { | |
| 381 | + -1 | |
| 382 | + } else { | |
| 383 | + 0 | |
| 384 | + })), | |
| 385 | + | |
| 386 | + // Factorial | |
| 387 | + ("factorial", 1, _) => { | |
| 388 | + if let Expr::Integer(n) = &args[0] { | |
| 389 | + if *n < 0 { | |
| 390 | + Err(CasError::Domain("factorial of negative number".to_string())) | |
| 391 | + } else if *n > 20 { | |
| 392 | + // Use Stirling's approximation for large n | |
| 393 | + Ok(Expr::Float(gamma(*n as f64 + 1.0))) | |
| 394 | + } else { | |
| 395 | + Ok(Expr::Integer(factorial(*n as u64) as i64)) | |
| 396 | + } | |
| 397 | + } else if let Some(x) = arg { | |
| 398 | + Ok(Expr::Float(gamma(x + 1.0))) | |
| 399 | + } else { | |
| 400 | + Err(CasError::Type("factorial requires a number".to_string())) | |
| 401 | + } | |
| 402 | + } | |
| 403 | + | |
| 404 | + // Two-argument functions | |
| 405 | + ("atan2", 2, _) => { | |
| 406 | + let y = self.to_f64(&args[0])?; | |
| 407 | + let x = self.to_f64(&args[1])?; | |
| 408 | + Ok(Expr::Float(from_angle(y.atan2(x)))) | |
| 409 | + } | |
| 410 | + ("pow", 2, _) => self.power(&args[0], &args[1]), | |
| 411 | + ("mod", 2, _) => { | |
| 412 | + let a = self.to_f64(&args[0])?; | |
| 413 | + let b = self.to_f64(&args[1])?; | |
| 414 | + Ok(Expr::Float(a % b)) | |
| 415 | + } | |
| 416 | + ("min", _, _) if !args.is_empty() => { | |
| 417 | + let values: Result<Vec<f64>> = args.iter().map(|a| self.to_f64(a)).collect(); | |
| 418 | + let min = values? | |
| 419 | + .into_iter() | |
| 420 | + .fold(f64::INFINITY, |a, b| a.min(b)); | |
| 421 | + Ok(Expr::Float(min)) | |
| 422 | + } | |
| 423 | + ("max", _, _) if !args.is_empty() => { | |
| 424 | + let values: Result<Vec<f64>> = args.iter().map(|a| self.to_f64(a)).collect(); | |
| 425 | + let max = values? | |
| 426 | + .into_iter() | |
| 427 | + .fold(f64::NEG_INFINITY, |a, b| a.max(b)); | |
| 428 | + Ok(Expr::Float(max)) | |
| 429 | + } | |
| 430 | + ("gcd", 2, _) => { | |
| 431 | + if let (Expr::Integer(a), Expr::Integer(b)) = (&args[0], &args[1]) { | |
| 432 | + Ok(Expr::Integer(gcd(*a, *b))) | |
| 433 | + } else { | |
| 434 | + Err(CasError::Type("gcd requires integers".to_string())) | |
| 435 | + } | |
| 436 | + } | |
| 437 | + ("lcm", 2, _) => { | |
| 438 | + if let (Expr::Integer(a), Expr::Integer(b)) = (&args[0], &args[1]) { | |
| 439 | + Ok(Expr::Integer(lcm(*a, *b))) | |
| 440 | + } else { | |
| 441 | + Err(CasError::Type("lcm requires integers".to_string())) | |
| 442 | + } | |
| 443 | + } | |
| 444 | + | |
| 445 | + _ => { | |
| 446 | + if self.exact_mode { | |
| 447 | + Ok(Expr::func(name, args.to_vec())) | |
| 448 | + } else { | |
| 449 | + Err(CasError::UndefinedFunction(name.to_string())) | |
| 450 | + } | |
| 451 | + } | |
| 452 | + } | |
| 453 | + } | |
| 454 | +} | |
| 455 | + | |
| 456 | +// Helper functions | |
| 457 | + | |
| 458 | +fn factorial(n: u64) -> u64 { | |
| 459 | + (1..=n).product() | |
| 460 | +} | |
| 461 | + | |
| 462 | +fn gcd(a: i64, b: i64) -> i64 { | |
| 463 | + let (a, b) = (a.abs(), b.abs()); | |
| 464 | + if b == 0 { | |
| 465 | + a | |
| 466 | + } else { | |
| 467 | + gcd(b, a % b) | |
| 468 | + } | |
| 469 | +} | |
| 470 | + | |
| 471 | +fn lcm(a: i64, b: i64) -> i64 { | |
| 472 | + (a * b).abs() / gcd(a, b) | |
| 473 | +} | |
| 474 | + | |
| 475 | +/// Gamma function approximation (Lanczos) | |
| 476 | +fn gamma(x: f64) -> f64 { | |
| 477 | + if x < 0.5 { | |
| 478 | + PI / (PI * x).sin() / gamma(1.0 - x) | |
| 479 | + } else { | |
| 480 | + let x = x - 1.0; | |
| 481 | + let g = 7.0; | |
| 482 | + let c = [ | |
| 483 | + 0.99999999999980993, | |
| 484 | + 676.5203681218851, | |
| 485 | + -1259.1392167224028, | |
| 486 | + 771.32342877765313, | |
| 487 | + -176.61502916214059, | |
| 488 | + 12.507343278686905, | |
| 489 | + -0.13857109526572012, | |
| 490 | + 9.9843695780195716e-6, | |
| 491 | + 1.5056327351493116e-7, | |
| 492 | + ]; | |
| 493 | + | |
| 494 | + let mut sum = c[0]; | |
| 495 | + for (i, &ci) in c.iter().enumerate().skip(1) { | |
| 496 | + sum += ci / (x + i as f64); | |
| 497 | + } | |
| 498 | + | |
| 499 | + (2.0 * PI).sqrt() * (x + g + 0.5).powf(x + 0.5) * (-(x + g + 0.5)).exp() * sum | |
| 500 | + } | |
| 501 | +} | |
| 502 | + | |
| 503 | +#[cfg(test)] | |
| 504 | +mod tests { | |
| 505 | + use super::*; | |
| 506 | + use crate::parser::parse; | |
| 507 | + | |
| 508 | + fn eval(input: &str) -> Result<Expr> { | |
| 509 | + let expr = parse(input)?; | |
| 510 | + Evaluator::new().eval(&expr) | |
| 511 | + } | |
| 512 | + | |
| 513 | + fn eval_to_f64(input: &str) -> f64 { | |
| 514 | + match eval(input).unwrap() { | |
| 515 | + Expr::Integer(n) => n as f64, | |
| 516 | + Expr::Float(x) => x, | |
| 517 | + other => panic!("expected number, got {other}"), | |
| 518 | + } | |
| 519 | + } | |
| 520 | + | |
| 521 | + #[test] | |
| 522 | + fn test_arithmetic() { | |
| 523 | + assert_eq!(eval("2 + 3").unwrap(), Expr::Integer(5)); | |
| 524 | + assert_eq!(eval("10 - 4").unwrap(), Expr::Integer(6)); | |
| 525 | + assert_eq!(eval("3 * 4").unwrap(), Expr::Integer(12)); | |
| 526 | + assert_eq!(eval("2^10").unwrap(), Expr::Integer(1024)); | |
| 527 | + } | |
| 528 | + | |
| 529 | + #[test] | |
| 530 | + fn test_float() { | |
| 531 | + let result = eval_to_f64("3.14 * 2"); | |
| 532 | + assert!((result - 6.28).abs() < 1e-10); | |
| 533 | + } | |
| 534 | + | |
| 535 | + #[test] | |
| 536 | + fn test_functions() { | |
| 537 | + let result = eval_to_f64("sin(0)"); | |
| 538 | + assert!(result.abs() < 1e-10); | |
| 539 | + | |
| 540 | + let result = eval_to_f64("cos(0)"); | |
| 541 | + assert!((result - 1.0).abs() < 1e-10); | |
| 542 | + | |
| 543 | + let result = eval_to_f64("sqrt(4)"); | |
| 544 | + assert!((result - 2.0).abs() < 1e-10); | |
| 545 | + | |
| 546 | + let result = eval_to_f64("ln(e)"); | |
| 547 | + assert!((result - 1.0).abs() < 1e-10); | |
| 548 | + } | |
| 549 | + | |
| 550 | + #[test] | |
| 551 | + fn test_constants() { | |
| 552 | + let result = eval_to_f64("pi"); | |
| 553 | + assert!((result - PI).abs() < 1e-10); | |
| 554 | + | |
| 555 | + let result = eval_to_f64("e"); | |
| 556 | + assert!((result - E).abs() < 1e-10); | |
| 557 | + } | |
| 558 | + | |
| 559 | + #[test] | |
| 560 | + fn test_factorial() { | |
| 561 | + assert_eq!(eval("5!").unwrap(), Expr::Integer(120)); | |
| 562 | + } | |
| 563 | + | |
| 564 | + #[test] | |
| 565 | + fn test_complex_expr() { | |
| 566 | + let result = eval_to_f64("2 * sin(pi/2) + 1"); | |
| 567 | + assert!((result - 3.0).abs() < 1e-10); | |
| 568 | + } | |
| 569 | + | |
| 570 | + #[test] | |
| 571 | + fn test_variables() { | |
| 572 | + let expr = parse("x + 1").unwrap(); | |
| 573 | + let mut evaluator = Evaluator::new(); | |
| 574 | + evaluator.set_var("x", Expr::integer(5)); | |
| 575 | + let result = evaluator.eval(&expr).unwrap(); | |
| 576 | + assert_eq!(result, Expr::Integer(6)); | |
| 577 | + } | |
| 578 | +} | |
garcalc-cas/src/expr.rsadded@@ -0,0 +1,497 @@ | ||
| 1 | +use serde::{Deserialize, Serialize}; | |
| 2 | +use std::fmt; | |
| 3 | + | |
| 4 | +/// A symbol (variable name) | |
| 5 | +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] | |
| 6 | +pub struct Symbol(pub String); | |
| 7 | + | |
| 8 | +impl Symbol { | |
| 9 | + pub fn new(name: impl Into<String>) -> Self { | |
| 10 | + Self(name.into()) | |
| 11 | + } | |
| 12 | + | |
| 13 | + pub fn as_str(&self) -> &str { | |
| 14 | + &self.0 | |
| 15 | + } | |
| 16 | +} | |
| 17 | + | |
| 18 | +impl fmt::Display for Symbol { | |
| 19 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| 20 | + write!(f, "{}", self.0) | |
| 21 | + } | |
| 22 | +} | |
| 23 | + | |
| 24 | +impl From<&str> for Symbol { | |
| 25 | + fn from(s: &str) -> Self { | |
| 26 | + Self(s.to_string()) | |
| 27 | + } | |
| 28 | +} | |
| 29 | + | |
| 30 | +/// A rational number (numerator/denominator) | |
| 31 | +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] | |
| 32 | +pub struct Rational { | |
| 33 | + pub num: i64, | |
| 34 | + pub den: i64, | |
| 35 | +} | |
| 36 | + | |
| 37 | +impl Rational { | |
| 38 | + pub fn new(num: i64, den: i64) -> Self { | |
| 39 | + let g = gcd(num.abs(), den.abs()); | |
| 40 | + let sign = if den < 0 { -1 } else { 1 }; | |
| 41 | + Self { | |
| 42 | + num: sign * num / g, | |
| 43 | + den: den.abs() / g, | |
| 44 | + } | |
| 45 | + } | |
| 46 | + | |
| 47 | + pub fn integer(n: i64) -> Self { | |
| 48 | + Self { num: n, den: 1 } | |
| 49 | + } | |
| 50 | + | |
| 51 | + pub fn is_integer(&self) -> bool { | |
| 52 | + self.den == 1 | |
| 53 | + } | |
| 54 | + | |
| 55 | + pub fn to_f64(self) -> f64 { | |
| 56 | + self.num as f64 / self.den as f64 | |
| 57 | + } | |
| 58 | +} | |
| 59 | + | |
| 60 | +impl fmt::Display for Rational { | |
| 61 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| 62 | + if self.den == 1 { | |
| 63 | + write!(f, "{}", self.num) | |
| 64 | + } else { | |
| 65 | + write!(f, "{}/{}", self.num, self.den) | |
| 66 | + } | |
| 67 | + } | |
| 68 | +} | |
| 69 | + | |
| 70 | +fn gcd(a: i64, b: i64) -> i64 { | |
| 71 | + if b == 0 { | |
| 72 | + a | |
| 73 | + } else { | |
| 74 | + gcd(b, a % b) | |
| 75 | + } | |
| 76 | +} | |
| 77 | + | |
| 78 | +/// Mathematical expression AST | |
| 79 | +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | |
| 80 | +pub enum Expr { | |
| 81 | + // Atoms | |
| 82 | + /// Integer value | |
| 83 | + Integer(i64), | |
| 84 | + /// Rational number | |
| 85 | + Rational(Rational), | |
| 86 | + /// Floating point | |
| 87 | + Float(f64), | |
| 88 | + /// Complex number (real, imaginary) | |
| 89 | + Complex(f64, f64), | |
| 90 | + /// Symbolic variable | |
| 91 | + Symbol(Symbol), | |
| 92 | + | |
| 93 | + // Arithmetic operations | |
| 94 | + /// Negation: -x | |
| 95 | + Neg(Box<Expr>), | |
| 96 | + /// Addition: a + b + c + ... | |
| 97 | + Add(Vec<Expr>), | |
| 98 | + /// Multiplication: a * b * c * ... | |
| 99 | + Mul(Vec<Expr>), | |
| 100 | + /// Power: base^exponent | |
| 101 | + Pow(Box<Expr>, Box<Expr>), | |
| 102 | + | |
| 103 | + // Function application | |
| 104 | + /// Function call: f(args...) | |
| 105 | + Func(String, Vec<Expr>), | |
| 106 | + | |
| 107 | + // Calculus (symbolic) | |
| 108 | + /// Derivative: d/dx(expr) with order | |
| 109 | + Derivative { | |
| 110 | + expr: Box<Expr>, | |
| 111 | + var: Symbol, | |
| 112 | + order: u32, | |
| 113 | + }, | |
| 114 | + /// Integral: ∫ expr dx, optionally with bounds | |
| 115 | + Integral { | |
| 116 | + expr: Box<Expr>, | |
| 117 | + var: Symbol, | |
| 118 | + lower: Option<Box<Expr>>, | |
| 119 | + upper: Option<Box<Expr>>, | |
| 120 | + }, | |
| 121 | + /// Limit: lim_{x→point} expr | |
| 122 | + Limit { | |
| 123 | + expr: Box<Expr>, | |
| 124 | + var: Symbol, | |
| 125 | + point: Box<Expr>, | |
| 126 | + direction: Option<LimitDirection>, | |
| 127 | + }, | |
| 128 | + /// Summation: Σ_{var=lower}^{upper} expr | |
| 129 | + Sum { | |
| 130 | + expr: Box<Expr>, | |
| 131 | + var: Symbol, | |
| 132 | + lower: Box<Expr>, | |
| 133 | + upper: Box<Expr>, | |
| 134 | + }, | |
| 135 | + /// Product: Π_{var=lower}^{upper} expr | |
| 136 | + Product { | |
| 137 | + expr: Box<Expr>, | |
| 138 | + var: Symbol, | |
| 139 | + lower: Box<Expr>, | |
| 140 | + upper: Box<Expr>, | |
| 141 | + }, | |
| 142 | + | |
| 143 | + // Algebra | |
| 144 | + /// Equation: lhs = rhs | |
| 145 | + Equation(Box<Expr>, Box<Expr>), | |
| 146 | + /// Inequality: lhs < rhs (or >, <=, >=, !=) | |
| 147 | + Inequality { | |
| 148 | + lhs: Box<Expr>, | |
| 149 | + op: InequalityOp, | |
| 150 | + rhs: Box<Expr>, | |
| 151 | + }, | |
| 152 | + | |
| 153 | + // Linear algebra | |
| 154 | + /// Matrix: rows of expressions | |
| 155 | + Matrix(Vec<Vec<Expr>>), | |
| 156 | + /// Vector: list of expressions | |
| 157 | + Vector(Vec<Expr>), | |
| 158 | + | |
| 159 | + // Special values | |
| 160 | + /// Undefined result | |
| 161 | + Undefined, | |
| 162 | + /// Positive or negative infinity | |
| 163 | + Infinity(Sign), | |
| 164 | +} | |
| 165 | + | |
| 166 | +/// Direction for limits | |
| 167 | +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] | |
| 168 | +pub enum LimitDirection { | |
| 169 | + Left, | |
| 170 | + Right, | |
| 171 | +} | |
| 172 | + | |
| 173 | +/// Sign for infinity | |
| 174 | +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] | |
| 175 | +pub enum Sign { | |
| 176 | + Positive, | |
| 177 | + Negative, | |
| 178 | +} | |
| 179 | + | |
| 180 | +/// Inequality operators | |
| 181 | +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] | |
| 182 | +pub enum InequalityOp { | |
| 183 | + Lt, | |
| 184 | + Le, | |
| 185 | + Gt, | |
| 186 | + Ge, | |
| 187 | + Ne, | |
| 188 | +} | |
| 189 | + | |
| 190 | +impl Expr { | |
| 191 | + // Constructors for common expressions | |
| 192 | + | |
| 193 | + pub fn integer(n: i64) -> Self { | |
| 194 | + Self::Integer(n) | |
| 195 | + } | |
| 196 | + | |
| 197 | + pub fn float(x: f64) -> Self { | |
| 198 | + Self::Float(x) | |
| 199 | + } | |
| 200 | + | |
| 201 | + pub fn symbol(name: impl Into<String>) -> Self { | |
| 202 | + Self::Symbol(Symbol::new(name)) | |
| 203 | + } | |
| 204 | + | |
| 205 | + pub fn neg(expr: Expr) -> Self { | |
| 206 | + Self::Neg(Box::new(expr)) | |
| 207 | + } | |
| 208 | + | |
| 209 | + pub fn add(terms: Vec<Expr>) -> Self { | |
| 210 | + if terms.len() == 1 { | |
| 211 | + terms.into_iter().next().unwrap() | |
| 212 | + } else { | |
| 213 | + Self::Add(terms) | |
| 214 | + } | |
| 215 | + } | |
| 216 | + | |
| 217 | + pub fn sub(a: Expr, b: Expr) -> Self { | |
| 218 | + Self::add(vec![a, Self::neg(b)]) | |
| 219 | + } | |
| 220 | + | |
| 221 | + pub fn mul(factors: Vec<Expr>) -> Self { | |
| 222 | + if factors.len() == 1 { | |
| 223 | + factors.into_iter().next().unwrap() | |
| 224 | + } else { | |
| 225 | + Self::Mul(factors) | |
| 226 | + } | |
| 227 | + } | |
| 228 | + | |
| 229 | + pub fn div(a: Expr, b: Expr) -> Self { | |
| 230 | + Self::mul(vec![a, Self::pow(b, Self::integer(-1))]) | |
| 231 | + } | |
| 232 | + | |
| 233 | + pub fn pow(base: Expr, exp: Expr) -> Self { | |
| 234 | + Self::Pow(Box::new(base), Box::new(exp)) | |
| 235 | + } | |
| 236 | + | |
| 237 | + pub fn func(name: impl Into<String>, args: Vec<Expr>) -> Self { | |
| 238 | + Self::Func(name.into(), args) | |
| 239 | + } | |
| 240 | + | |
| 241 | + // Common mathematical constants | |
| 242 | + pub fn pi() -> Self { | |
| 243 | + Self::symbol("pi") | |
| 244 | + } | |
| 245 | + | |
| 246 | + pub fn e() -> Self { | |
| 247 | + Self::symbol("e") | |
| 248 | + } | |
| 249 | + | |
| 250 | + pub fn i() -> Self { | |
| 251 | + Self::Complex(0.0, 1.0) | |
| 252 | + } | |
| 253 | + | |
| 254 | + // Predicate methods | |
| 255 | + | |
| 256 | + pub fn is_zero(&self) -> bool { | |
| 257 | + match self { | |
| 258 | + Self::Integer(0) => true, | |
| 259 | + Self::Float(x) if *x == 0.0 => true, | |
| 260 | + _ => false, | |
| 261 | + } | |
| 262 | + } | |
| 263 | + | |
| 264 | + pub fn is_one(&self) -> bool { | |
| 265 | + match self { | |
| 266 | + Self::Integer(1) => true, | |
| 267 | + Self::Float(x) if *x == 1.0 => true, | |
| 268 | + _ => false, | |
| 269 | + } | |
| 270 | + } | |
| 271 | + | |
| 272 | + pub fn is_negative_one(&self) -> bool { | |
| 273 | + match self { | |
| 274 | + Self::Integer(-1) => true, | |
| 275 | + Self::Float(x) if *x == -1.0 => true, | |
| 276 | + _ => false, | |
| 277 | + } | |
| 278 | + } | |
| 279 | + | |
| 280 | + pub fn is_number(&self) -> bool { | |
| 281 | + matches!( | |
| 282 | + self, | |
| 283 | + Self::Integer(_) | Self::Rational(_) | Self::Float(_) | Self::Complex(_, _) | |
| 284 | + ) | |
| 285 | + } | |
| 286 | + | |
| 287 | + pub fn is_symbol(&self) -> bool { | |
| 288 | + matches!(self, Self::Symbol(_)) | |
| 289 | + } | |
| 290 | + | |
| 291 | + /// Check if expression contains a variable | |
| 292 | + pub fn contains_var(&self, var: &Symbol) -> bool { | |
| 293 | + match self { | |
| 294 | + Self::Symbol(s) => s == var, | |
| 295 | + Self::Neg(e) => e.contains_var(var), | |
| 296 | + Self::Add(terms) => terms.iter().any(|t| t.contains_var(var)), | |
| 297 | + Self::Mul(factors) => factors.iter().any(|f| f.contains_var(var)), | |
| 298 | + Self::Pow(base, exp) => base.contains_var(var) || exp.contains_var(var), | |
| 299 | + Self::Func(_, args) => args.iter().any(|a| a.contains_var(var)), | |
| 300 | + Self::Derivative { expr, .. } => expr.contains_var(var), | |
| 301 | + Self::Integral { expr, .. } => expr.contains_var(var), | |
| 302 | + Self::Limit { expr, point, .. } => expr.contains_var(var) || point.contains_var(var), | |
| 303 | + Self::Sum { expr, lower, upper, .. } | Self::Product { expr, lower, upper, .. } => { | |
| 304 | + expr.contains_var(var) || lower.contains_var(var) || upper.contains_var(var) | |
| 305 | + } | |
| 306 | + Self::Equation(lhs, rhs) => lhs.contains_var(var) || rhs.contains_var(var), | |
| 307 | + Self::Inequality { lhs, rhs, .. } => lhs.contains_var(var) || rhs.contains_var(var), | |
| 308 | + Self::Matrix(rows) => rows.iter().any(|row| row.iter().any(|e| e.contains_var(var))), | |
| 309 | + Self::Vector(elems) => elems.iter().any(|e| e.contains_var(var)), | |
| 310 | + _ => false, | |
| 311 | + } | |
| 312 | + } | |
| 313 | +} | |
| 314 | + | |
| 315 | +impl fmt::Display for Expr { | |
| 316 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| 317 | + match self { | |
| 318 | + Self::Integer(n) => write!(f, "{n}"), | |
| 319 | + Self::Rational(r) => write!(f, "{r}"), | |
| 320 | + Self::Float(x) => { | |
| 321 | + if x.fract() == 0.0 && x.abs() < 1e15 { | |
| 322 | + write!(f, "{}", *x as i64) | |
| 323 | + } else { | |
| 324 | + write!(f, "{x}") | |
| 325 | + } | |
| 326 | + } | |
| 327 | + Self::Complex(re, im) => { | |
| 328 | + if *re == 0.0 { | |
| 329 | + if *im == 1.0 { | |
| 330 | + write!(f, "i") | |
| 331 | + } else if *im == -1.0 { | |
| 332 | + write!(f, "-i") | |
| 333 | + } else { | |
| 334 | + write!(f, "{im}i") | |
| 335 | + } | |
| 336 | + } else if *im >= 0.0 { | |
| 337 | + write!(f, "{re}+{im}i") | |
| 338 | + } else { | |
| 339 | + write!(f, "{re}{im}i") | |
| 340 | + } | |
| 341 | + } | |
| 342 | + Self::Symbol(s) => write!(f, "{s}"), | |
| 343 | + Self::Neg(e) => write!(f, "-{e}"), | |
| 344 | + Self::Add(terms) => { | |
| 345 | + if terms.is_empty() { | |
| 346 | + write!(f, "0") | |
| 347 | + } else { | |
| 348 | + write!(f, "(")?; | |
| 349 | + for (i, term) in terms.iter().enumerate() { | |
| 350 | + if i > 0 { | |
| 351 | + write!(f, "+")?; | |
| 352 | + } | |
| 353 | + write!(f, "{term}")?; | |
| 354 | + } | |
| 355 | + write!(f, ")") | |
| 356 | + } | |
| 357 | + } | |
| 358 | + Self::Mul(factors) => { | |
| 359 | + if factors.is_empty() { | |
| 360 | + write!(f, "1") | |
| 361 | + } else { | |
| 362 | + write!(f, "(")?; | |
| 363 | + for (i, factor) in factors.iter().enumerate() { | |
| 364 | + if i > 0 { | |
| 365 | + write!(f, "*")?; | |
| 366 | + } | |
| 367 | + write!(f, "{factor}")?; | |
| 368 | + } | |
| 369 | + write!(f, ")") | |
| 370 | + } | |
| 371 | + } | |
| 372 | + Self::Pow(base, exp) => write!(f, "{base}^{exp}"), | |
| 373 | + Self::Func(name, args) => { | |
| 374 | + write!(f, "{name}(")?; | |
| 375 | + for (i, arg) in args.iter().enumerate() { | |
| 376 | + if i > 0 { | |
| 377 | + write!(f, ", ")?; | |
| 378 | + } | |
| 379 | + write!(f, "{arg}")?; | |
| 380 | + } | |
| 381 | + write!(f, ")") | |
| 382 | + } | |
| 383 | + Self::Derivative { expr, var, order } => { | |
| 384 | + if *order == 1 { | |
| 385 | + write!(f, "d/d{var}({expr})") | |
| 386 | + } else { | |
| 387 | + write!(f, "d^{order}/d{var}^{order}({expr})") | |
| 388 | + } | |
| 389 | + } | |
| 390 | + Self::Integral { | |
| 391 | + expr, | |
| 392 | + var, | |
| 393 | + lower, | |
| 394 | + upper, | |
| 395 | + } => { | |
| 396 | + if let (Some(l), Some(u)) = (lower, upper) { | |
| 397 | + write!(f, "integrate({expr}, {var}, {l}, {u})") | |
| 398 | + } else { | |
| 399 | + write!(f, "integrate({expr}, {var})") | |
| 400 | + } | |
| 401 | + } | |
| 402 | + Self::Limit { | |
| 403 | + expr, | |
| 404 | + var, | |
| 405 | + point, | |
| 406 | + direction, | |
| 407 | + } => { | |
| 408 | + let dir = match direction { | |
| 409 | + Some(LimitDirection::Left) => "-", | |
| 410 | + Some(LimitDirection::Right) => "+", | |
| 411 | + None => "", | |
| 412 | + }; | |
| 413 | + write!(f, "lim({expr}, {var}, {point}{dir})") | |
| 414 | + } | |
| 415 | + Self::Sum { | |
| 416 | + expr, | |
| 417 | + var, | |
| 418 | + lower, | |
| 419 | + upper, | |
| 420 | + } => write!(f, "sum({expr}, {var}, {lower}, {upper})"), | |
| 421 | + Self::Product { | |
| 422 | + expr, | |
| 423 | + var, | |
| 424 | + lower, | |
| 425 | + upper, | |
| 426 | + } => write!(f, "product({expr}, {var}, {lower}, {upper})"), | |
| 427 | + Self::Equation(lhs, rhs) => write!(f, "{lhs} = {rhs}"), | |
| 428 | + Self::Inequality { lhs, op, rhs } => { | |
| 429 | + let op_str = match op { | |
| 430 | + InequalityOp::Lt => "<", | |
| 431 | + InequalityOp::Le => "<=", | |
| 432 | + InequalityOp::Gt => ">", | |
| 433 | + InequalityOp::Ge => ">=", | |
| 434 | + InequalityOp::Ne => "!=", | |
| 435 | + }; | |
| 436 | + write!(f, "{lhs} {op_str} {rhs}") | |
| 437 | + } | |
| 438 | + Self::Matrix(rows) => { | |
| 439 | + write!(f, "[")?; | |
| 440 | + for (i, row) in rows.iter().enumerate() { | |
| 441 | + if i > 0 { | |
| 442 | + write!(f, "; ")?; | |
| 443 | + } | |
| 444 | + for (j, elem) in row.iter().enumerate() { | |
| 445 | + if j > 0 { | |
| 446 | + write!(f, ", ")?; | |
| 447 | + } | |
| 448 | + write!(f, "{elem}")?; | |
| 449 | + } | |
| 450 | + } | |
| 451 | + write!(f, "]") | |
| 452 | + } | |
| 453 | + Self::Vector(elems) => { | |
| 454 | + write!(f, "[")?; | |
| 455 | + for (i, elem) in elems.iter().enumerate() { | |
| 456 | + if i > 0 { | |
| 457 | + write!(f, ", ")?; | |
| 458 | + } | |
| 459 | + write!(f, "{elem}")?; | |
| 460 | + } | |
| 461 | + write!(f, "]") | |
| 462 | + } | |
| 463 | + Self::Undefined => write!(f, "undefined"), | |
| 464 | + Self::Infinity(Sign::Positive) => write!(f, "infinity"), | |
| 465 | + Self::Infinity(Sign::Negative) => write!(f, "-infinity"), | |
| 466 | + } | |
| 467 | + } | |
| 468 | +} | |
| 469 | + | |
| 470 | +#[cfg(test)] | |
| 471 | +mod tests { | |
| 472 | + use super::*; | |
| 473 | + | |
| 474 | + #[test] | |
| 475 | + fn test_rational() { | |
| 476 | + let r = Rational::new(4, 6); | |
| 477 | + assert_eq!(r.num, 2); | |
| 478 | + assert_eq!(r.den, 3); | |
| 479 | + assert_eq!(r.to_string(), "2/3"); | |
| 480 | + } | |
| 481 | + | |
| 482 | + #[test] | |
| 483 | + fn test_expr_display() { | |
| 484 | + let expr = Expr::add(vec![ | |
| 485 | + Expr::symbol("x"), | |
| 486 | + Expr::mul(vec![Expr::integer(2), Expr::symbol("y")]), | |
| 487 | + ]); | |
| 488 | + assert_eq!(expr.to_string(), "(x+(2*y))"); | |
| 489 | + } | |
| 490 | + | |
| 491 | + #[test] | |
| 492 | + fn test_contains_var() { | |
| 493 | + let expr = Expr::add(vec![Expr::symbol("x"), Expr::integer(1)]); | |
| 494 | + assert!(expr.contains_var(&Symbol::new("x"))); | |
| 495 | + assert!(!expr.contains_var(&Symbol::new("y"))); | |
| 496 | + } | |
| 497 | +} | |
garcalc-cas/src/lib.rsadded@@ -0,0 +1,18 @@ | ||
| 1 | +//! garcalc-cas: Computer Algebra System for garcalc | |
| 2 | +//! | |
| 3 | +//! A custom CAS engine providing: | |
| 4 | +//! - Expression parsing and representation | |
| 5 | +//! - Numeric and symbolic evaluation | |
| 6 | +//! - Symbolic differentiation and integration | |
| 7 | +//! - Equation solving | |
| 8 | +//! - Limits and series expansions | |
| 9 | + | |
| 10 | +pub mod expr; | |
| 11 | +pub mod parser; | |
| 12 | +pub mod eval; | |
| 13 | +pub mod error; | |
| 14 | + | |
| 15 | +pub use expr::{Expr, Symbol, Rational}; | |
| 16 | +pub use parser::Parser; | |
| 17 | +pub use eval::Evaluator; | |
| 18 | +pub use error::{CasError, Result}; | |
garcalc-cas/src/parser.rsadded@@ -0,0 +1,678 @@ | ||
| 1 | +//! Expression parser using recursive descent | |
| 2 | +//! | |
| 3 | +//! Grammar: | |
| 4 | +//! ```text | |
| 5 | +//! expr = equation | |
| 6 | +//! equation = additive ("=" additive)? | |
| 7 | +//! additive = multiplicative (("+"|"-") multiplicative)* | |
| 8 | +//! multiplicative = power (("*"|"/") power)* | |
| 9 | +//! power = unary ("^" power)? | |
| 10 | +//! unary = "-" unary | postfix | |
| 11 | +//! postfix = primary ("!")* | |
| 12 | +//! primary = NUMBER | SYMBOL | "(" expr ")" | function | "[" matrix "]" | |
| 13 | +//! function = SYMBOL "(" args? ")" | |
| 14 | +//! args = expr ("," expr)* | |
| 15 | +//! matrix = row (";" row)* | |
| 16 | +//! row = expr ("," expr)* | |
| 17 | +//! ``` | |
| 18 | + | |
| 19 | +use crate::error::{CasError, Result}; | |
| 20 | +use crate::expr::{Expr, Symbol}; | |
| 21 | + | |
| 22 | +/// Token types | |
| 23 | +#[derive(Debug, Clone, PartialEq)] | |
| 24 | +pub enum Token { | |
| 25 | + Number(f64), | |
| 26 | + Symbol(String), | |
| 27 | + Plus, | |
| 28 | + Minus, | |
| 29 | + Star, | |
| 30 | + Slash, | |
| 31 | + Caret, | |
| 32 | + Bang, | |
| 33 | + Eq, | |
| 34 | + LParen, | |
| 35 | + RParen, | |
| 36 | + LBracket, | |
| 37 | + RBracket, | |
| 38 | + Comma, | |
| 39 | + Semicolon, | |
| 40 | + Eof, | |
| 41 | +} | |
| 42 | + | |
| 43 | +/// Tokenizer for mathematical expressions | |
| 44 | +pub struct Lexer<'a> { | |
| 45 | + input: &'a str, | |
| 46 | + pos: usize, | |
| 47 | +} | |
| 48 | + | |
| 49 | +impl<'a> Lexer<'a> { | |
| 50 | + pub fn new(input: &'a str) -> Self { | |
| 51 | + Self { input, pos: 0 } | |
| 52 | + } | |
| 53 | + | |
| 54 | + fn peek_char(&self) -> Option<char> { | |
| 55 | + self.input[self.pos..].chars().next() | |
| 56 | + } | |
| 57 | + | |
| 58 | + fn advance(&mut self) -> Option<char> { | |
| 59 | + let c = self.peek_char()?; | |
| 60 | + self.pos += c.len_utf8(); | |
| 61 | + Some(c) | |
| 62 | + } | |
| 63 | + | |
| 64 | + fn skip_whitespace(&mut self) { | |
| 65 | + while let Some(c) = self.peek_char() { | |
| 66 | + if c.is_whitespace() { | |
| 67 | + self.advance(); | |
| 68 | + } else { | |
| 69 | + break; | |
| 70 | + } | |
| 71 | + } | |
| 72 | + } | |
| 73 | + | |
| 74 | + fn read_number(&mut self) -> Token { | |
| 75 | + let start = self.pos; | |
| 76 | + let mut has_dot = false; | |
| 77 | + let mut has_exp = false; | |
| 78 | + | |
| 79 | + while let Some(c) = self.peek_char() { | |
| 80 | + if c.is_ascii_digit() { | |
| 81 | + self.advance(); | |
| 82 | + } else if c == '.' && !has_dot && !has_exp { | |
| 83 | + has_dot = true; | |
| 84 | + self.advance(); | |
| 85 | + } else if (c == 'e' || c == 'E') && !has_exp { | |
| 86 | + has_exp = true; | |
| 87 | + self.advance(); | |
| 88 | + // Handle optional sign after exponent | |
| 89 | + if let Some(sign) = self.peek_char() { | |
| 90 | + if sign == '+' || sign == '-' { | |
| 91 | + self.advance(); | |
| 92 | + } | |
| 93 | + } | |
| 94 | + } else { | |
| 95 | + break; | |
| 96 | + } | |
| 97 | + } | |
| 98 | + | |
| 99 | + let num_str = &self.input[start..self.pos]; | |
| 100 | + let value: f64 = num_str.parse().unwrap_or(f64::NAN); | |
| 101 | + Token::Number(value) | |
| 102 | + } | |
| 103 | + | |
| 104 | + fn read_symbol(&mut self) -> Token { | |
| 105 | + let start = self.pos; | |
| 106 | + | |
| 107 | + while let Some(c) = self.peek_char() { | |
| 108 | + if c.is_alphanumeric() || c == '_' { | |
| 109 | + self.advance(); | |
| 110 | + } else { | |
| 111 | + break; | |
| 112 | + } | |
| 113 | + } | |
| 114 | + | |
| 115 | + let name = self.input[start..self.pos].to_string(); | |
| 116 | + Token::Symbol(name) | |
| 117 | + } | |
| 118 | + | |
| 119 | + pub fn next_token(&mut self) -> Result<Token> { | |
| 120 | + self.skip_whitespace(); | |
| 121 | + | |
| 122 | + let Some(c) = self.peek_char() else { | |
| 123 | + return Ok(Token::Eof); | |
| 124 | + }; | |
| 125 | + | |
| 126 | + let token = match c { | |
| 127 | + '+' => { | |
| 128 | + self.advance(); | |
| 129 | + Token::Plus | |
| 130 | + } | |
| 131 | + '-' => { | |
| 132 | + self.advance(); | |
| 133 | + Token::Minus | |
| 134 | + } | |
| 135 | + '*' => { | |
| 136 | + self.advance(); | |
| 137 | + Token::Star | |
| 138 | + } | |
| 139 | + '/' => { | |
| 140 | + self.advance(); | |
| 141 | + Token::Slash | |
| 142 | + } | |
| 143 | + '^' => { | |
| 144 | + self.advance(); | |
| 145 | + Token::Caret | |
| 146 | + } | |
| 147 | + '!' => { | |
| 148 | + self.advance(); | |
| 149 | + Token::Bang | |
| 150 | + } | |
| 151 | + '=' => { | |
| 152 | + self.advance(); | |
| 153 | + Token::Eq | |
| 154 | + } | |
| 155 | + '(' => { | |
| 156 | + self.advance(); | |
| 157 | + Token::LParen | |
| 158 | + } | |
| 159 | + ')' => { | |
| 160 | + self.advance(); | |
| 161 | + Token::RParen | |
| 162 | + } | |
| 163 | + '[' => { | |
| 164 | + self.advance(); | |
| 165 | + Token::LBracket | |
| 166 | + } | |
| 167 | + ']' => { | |
| 168 | + self.advance(); | |
| 169 | + Token::RBracket | |
| 170 | + } | |
| 171 | + ',' => { | |
| 172 | + self.advance(); | |
| 173 | + Token::Comma | |
| 174 | + } | |
| 175 | + ';' => { | |
| 176 | + self.advance(); | |
| 177 | + Token::Semicolon | |
| 178 | + } | |
| 179 | + _ if c.is_ascii_digit() || c == '.' => self.read_number(), | |
| 180 | + _ if c.is_alphabetic() || c == '_' => self.read_symbol(), | |
| 181 | + _ => { | |
| 182 | + return Err(CasError::Parse { | |
| 183 | + position: self.pos, | |
| 184 | + message: format!("unexpected character: '{c}'"), | |
| 185 | + }); | |
| 186 | + } | |
| 187 | + }; | |
| 188 | + | |
| 189 | + Ok(token) | |
| 190 | + } | |
| 191 | + | |
| 192 | + pub fn position(&self) -> usize { | |
| 193 | + self.pos | |
| 194 | + } | |
| 195 | +} | |
| 196 | + | |
| 197 | +/// Recursive descent parser | |
| 198 | +pub struct Parser<'a> { | |
| 199 | + lexer: Lexer<'a>, | |
| 200 | + current: Token, | |
| 201 | +} | |
| 202 | + | |
| 203 | +impl<'a> Parser<'a> { | |
| 204 | + pub fn new(input: &'a str) -> Result<Self> { | |
| 205 | + let mut lexer = Lexer::new(input); | |
| 206 | + let current = lexer.next_token()?; | |
| 207 | + Ok(Self { lexer, current }) | |
| 208 | + } | |
| 209 | + | |
| 210 | + fn advance(&mut self) -> Result<Token> { | |
| 211 | + let prev = std::mem::replace(&mut self.current, self.lexer.next_token()?); | |
| 212 | + Ok(prev) | |
| 213 | + } | |
| 214 | + | |
| 215 | + fn expect(&mut self, expected: Token) -> Result<()> { | |
| 216 | + if self.current == expected { | |
| 217 | + self.advance()?; | |
| 218 | + Ok(()) | |
| 219 | + } else { | |
| 220 | + Err(CasError::Parse { | |
| 221 | + position: self.lexer.position(), | |
| 222 | + message: format!("expected {expected:?}, found {:?}", self.current), | |
| 223 | + }) | |
| 224 | + } | |
| 225 | + } | |
| 226 | + | |
| 227 | + /// Parse a complete expression | |
| 228 | + pub fn parse(&mut self) -> Result<Expr> { | |
| 229 | + let expr = self.parse_equation()?; | |
| 230 | + if self.current != Token::Eof { | |
| 231 | + return Err(CasError::Parse { | |
| 232 | + position: self.lexer.position(), | |
| 233 | + message: format!("unexpected token: {:?}", self.current), | |
| 234 | + }); | |
| 235 | + } | |
| 236 | + Ok(expr) | |
| 237 | + } | |
| 238 | + | |
| 239 | + fn parse_equation(&mut self) -> Result<Expr> { | |
| 240 | + let lhs = self.parse_additive()?; | |
| 241 | + if self.current == Token::Eq { | |
| 242 | + self.advance()?; | |
| 243 | + let rhs = self.parse_additive()?; | |
| 244 | + Ok(Expr::Equation(Box::new(lhs), Box::new(rhs))) | |
| 245 | + } else { | |
| 246 | + Ok(lhs) | |
| 247 | + } | |
| 248 | + } | |
| 249 | + | |
| 250 | + fn parse_additive(&mut self) -> Result<Expr> { | |
| 251 | + let mut terms = vec![self.parse_multiplicative()?]; | |
| 252 | + | |
| 253 | + loop { | |
| 254 | + match &self.current { | |
| 255 | + Token::Plus => { | |
| 256 | + self.advance()?; | |
| 257 | + terms.push(self.parse_multiplicative()?); | |
| 258 | + } | |
| 259 | + Token::Minus => { | |
| 260 | + self.advance()?; | |
| 261 | + terms.push(Expr::neg(self.parse_multiplicative()?)); | |
| 262 | + } | |
| 263 | + _ => break, | |
| 264 | + } | |
| 265 | + } | |
| 266 | + | |
| 267 | + Ok(Expr::add(terms)) | |
| 268 | + } | |
| 269 | + | |
| 270 | + fn parse_multiplicative(&mut self) -> Result<Expr> { | |
| 271 | + let mut factors = vec![self.parse_power()?]; | |
| 272 | + | |
| 273 | + loop { | |
| 274 | + match &self.current { | |
| 275 | + Token::Star => { | |
| 276 | + self.advance()?; | |
| 277 | + factors.push(self.parse_power()?); | |
| 278 | + } | |
| 279 | + Token::Slash => { | |
| 280 | + self.advance()?; | |
| 281 | + let divisor = self.parse_power()?; | |
| 282 | + factors.push(Expr::pow(divisor, Expr::integer(-1))); | |
| 283 | + } | |
| 284 | + // Implicit multiplication: 2x, xy, x(y+1) | |
| 285 | + Token::Symbol(_) | Token::LParen | Token::Number(_) => { | |
| 286 | + // Check if this could be implicit multiplication | |
| 287 | + if let Some(last) = factors.last() { | |
| 288 | + if matches!( | |
| 289 | + last, | |
| 290 | + Expr::Symbol(_) | |
| 291 | + | Expr::Integer(_) | |
| 292 | + | Expr::Float(_) | |
| 293 | + | Expr::Func(_, _) | |
| 294 | + ) { | |
| 295 | + factors.push(self.parse_power()?); | |
| 296 | + continue; | |
| 297 | + } | |
| 298 | + } | |
| 299 | + break; | |
| 300 | + } | |
| 301 | + _ => break, | |
| 302 | + } | |
| 303 | + } | |
| 304 | + | |
| 305 | + Ok(Expr::mul(factors)) | |
| 306 | + } | |
| 307 | + | |
| 308 | + fn parse_power(&mut self) -> Result<Expr> { | |
| 309 | + let base = self.parse_unary()?; | |
| 310 | + if self.current == Token::Caret { | |
| 311 | + self.advance()?; | |
| 312 | + let exp = self.parse_power()?; // Right associative | |
| 313 | + Ok(Expr::pow(base, exp)) | |
| 314 | + } else { | |
| 315 | + Ok(base) | |
| 316 | + } | |
| 317 | + } | |
| 318 | + | |
| 319 | + fn parse_unary(&mut self) -> Result<Expr> { | |
| 320 | + if self.current == Token::Minus { | |
| 321 | + self.advance()?; | |
| 322 | + let expr = self.parse_unary()?; | |
| 323 | + Ok(Expr::neg(expr)) | |
| 324 | + } else { | |
| 325 | + self.parse_postfix() | |
| 326 | + } | |
| 327 | + } | |
| 328 | + | |
| 329 | + fn parse_postfix(&mut self) -> Result<Expr> { | |
| 330 | + let mut expr = self.parse_primary()?; | |
| 331 | + | |
| 332 | + while self.current == Token::Bang { | |
| 333 | + self.advance()?; | |
| 334 | + expr = Expr::func("factorial", vec![expr]); | |
| 335 | + } | |
| 336 | + | |
| 337 | + Ok(expr) | |
| 338 | + } | |
| 339 | + | |
| 340 | + fn parse_primary(&mut self) -> Result<Expr> { | |
| 341 | + match &self.current { | |
| 342 | + Token::Number(n) => { | |
| 343 | + let n = *n; | |
| 344 | + self.advance()?; | |
| 345 | + // Convert to integer if possible | |
| 346 | + if n.fract() == 0.0 && n.abs() < i64::MAX as f64 { | |
| 347 | + Ok(Expr::integer(n as i64)) | |
| 348 | + } else { | |
| 349 | + Ok(Expr::float(n)) | |
| 350 | + } | |
| 351 | + } | |
| 352 | + Token::Symbol(name) => { | |
| 353 | + let name = name.clone(); | |
| 354 | + self.advance()?; | |
| 355 | + | |
| 356 | + // Check for function call | |
| 357 | + if self.current == Token::LParen { | |
| 358 | + self.advance()?; | |
| 359 | + let args = self.parse_args()?; | |
| 360 | + self.expect(Token::RParen)?; | |
| 361 | + Ok(self.handle_special_function(&name, args)) | |
| 362 | + } else { | |
| 363 | + // Constants | |
| 364 | + match name.as_str() { | |
| 365 | + "pi" | "PI" => Ok(Expr::symbol("pi")), | |
| 366 | + "e" | "E" => Ok(Expr::symbol("e")), | |
| 367 | + "i" => Ok(Expr::Complex(0.0, 1.0)), | |
| 368 | + "inf" | "infinity" => Ok(Expr::Infinity(crate::expr::Sign::Positive)), | |
| 369 | + _ => Ok(Expr::symbol(name)), | |
| 370 | + } | |
| 371 | + } | |
| 372 | + } | |
| 373 | + Token::LParen => { | |
| 374 | + self.advance()?; | |
| 375 | + let expr = self.parse_equation()?; | |
| 376 | + self.expect(Token::RParen)?; | |
| 377 | + Ok(expr) | |
| 378 | + } | |
| 379 | + Token::LBracket => { | |
| 380 | + self.advance()?; | |
| 381 | + let matrix = self.parse_matrix()?; | |
| 382 | + self.expect(Token::RBracket)?; | |
| 383 | + Ok(matrix) | |
| 384 | + } | |
| 385 | + _ => Err(CasError::Parse { | |
| 386 | + position: self.lexer.position(), | |
| 387 | + message: format!("unexpected token: {:?}", self.current), | |
| 388 | + }), | |
| 389 | + } | |
| 390 | + } | |
| 391 | + | |
| 392 | + fn parse_args(&mut self) -> Result<Vec<Expr>> { | |
| 393 | + if self.current == Token::RParen { | |
| 394 | + return Ok(vec![]); | |
| 395 | + } | |
| 396 | + | |
| 397 | + let mut args = vec![self.parse_equation()?]; | |
| 398 | + | |
| 399 | + while self.current == Token::Comma { | |
| 400 | + self.advance()?; | |
| 401 | + args.push(self.parse_equation()?); | |
| 402 | + } | |
| 403 | + | |
| 404 | + Ok(args) | |
| 405 | + } | |
| 406 | + | |
| 407 | + fn parse_matrix(&mut self) -> Result<Expr> { | |
| 408 | + let mut rows = vec![self.parse_row()?]; | |
| 409 | + | |
| 410 | + while self.current == Token::Semicolon { | |
| 411 | + self.advance()?; | |
| 412 | + rows.push(self.parse_row()?); | |
| 413 | + } | |
| 414 | + | |
| 415 | + // Single row = vector, multiple rows = matrix | |
| 416 | + if rows.len() == 1 { | |
| 417 | + Ok(Expr::Vector(rows.into_iter().next().unwrap())) | |
| 418 | + } else { | |
| 419 | + Ok(Expr::Matrix(rows)) | |
| 420 | + } | |
| 421 | + } | |
| 422 | + | |
| 423 | + fn parse_row(&mut self) -> Result<Vec<Expr>> { | |
| 424 | + let mut elems = vec![self.parse_equation()?]; | |
| 425 | + | |
| 426 | + while self.current == Token::Comma { | |
| 427 | + self.advance()?; | |
| 428 | + elems.push(self.parse_equation()?); | |
| 429 | + } | |
| 430 | + | |
| 431 | + Ok(elems) | |
| 432 | + } | |
| 433 | + | |
| 434 | + /// Handle special functions that need custom AST nodes | |
| 435 | + fn handle_special_function(&self, name: &str, mut args: Vec<Expr>) -> Expr { | |
| 436 | + match name { | |
| 437 | + "diff" | "derivative" => { | |
| 438 | + if args.len() >= 2 { | |
| 439 | + let expr = args.remove(0); | |
| 440 | + let var = match args.remove(0) { | |
| 441 | + Expr::Symbol(s) => s, | |
| 442 | + _ => Symbol::new("x"), | |
| 443 | + }; | |
| 444 | + let order = if args.is_empty() { | |
| 445 | + 1 | |
| 446 | + } else { | |
| 447 | + match &args[0] { | |
| 448 | + Expr::Integer(n) => *n as u32, | |
| 449 | + _ => 1, | |
| 450 | + } | |
| 451 | + }; | |
| 452 | + Expr::Derivative { | |
| 453 | + expr: Box::new(expr), | |
| 454 | + var, | |
| 455 | + order, | |
| 456 | + } | |
| 457 | + } else if args.len() == 1 { | |
| 458 | + Expr::Derivative { | |
| 459 | + expr: Box::new(args.remove(0)), | |
| 460 | + var: Symbol::new("x"), | |
| 461 | + order: 1, | |
| 462 | + } | |
| 463 | + } else { | |
| 464 | + Expr::func(name, args) | |
| 465 | + } | |
| 466 | + } | |
| 467 | + "integrate" | "int" => { | |
| 468 | + if args.len() >= 2 { | |
| 469 | + let expr = args.remove(0); | |
| 470 | + let var = match args.remove(0) { | |
| 471 | + Expr::Symbol(s) => s, | |
| 472 | + _ => Symbol::new("x"), | |
| 473 | + }; | |
| 474 | + let (lower, upper) = if args.len() >= 2 { | |
| 475 | + (Some(Box::new(args.remove(0))), Some(Box::new(args.remove(0)))) | |
| 476 | + } else { | |
| 477 | + (None, None) | |
| 478 | + }; | |
| 479 | + Expr::Integral { | |
| 480 | + expr: Box::new(expr), | |
| 481 | + var, | |
| 482 | + lower, | |
| 483 | + upper, | |
| 484 | + } | |
| 485 | + } else { | |
| 486 | + Expr::func(name, args) | |
| 487 | + } | |
| 488 | + } | |
| 489 | + "lim" | "limit" => { | |
| 490 | + if args.len() >= 3 { | |
| 491 | + let expr = args.remove(0); | |
| 492 | + let var = match args.remove(0) { | |
| 493 | + Expr::Symbol(s) => s, | |
| 494 | + _ => Symbol::new("x"), | |
| 495 | + }; | |
| 496 | + let point = args.remove(0); | |
| 497 | + Expr::Limit { | |
| 498 | + expr: Box::new(expr), | |
| 499 | + var, | |
| 500 | + point: Box::new(point), | |
| 501 | + direction: None, | |
| 502 | + } | |
| 503 | + } else { | |
| 504 | + Expr::func(name, args) | |
| 505 | + } | |
| 506 | + } | |
| 507 | + "sum" => { | |
| 508 | + if args.len() >= 4 { | |
| 509 | + let expr = args.remove(0); | |
| 510 | + let var = match args.remove(0) { | |
| 511 | + Expr::Symbol(s) => s, | |
| 512 | + _ => Symbol::new("n"), | |
| 513 | + }; | |
| 514 | + let lower = args.remove(0); | |
| 515 | + let upper = args.remove(0); | |
| 516 | + Expr::Sum { | |
| 517 | + expr: Box::new(expr), | |
| 518 | + var, | |
| 519 | + lower: Box::new(lower), | |
| 520 | + upper: Box::new(upper), | |
| 521 | + } | |
| 522 | + } else { | |
| 523 | + Expr::func(name, args) | |
| 524 | + } | |
| 525 | + } | |
| 526 | + "product" | "prod" => { | |
| 527 | + if args.len() >= 4 { | |
| 528 | + let expr = args.remove(0); | |
| 529 | + let var = match args.remove(0) { | |
| 530 | + Expr::Symbol(s) => s, | |
| 531 | + _ => Symbol::new("n"), | |
| 532 | + }; | |
| 533 | + let lower = args.remove(0); | |
| 534 | + let upper = args.remove(0); | |
| 535 | + Expr::Product { | |
| 536 | + expr: Box::new(expr), | |
| 537 | + var, | |
| 538 | + lower: Box::new(lower), | |
| 539 | + upper: Box::new(upper), | |
| 540 | + } | |
| 541 | + } else { | |
| 542 | + Expr::func(name, args) | |
| 543 | + } | |
| 544 | + } | |
| 545 | + _ => Expr::func(name, args), | |
| 546 | + } | |
| 547 | + } | |
| 548 | +} | |
| 549 | + | |
| 550 | +/// Parse a string into an expression | |
| 551 | +pub fn parse(input: &str) -> Result<Expr> { | |
| 552 | + Parser::new(input)?.parse() | |
| 553 | +} | |
| 554 | + | |
| 555 | +#[cfg(test)] | |
| 556 | +mod tests { | |
| 557 | + use super::*; | |
| 558 | + | |
| 559 | + #[test] | |
| 560 | + fn test_number() { | |
| 561 | + let expr = parse("42").unwrap(); | |
| 562 | + assert_eq!(expr, Expr::integer(42)); | |
| 563 | + } | |
| 564 | + | |
| 565 | + #[test] | |
| 566 | + fn test_float() { | |
| 567 | + let expr = parse("3.14").unwrap(); | |
| 568 | + assert_eq!(expr, Expr::float(3.14)); | |
| 569 | + } | |
| 570 | + | |
| 571 | + #[test] | |
| 572 | + fn test_addition() { | |
| 573 | + let expr = parse("1 + 2 + 3").unwrap(); | |
| 574 | + assert_eq!( | |
| 575 | + expr, | |
| 576 | + Expr::add(vec![Expr::integer(1), Expr::integer(2), Expr::integer(3)]) | |
| 577 | + ); | |
| 578 | + } | |
| 579 | + | |
| 580 | + #[test] | |
| 581 | + fn test_subtraction() { | |
| 582 | + let expr = parse("5 - 3").unwrap(); | |
| 583 | + assert_eq!( | |
| 584 | + expr, | |
| 585 | + Expr::add(vec![Expr::integer(5), Expr::neg(Expr::integer(3))]) | |
| 586 | + ); | |
| 587 | + } | |
| 588 | + | |
| 589 | + #[test] | |
| 590 | + fn test_multiplication() { | |
| 591 | + let expr = parse("2 * 3").unwrap(); | |
| 592 | + assert_eq!(expr, Expr::mul(vec![Expr::integer(2), Expr::integer(3)])); | |
| 593 | + } | |
| 594 | + | |
| 595 | + #[test] | |
| 596 | + fn test_division() { | |
| 597 | + let expr = parse("6 / 2").unwrap(); | |
| 598 | + assert_eq!( | |
| 599 | + expr, | |
| 600 | + Expr::mul(vec![ | |
| 601 | + Expr::integer(6), | |
| 602 | + Expr::pow(Expr::integer(2), Expr::integer(-1)) | |
| 603 | + ]) | |
| 604 | + ); | |
| 605 | + } | |
| 606 | + | |
| 607 | + #[test] | |
| 608 | + fn test_power() { | |
| 609 | + let expr = parse("2^3").unwrap(); | |
| 610 | + assert_eq!(expr, Expr::pow(Expr::integer(2), Expr::integer(3))); | |
| 611 | + } | |
| 612 | + | |
| 613 | + #[test] | |
| 614 | + fn test_implicit_multiplication() { | |
| 615 | + let expr = parse("2x").unwrap(); | |
| 616 | + assert_eq!(expr, Expr::mul(vec![Expr::integer(2), Expr::symbol("x")])); | |
| 617 | + } | |
| 618 | + | |
| 619 | + #[test] | |
| 620 | + fn test_function() { | |
| 621 | + let expr = parse("sin(x)").unwrap(); | |
| 622 | + assert_eq!(expr, Expr::func("sin", vec![Expr::symbol("x")])); | |
| 623 | + } | |
| 624 | + | |
| 625 | + #[test] | |
| 626 | + fn test_nested() { | |
| 627 | + let expr = parse("sin(x + 1)").unwrap(); | |
| 628 | + assert_eq!( | |
| 629 | + expr, | |
| 630 | + Expr::func( | |
| 631 | + "sin", | |
| 632 | + vec![Expr::add(vec![Expr::symbol("x"), Expr::integer(1)])] | |
| 633 | + ) | |
| 634 | + ); | |
| 635 | + } | |
| 636 | + | |
| 637 | + #[test] | |
| 638 | + fn test_equation() { | |
| 639 | + let expr = parse("x + 1 = 5").unwrap(); | |
| 640 | + assert!(matches!(expr, Expr::Equation(_, _))); | |
| 641 | + } | |
| 642 | + | |
| 643 | + #[test] | |
| 644 | + fn test_derivative() { | |
| 645 | + let expr = parse("diff(x^2, x)").unwrap(); | |
| 646 | + assert!(matches!(expr, Expr::Derivative { .. })); | |
| 647 | + } | |
| 648 | + | |
| 649 | + #[test] | |
| 650 | + fn test_integral() { | |
| 651 | + let expr = parse("integrate(x^2, x)").unwrap(); | |
| 652 | + assert!(matches!(expr, Expr::Integral { lower: None, .. })); | |
| 653 | + } | |
| 654 | + | |
| 655 | + #[test] | |
| 656 | + fn test_definite_integral() { | |
| 657 | + let expr = parse("integrate(x^2, x, 0, 1)").unwrap(); | |
| 658 | + assert!(matches!(expr, Expr::Integral { lower: Some(_), .. })); | |
| 659 | + } | |
| 660 | + | |
| 661 | + #[test] | |
| 662 | + fn test_vector() { | |
| 663 | + let expr = parse("[1, 2, 3]").unwrap(); | |
| 664 | + assert!(matches!(expr, Expr::Vector(_))); | |
| 665 | + } | |
| 666 | + | |
| 667 | + #[test] | |
| 668 | + fn test_matrix() { | |
| 669 | + let expr = parse("[1, 2; 3, 4]").unwrap(); | |
| 670 | + assert!(matches!(expr, Expr::Matrix(_))); | |
| 671 | + } | |
| 672 | + | |
| 673 | + #[test] | |
| 674 | + fn test_factorial() { | |
| 675 | + let expr = parse("5!").unwrap(); | |
| 676 | + assert_eq!(expr, Expr::func("factorial", vec![Expr::integer(5)])); | |
| 677 | + } | |
| 678 | +} | |