| 1 | !> Expression evaluator module for FORTBITE |
| 2 | !> |
| 3 | !> Evaluates Abstract Syntax Trees, performing mathematical operations |
| 4 | !> with proper type promotion and precision handling. |
| 5 | module fortbite_evaluator_m |
| 6 | use fortbite_types_m, only: value_t, variable_t, VALUE_SCALAR, VALUE_COMPLEX, VALUE_MATRIX, & |
| 7 | create_scalar, create_complex, create_matrix, print_value, is_zero, is_real, & |
| 8 | create_zeros_matrix, create_ones_matrix, create_eye_matrix |
| 9 | use fortbite_ast_m, only: ast_node_t, ast_node_ptr_t, AST_LITERAL, AST_IDENTIFIER, AST_BINARY_OP, & |
| 10 | AST_UNARY_OP, AST_FUNCTION_CALL, AST_ASSIGNMENT, AST_PRECISION_SPEC, AST_MATRIX_LITERAL, & |
| 11 | OP_ADD, OP_SUB, OP_MUL, OP_DIV, OP_POW, OP_MOD, & |
| 12 | OP_UNARY_PLUS, OP_UNARY_MINUS |
| 13 | use fortbite_arithmetic_m, only: add_values, subtract_values, multiply_values, & |
| 14 | divide_values, power_values, negate_value, abs_value |
| 15 | use fortbite_matrix_m, only: matrix_transpose, matrix_determinant, matrix_inverse, & |
| 16 | matrix_element_access, matrix_solve, matrix_rank, matrix_trace |
| 17 | use fortbite_functions_m, only: eval_trigonometric, eval_hyperbolic, eval_logarithmic, & |
| 18 | eval_exponential, eval_statistical, eval_special, eval_complex_functions |
| 19 | use iso_fortran_env, only: real64 |
| 20 | implicit none |
| 21 | private |
| 22 | |
| 23 | public :: evaluate_expression, evaluation_context_t, evaluation_error_t |
| 24 | public :: create_context, destroy_context, set_variable, get_variable |
| 25 | |
| 26 | !> Evaluation context (variable storage) |
| 27 | type :: evaluation_context_t |
| 28 | type(variable_t), pointer :: variables => null() |
| 29 | integer :: default_precision = 15 |
| 30 | end type evaluation_context_t |
| 31 | |
| 32 | !> Evaluation error information |
| 33 | type :: evaluation_error_t |
| 34 | logical :: has_error = .false. |
| 35 | character(len=200) :: message = '' |
| 36 | end type evaluation_error_t |
| 37 | |
| 38 | contains |
| 39 | |
| 40 | !> Create a new evaluation context |
| 41 | function create_context() result(context) |
| 42 | type(evaluation_context_t) :: context |
| 43 | |
| 44 | ! Initialize with default precision |
| 45 | context%default_precision = 15 |
| 46 | nullify(context%variables) |
| 47 | end function create_context |
| 48 | |
| 49 | !> Destroy evaluation context and free variables |
| 50 | subroutine destroy_context(context) |
| 51 | type(evaluation_context_t), intent(inout) :: context |
| 52 | |
| 53 | call free_variables(context%variables) |
| 54 | end subroutine destroy_context |
| 55 | |
| 56 | !> Set a variable in the context |
| 57 | subroutine set_variable(context, name, value) |
| 58 | type(evaluation_context_t), intent(inout) :: context |
| 59 | character(len=*), intent(in) :: name |
| 60 | type(value_t), intent(in) :: value |
| 61 | |
| 62 | type(variable_t), pointer :: var, current |
| 63 | |
| 64 | ! Look for existing variable |
| 65 | current => context%variables |
| 66 | do while (associated(current)) |
| 67 | if (current%name == name) then |
| 68 | current%value = value |
| 69 | return |
| 70 | end if |
| 71 | current => current%next |
| 72 | end do |
| 73 | |
| 74 | ! Create new variable |
| 75 | allocate(var) |
| 76 | var%name = trim(name) |
| 77 | var%value = value |
| 78 | var%next => context%variables |
| 79 | context%variables => var |
| 80 | end subroutine set_variable |
| 81 | |
| 82 | !> Get a variable from the context |
| 83 | function get_variable(context, name, value) result(found) |
| 84 | type(evaluation_context_t), intent(in) :: context |
| 85 | character(len=*), intent(in) :: name |
| 86 | type(value_t), intent(out) :: value |
| 87 | logical :: found |
| 88 | |
| 89 | type(variable_t), pointer :: current |
| 90 | |
| 91 | found = .false. |
| 92 | current => context%variables |
| 93 | |
| 94 | do while (associated(current)) |
| 95 | if (current%name == name) then |
| 96 | value = current%value |
| 97 | found = .true. |
| 98 | return |
| 99 | end if |
| 100 | current => current%next |
| 101 | end do |
| 102 | end function get_variable |
| 103 | |
| 104 | !> Evaluate an AST expression |
| 105 | recursive function evaluate_expression(node, context, error) result(value) |
| 106 | type(ast_node_t), pointer, intent(in) :: node |
| 107 | type(evaluation_context_t), intent(inout) :: context |
| 108 | type(evaluation_error_t), intent(out), optional :: error |
| 109 | type(value_t) :: value |
| 110 | |
| 111 | type(evaluation_error_t) :: local_error |
| 112 | |
| 113 | local_error%has_error = .false. |
| 114 | |
| 115 | if (.not. associated(node)) then |
| 116 | call set_eval_error(local_error, 'Null AST node') |
| 117 | value = create_scalar(0.0_real64) ! Default value |
| 118 | if (present(error)) error = local_error |
| 119 | return |
| 120 | end if |
| 121 | |
| 122 | select case (node%node_type) |
| 123 | case (AST_LITERAL) |
| 124 | value = evaluate_literal(node, local_error) |
| 125 | |
| 126 | case (AST_IDENTIFIER) |
| 127 | value = evaluate_identifier(node, context, local_error) |
| 128 | |
| 129 | case (AST_BINARY_OP) |
| 130 | value = evaluate_binary_op(node, context, local_error) |
| 131 | |
| 132 | case (AST_UNARY_OP) |
| 133 | value = evaluate_unary_op(node, context, local_error) |
| 134 | |
| 135 | case (AST_FUNCTION_CALL) |
| 136 | value = evaluate_function_call(node, context, local_error) |
| 137 | |
| 138 | case (AST_ASSIGNMENT) |
| 139 | value = evaluate_assignment(node, context, local_error) |
| 140 | |
| 141 | case (AST_PRECISION_SPEC) |
| 142 | value = evaluate_precision_spec(node, context, local_error) |
| 143 | |
| 144 | case (AST_MATRIX_LITERAL) |
| 145 | value = evaluate_matrix_literal(node, local_error) |
| 146 | |
| 147 | case default |
| 148 | call set_eval_error(local_error, 'Unknown AST node type') |
| 149 | value = create_scalar(0.0_real64) |
| 150 | end select |
| 151 | |
| 152 | if (present(error)) error = local_error |
| 153 | end function evaluate_expression |
| 154 | |
| 155 | !> Evaluate a literal value |
| 156 | function evaluate_literal(node, error) result(value) |
| 157 | type(ast_node_t), pointer, intent(in) :: node |
| 158 | type(evaluation_error_t), intent(out) :: error |
| 159 | type(value_t) :: value |
| 160 | |
| 161 | error%has_error = .false. |
| 162 | value = node%literal_value |
| 163 | end function evaluate_literal |
| 164 | |
| 165 | !> Evaluate an identifier (variable or constant) |
| 166 | function evaluate_identifier(node, context, error) result(value) |
| 167 | type(ast_node_t), pointer, intent(in) :: node |
| 168 | type(evaluation_context_t), intent(inout) :: context |
| 169 | type(evaluation_error_t), intent(out) :: error |
| 170 | type(value_t) :: value |
| 171 | |
| 172 | logical :: found |
| 173 | |
| 174 | error%has_error = .false. |
| 175 | |
| 176 | ! First check for mathematical constants |
| 177 | if (node%identifier == 'pi') then |
| 178 | value = create_scalar(4.0_real64 * atan(1.0_real64)) ! Precise π |
| 179 | return |
| 180 | else if (node%identifier == 'e') then |
| 181 | value = create_scalar(exp(1.0_real64)) ! Precise e |
| 182 | return |
| 183 | else if (node%identifier == 'i') then |
| 184 | value = create_complex(0.0_real64, 1.0_real64) ! Imaginary unit |
| 185 | return |
| 186 | end if |
| 187 | |
| 188 | ! Check for matrix creation shortcuts |
| 189 | if (node%identifier == 'zeros2') then |
| 190 | value = create_zeros_matrix(2, 2) |
| 191 | return |
| 192 | else if (node%identifier == 'ones2') then |
| 193 | value = create_ones_matrix(2, 2) |
| 194 | return |
| 195 | else if (node%identifier == 'ones3') then |
| 196 | value = create_ones_matrix(3, 3) |
| 197 | return |
| 198 | else if (node%identifier == 'eye2') then |
| 199 | value = create_eye_matrix(2) |
| 200 | return |
| 201 | else if (node%identifier == 'testmat') then |
| 202 | ! Create a test matrix [[1,2],[3,4]] |
| 203 | value = create_matrix(reshape([1.0_real64, 3.0_real64, 2.0_real64, 4.0_real64], [2, 2])) |
| 204 | return |
| 205 | end if |
| 206 | |
| 207 | ! Check user-defined variables |
| 208 | found = get_variable(context, node%identifier, value) |
| 209 | |
| 210 | if (.not. found) then |
| 211 | call set_eval_error(error, 'Undefined variable: ' // node%identifier) |
| 212 | value = create_scalar(0.0_real64) |
| 213 | end if |
| 214 | end function evaluate_identifier |
| 215 | |
| 216 | !> Evaluate a binary operation |
| 217 | recursive function evaluate_binary_op(node, context, error) result(value) |
| 218 | type(ast_node_t), pointer, intent(in) :: node |
| 219 | type(evaluation_context_t), intent(inout) :: context |
| 220 | type(evaluation_error_t), intent(out) :: error |
| 221 | type(value_t) :: value |
| 222 | |
| 223 | type(value_t) :: left_val, right_val |
| 224 | type(evaluation_error_t) :: left_error, right_error |
| 225 | |
| 226 | error%has_error = .false. |
| 227 | |
| 228 | ! Evaluate operands |
| 229 | left_val = evaluate_expression(node%left, context, left_error) |
| 230 | if (left_error%has_error) then |
| 231 | error = left_error |
| 232 | return |
| 233 | end if |
| 234 | |
| 235 | right_val = evaluate_expression(node%right, context, right_error) |
| 236 | if (right_error%has_error) then |
| 237 | error = right_error |
| 238 | return |
| 239 | end if |
| 240 | |
| 241 | ! Perform operation |
| 242 | select case (node%operator) |
| 243 | case (OP_ADD) |
| 244 | value = add_values(left_val, right_val) |
| 245 | case (OP_SUB) |
| 246 | value = subtract_values(left_val, right_val) |
| 247 | case (OP_MUL) |
| 248 | value = multiply_values(left_val, right_val) |
| 249 | case (OP_DIV) |
| 250 | value = divide_values(left_val, right_val) |
| 251 | case (OP_POW) |
| 252 | value = power_values(left_val, right_val) |
| 253 | case (OP_MOD) |
| 254 | ! Modulo operation (for now, only on real numbers) |
| 255 | if (left_val%value_type == VALUE_SCALAR .and. right_val%value_type == VALUE_SCALAR) then |
| 256 | value = create_scalar(mod(left_val%scalar_val, right_val%scalar_val)) |
| 257 | else |
| 258 | call set_eval_error(error, 'Modulo operation only supported for real numbers') |
| 259 | value = create_scalar(0.0_real64) |
| 260 | end if |
| 261 | case default |
| 262 | call set_eval_error(error, 'Unknown binary operator') |
| 263 | value = create_scalar(0.0_real64) |
| 264 | end select |
| 265 | end function evaluate_binary_op |
| 266 | |
| 267 | !> Evaluate a unary operation |
| 268 | recursive function evaluate_unary_op(node, context, error) result(value) |
| 269 | type(ast_node_t), pointer, intent(in) :: node |
| 270 | type(evaluation_context_t), intent(inout) :: context |
| 271 | type(evaluation_error_t), intent(out) :: error |
| 272 | type(value_t) :: value |
| 273 | |
| 274 | type(value_t) :: operand_val |
| 275 | type(evaluation_error_t) :: operand_error |
| 276 | |
| 277 | error%has_error = .false. |
| 278 | |
| 279 | ! Evaluate operand |
| 280 | operand_val = evaluate_expression(node%operand, context, operand_error) |
| 281 | if (operand_error%has_error) then |
| 282 | error = operand_error |
| 283 | return |
| 284 | end if |
| 285 | |
| 286 | ! Perform operation |
| 287 | select case (node%operator) |
| 288 | case (OP_UNARY_PLUS) |
| 289 | value = operand_val ! Unary plus doesn't change the value |
| 290 | case (OP_UNARY_MINUS) |
| 291 | value = negate_value(operand_val) |
| 292 | case default |
| 293 | call set_eval_error(error, 'Unknown unary operator') |
| 294 | value = create_scalar(0.0_real64) |
| 295 | end select |
| 296 | end function evaluate_unary_op |
| 297 | |
| 298 | !> Validate function argument count and return error if invalid |
| 299 | logical function validate_function_args(node, expected_count, error) result(is_valid) |
| 300 | type(ast_node_t), pointer, intent(in) :: node |
| 301 | integer, intent(in) :: expected_count |
| 302 | type(evaluation_error_t), intent(out) :: error |
| 303 | |
| 304 | character(len=100) :: error_msg |
| 305 | |
| 306 | error%has_error = .false. |
| 307 | |
| 308 | if (node%arg_count /= expected_count) then |
| 309 | if (expected_count == 1) then |
| 310 | write(error_msg, '(A,A,A)') trim(node%function_name), '() expects 1 argument' |
| 311 | else |
| 312 | write(error_msg, '(A,A,A,I0,A)') trim(node%function_name), '() expects ', '', expected_count, ' arguments' |
| 313 | end if |
| 314 | call set_eval_error(error, error_msg) |
| 315 | is_valid = .false. |
| 316 | else |
| 317 | is_valid = .true. |
| 318 | end if |
| 319 | end function validate_function_args |
| 320 | |
| 321 | !> Validate matrix dimension arguments (positive integers) |
| 322 | logical function validate_matrix_dims(arg1, arg2, func_name, error) result(is_valid) |
| 323 | type(value_t), intent(in) :: arg1 |
| 324 | type(value_t), intent(in), optional :: arg2 |
| 325 | character(len=*), intent(in) :: func_name |
| 326 | type(evaluation_error_t), intent(out) :: error |
| 327 | |
| 328 | character(len=100) :: error_msg |
| 329 | |
| 330 | error%has_error = .false. |
| 331 | is_valid = .false. |
| 332 | |
| 333 | ! Check first argument |
| 334 | if (arg1%value_type /= VALUE_SCALAR) then |
| 335 | write(error_msg, '(A,A)') trim(func_name), '() expects numeric size argument' |
| 336 | call set_eval_error(error, error_msg) |
| 337 | return |
| 338 | end if |
| 339 | |
| 340 | if (arg1%scalar_val <= 0 .or. arg1%scalar_val /= int(arg1%scalar_val)) then |
| 341 | write(error_msg, '(A,A)') trim(func_name), '() size must be a positive integer' |
| 342 | call set_eval_error(error, error_msg) |
| 343 | return |
| 344 | end if |
| 345 | |
| 346 | ! Check second argument if present |
| 347 | if (present(arg2)) then |
| 348 | if (arg2%value_type /= VALUE_SCALAR) then |
| 349 | write(error_msg, '(A,A)') trim(func_name), '() expects numeric size arguments' |
| 350 | call set_eval_error(error, error_msg) |
| 351 | return |
| 352 | end if |
| 353 | |
| 354 | if (arg2%scalar_val <= 0 .or. arg2%scalar_val /= int(arg2%scalar_val)) then |
| 355 | write(error_msg, '(A,A)') trim(func_name), '() sizes must be positive integers' |
| 356 | call set_eval_error(error, error_msg) |
| 357 | return |
| 358 | end if |
| 359 | end if |
| 360 | |
| 361 | is_valid = .true. |
| 362 | end function validate_matrix_dims |
| 363 | |
| 364 | !> Handle matrix creation functions (zeros, ones, eye) |
| 365 | function eval_matrix_creation(func_name, node, args, error) result(value) |
| 366 | character(len=*), intent(in) :: func_name |
| 367 | type(ast_node_t), pointer, intent(in) :: node |
| 368 | type(value_t), intent(in) :: args(:) |
| 369 | type(evaluation_error_t), intent(out) :: error |
| 370 | type(value_t) :: value |
| 371 | |
| 372 | integer :: rows, cols |
| 373 | character(len=100) :: error_msg |
| 374 | |
| 375 | select case (trim(func_name)) |
| 376 | case ('zeros', 'ones') |
| 377 | if (node%arg_count == 1) then |
| 378 | ! Square matrix |
| 379 | if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then |
| 380 | rows = int(args(1)%scalar_val) |
| 381 | if (trim(func_name) == 'zeros') then |
| 382 | value = create_zeros_matrix(rows, rows) |
| 383 | else |
| 384 | value = create_ones_matrix(rows, rows) |
| 385 | end if |
| 386 | else |
| 387 | value = create_scalar(0.0_real64) |
| 388 | end if |
| 389 | else if (node%arg_count == 2) then |
| 390 | ! Rectangular matrix |
| 391 | if (validate_matrix_dims(args(1), args(2), func_name, error)) then |
| 392 | rows = int(args(1)%scalar_val) |
| 393 | cols = int(args(2)%scalar_val) |
| 394 | if (trim(func_name) == 'zeros') then |
| 395 | value = create_zeros_matrix(rows, cols) |
| 396 | else |
| 397 | value = create_ones_matrix(rows, cols) |
| 398 | end if |
| 399 | else |
| 400 | value = create_scalar(0.0_real64) |
| 401 | end if |
| 402 | else |
| 403 | write(error_msg, '(A,A)') trim(func_name), '() expects 1 or 2 arguments' |
| 404 | call set_eval_error(error, error_msg) |
| 405 | value = create_scalar(0.0_real64) |
| 406 | end if |
| 407 | |
| 408 | case ('eye') |
| 409 | if (node%arg_count == 1) then |
| 410 | if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then |
| 411 | rows = int(args(1)%scalar_val) |
| 412 | value = create_eye_matrix(rows) |
| 413 | else |
| 414 | value = create_scalar(0.0_real64) |
| 415 | end if |
| 416 | else |
| 417 | call set_eval_error(error, 'eye() expects 1 argument') |
| 418 | value = create_scalar(0.0_real64) |
| 419 | end if |
| 420 | end select |
| 421 | end function eval_matrix_creation |
| 422 | |
| 423 | !> Handle single-argument matrix operation functions |
| 424 | function eval_matrix_operation(func_name, node, args, error) result(value) |
| 425 | character(len=*), intent(in) :: func_name |
| 426 | type(ast_node_t), pointer, intent(in) :: node |
| 427 | type(value_t), intent(in) :: args(:) |
| 428 | type(evaluation_error_t), intent(out) :: error |
| 429 | type(value_t) :: value |
| 430 | |
| 431 | character(len=100) :: error_msg |
| 432 | |
| 433 | ! Validate single argument |
| 434 | if (.not. validate_function_args(node, 1, error)) then |
| 435 | value = create_scalar(0.0_real64) |
| 436 | return |
| 437 | end if |
| 438 | |
| 439 | ! Validate matrix argument |
| 440 | if (args(1)%value_type /= VALUE_MATRIX) then |
| 441 | write(error_msg, '(A,A)') trim(func_name), '() expects a matrix argument' |
| 442 | call set_eval_error(error, error_msg) |
| 443 | value = create_scalar(0.0_real64) |
| 444 | return |
| 445 | end if |
| 446 | |
| 447 | ! Dispatch to appropriate function |
| 448 | select case (trim(func_name)) |
| 449 | case ('transpose', 'trans') |
| 450 | value = matrix_transpose(args(1)) |
| 451 | case ('det', 'determinant') |
| 452 | value = create_scalar(matrix_determinant(args(1))) |
| 453 | case ('inv', 'inverse') |
| 454 | value = matrix_inverse(args(1)) |
| 455 | case default |
| 456 | call set_eval_error(error, 'Unknown matrix operation function') |
| 457 | value = create_scalar(0.0_real64) |
| 458 | end select |
| 459 | end function eval_matrix_operation |
| 460 | |
| 461 | !> Evaluate a function call |
| 462 | recursive function evaluate_function_call(node, context, error) result(value) |
| 463 | type(ast_node_t), pointer, intent(in) :: node |
| 464 | type(evaluation_context_t), intent(inout) :: context |
| 465 | type(evaluation_error_t), intent(out) :: error |
| 466 | type(value_t) :: value |
| 467 | |
| 468 | type(value_t), allocatable :: args(:) |
| 469 | type(evaluation_error_t) :: arg_error |
| 470 | integer :: i |
| 471 | real(real64) :: x, result_val |
| 472 | |
| 473 | error%has_error = .false. |
| 474 | |
| 475 | ! Evaluate arguments |
| 476 | if (node%arg_count > 0) then |
| 477 | allocate(args(node%arg_count)) |
| 478 | do i = 1, node%arg_count |
| 479 | args(i) = evaluate_expression(node%arguments(i)%ptr, context, arg_error) |
| 480 | if (arg_error%has_error) then |
| 481 | error = arg_error |
| 482 | return |
| 483 | end if |
| 484 | end do |
| 485 | end if |
| 486 | |
| 487 | ! Call function |
| 488 | select case (node%function_name) |
| 489 | ! Trigonometric functions |
| 490 | case ('sin', 'cos', 'tan', 'asin', 'arcsin', 'acos', 'arccos', 'atan', 'arctan', & |
| 491 | 'sec', 'csc', 'cot') |
| 492 | if (.not. validate_function_args(node, 1, error)) then |
| 493 | value = create_scalar(0.0_real64) |
| 494 | return |
| 495 | end if |
| 496 | value = eval_trigonometric(node%function_name, args(1)) |
| 497 | |
| 498 | ! Hyperbolic functions |
| 499 | case ('sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', 'sech', 'csch', 'coth') |
| 500 | if (.not. validate_function_args(node, 1, error)) then |
| 501 | value = create_scalar(0.0_real64) |
| 502 | return |
| 503 | end if |
| 504 | value = eval_hyperbolic(node%function_name, args(1)) |
| 505 | |
| 506 | ! Logarithmic functions |
| 507 | case ('log', 'ln', 'log10', 'lg', 'log2') |
| 508 | if (.not. validate_function_args(node, 1, error)) then |
| 509 | value = create_scalar(0.0_real64) |
| 510 | return |
| 511 | end if |
| 512 | value = eval_logarithmic(node%function_name, args(1)) |
| 513 | |
| 514 | ! Exponential functions |
| 515 | case ('exp', 'exp2', 'exp10', 'expm1') |
| 516 | if (.not. validate_function_args(node, 1, error)) then |
| 517 | value = create_scalar(0.0_real64) |
| 518 | return |
| 519 | end if |
| 520 | value = eval_exponential(node%function_name, args(1)) |
| 521 | |
| 522 | ! Statistical functions |
| 523 | case ('mean', 'average', 'sum', 'std', 'stddev') |
| 524 | if (.not. validate_function_args(node, 1, error)) then |
| 525 | value = create_scalar(0.0_real64) |
| 526 | return |
| 527 | end if |
| 528 | value = eval_statistical(node%function_name, args(1)) |
| 529 | |
| 530 | ! Special functions |
| 531 | case ('gamma', 'lgamma', 'loggamma', 'factorial', 'fact', 'erf', 'erfc', & |
| 532 | 'ceil', 'ceiling', 'floor', 'round', 'nint', 'frac', 'fraction') |
| 533 | if (.not. validate_function_args(node, 1, error)) then |
| 534 | value = create_scalar(0.0_real64) |
| 535 | return |
| 536 | end if |
| 537 | value = eval_special(node%function_name, args(1)) |
| 538 | |
| 539 | ! Complex functions |
| 540 | case ('real', 're', 'imag', 'im', 'conj', 'conjugate', 'arg', 'phase', 'angle', 'cabs', 'modulus') |
| 541 | if (.not. validate_function_args(node, 1, error)) then |
| 542 | value = create_scalar(0.0_real64) |
| 543 | return |
| 544 | end if |
| 545 | value = eval_complex_functions(node%function_name, args(1)) |
| 546 | |
| 547 | case ('sqrt') |
| 548 | if (.not. validate_function_args(node, 1, error)) then |
| 549 | value = create_scalar(0.0_real64) |
| 550 | return |
| 551 | end if |
| 552 | if (args(1)%value_type == VALUE_SCALAR) then |
| 553 | x = args(1)%scalar_val |
| 554 | if (x >= 0.0_real64) then |
| 555 | value = create_scalar(sqrt(x)) |
| 556 | else |
| 557 | call set_eval_error(error, 'sqrt() argument must be non-negative') |
| 558 | value = create_scalar(0.0_real64) |
| 559 | end if |
| 560 | else |
| 561 | call set_eval_error(error, 'sqrt() expects a real argument') |
| 562 | value = create_scalar(0.0_real64) |
| 563 | end if |
| 564 | |
| 565 | case ('abs') |
| 566 | if (.not. validate_function_args(node, 1, error)) then |
| 567 | value = create_scalar(0.0_real64) |
| 568 | return |
| 569 | end if |
| 570 | value = abs_value(args(1)) |
| 571 | |
| 572 | ! Matrix creation functions |
| 573 | case ('zeros', 'ones', 'eye') |
| 574 | value = eval_matrix_creation(node%function_name, node, args, error) |
| 575 | |
| 576 | ! Matrix functions |
| 577 | case ('transpose', 'trans', 'det', 'determinant', 'inv', 'inverse') |
| 578 | value = eval_matrix_operation(node%function_name, node, args, error) |
| 579 | |
| 580 | case ('solve') |
| 581 | ! Solve linear system Ax = b |
| 582 | if (node%arg_count /= 2) then |
| 583 | call set_eval_error(error, 'solve() expects 2 arguments: solve(A, b)') |
| 584 | value = create_scalar(0.0_real64) |
| 585 | return |
| 586 | end if |
| 587 | if (args(1)%value_type == VALUE_MATRIX .and. args(2)%value_type == VALUE_MATRIX) then |
| 588 | value = matrix_solve(args(1), args(2)) |
| 589 | else |
| 590 | call set_eval_error(error, 'solve() expects matrix arguments') |
| 591 | value = create_scalar(0.0_real64) |
| 592 | end if |
| 593 | |
| 594 | case ('rank') |
| 595 | ! Calculate matrix rank |
| 596 | if (args(1)%value_type == VALUE_MATRIX) then |
| 597 | value = create_scalar(real(matrix_rank(args(1)), real64)) |
| 598 | else |
| 599 | call set_eval_error(error, 'rank() expects a matrix argument') |
| 600 | value = create_scalar(0.0_real64) |
| 601 | end if |
| 602 | |
| 603 | case ('trace') |
| 604 | ! Calculate matrix trace (sum of diagonal elements) |
| 605 | if (args(1)%value_type == VALUE_MATRIX) then |
| 606 | value = create_scalar(matrix_trace(args(1))) |
| 607 | else |
| 608 | call set_eval_error(error, 'trace() expects a matrix argument') |
| 609 | value = create_scalar(0.0_real64) |
| 610 | end if |
| 611 | |
| 612 | case default |
| 613 | call set_eval_error(error, 'Unknown function: ' // node%function_name) |
| 614 | value = create_scalar(0.0_real64) |
| 615 | end select |
| 616 | |
| 617 | if (allocated(args)) deallocate(args) |
| 618 | end function evaluate_function_call |
| 619 | |
| 620 | !> Evaluate an assignment |
| 621 | recursive function evaluate_assignment(node, context, error) result(value) |
| 622 | type(ast_node_t), pointer, intent(in) :: node |
| 623 | type(evaluation_context_t), intent(inout) :: context |
| 624 | type(evaluation_error_t), intent(out) :: error |
| 625 | type(value_t) :: value |
| 626 | |
| 627 | type(evaluation_error_t) :: expr_error |
| 628 | character(len=:), allocatable :: var_name |
| 629 | |
| 630 | error%has_error = .false. |
| 631 | |
| 632 | if (.not. associated(node%left) .or. node%left%node_type /= AST_IDENTIFIER) then |
| 633 | call set_eval_error(error, 'Left side of assignment must be a variable') |
| 634 | value = create_scalar(0.0_real64) |
| 635 | return |
| 636 | end if |
| 637 | |
| 638 | var_name = node%left%identifier |
| 639 | |
| 640 | ! Evaluate the right-hand side expression |
| 641 | value = evaluate_expression(node%right, context, expr_error) |
| 642 | if (expr_error%has_error) then |
| 643 | error = expr_error |
| 644 | return |
| 645 | end if |
| 646 | |
| 647 | ! Store the variable |
| 648 | call set_variable(context, var_name, value) |
| 649 | end function evaluate_assignment |
| 650 | |
| 651 | !> Evaluate a precision specification |
| 652 | recursive function evaluate_precision_spec(node, context, error) result(value) |
| 653 | type(ast_node_t), pointer, intent(in) :: node |
| 654 | type(evaluation_context_t), intent(inout) :: context |
| 655 | type(evaluation_error_t), intent(out) :: error |
| 656 | type(value_t) :: value |
| 657 | |
| 658 | type(evaluation_error_t) :: expr_error |
| 659 | |
| 660 | error%has_error = .false. |
| 661 | |
| 662 | ! Evaluate the expression |
| 663 | value = evaluate_expression(node%expression, context, expr_error) |
| 664 | if (expr_error%has_error) then |
| 665 | error = expr_error |
| 666 | return |
| 667 | end if |
| 668 | |
| 669 | ! Store the requested precision in the value |
| 670 | ! We'll use the precision_kind field to store the number of digits |
| 671 | value%precision_kind = node%precision_digits |
| 672 | end function evaluate_precision_spec |
| 673 | |
| 674 | !> Evaluate a matrix literal |
| 675 | function evaluate_matrix_literal(node, error) result(value) |
| 676 | type(ast_node_t), pointer, intent(in) :: node |
| 677 | type(evaluation_error_t), intent(out) :: error |
| 678 | type(value_t) :: value |
| 679 | |
| 680 | error%has_error = .false. |
| 681 | |
| 682 | if (allocated(node%matrix_elements)) then |
| 683 | value = create_matrix(node%matrix_elements) |
| 684 | else |
| 685 | call set_eval_error(error, 'Invalid matrix literal') |
| 686 | value = create_scalar(0.0_real64) |
| 687 | end if |
| 688 | end function evaluate_matrix_literal |
| 689 | |
| 690 | !> Set an evaluation error |
| 691 | subroutine set_eval_error(error, message) |
| 692 | type(evaluation_error_t), intent(out) :: error |
| 693 | character(len=*), intent(in) :: message |
| 694 | |
| 695 | error%has_error = .true. |
| 696 | error%message = trim(message) |
| 697 | end subroutine set_eval_error |
| 698 | |
| 699 | !> Free variable linked list |
| 700 | recursive subroutine free_variables(var) |
| 701 | type(variable_t), pointer, intent(inout) :: var |
| 702 | |
| 703 | if (associated(var)) then |
| 704 | call free_variables(var%next) |
| 705 | deallocate(var) |
| 706 | end if |
| 707 | nullify(var) |
| 708 | end subroutine free_variables |
| 709 | |
| 710 | end module fortbite_evaluator_m |