Fortran · 27714 bytes Raw Blame History
1 !> Expression evaluator module for FORTBITE
2 !>
3 !> Evaluates Abstract Syntax Trees, performing mathematical operations
4 !> with proper type promotion and precision handling.
5 module fortbite_evaluator_m
6 use fortbite_types_m, only: value_t, variable_t, VALUE_SCALAR, VALUE_COMPLEX, VALUE_MATRIX, &
7 create_scalar, create_complex, create_matrix, print_value, is_zero, is_real, &
8 create_zeros_matrix, create_ones_matrix, create_eye_matrix
9 use fortbite_ast_m, only: ast_node_t, ast_node_ptr_t, AST_LITERAL, AST_IDENTIFIER, AST_BINARY_OP, &
10 AST_UNARY_OP, AST_FUNCTION_CALL, AST_ASSIGNMENT, AST_PRECISION_SPEC, AST_MATRIX_LITERAL, &
11 OP_ADD, OP_SUB, OP_MUL, OP_DIV, OP_POW, OP_MOD, &
12 OP_UNARY_PLUS, OP_UNARY_MINUS
13 use fortbite_arithmetic_m, only: add_values, subtract_values, multiply_values, &
14 divide_values, power_values, negate_value, abs_value
15 use fortbite_matrix_m, only: matrix_transpose, matrix_determinant, matrix_inverse, &
16 matrix_element_access, matrix_solve, matrix_rank, matrix_trace
17 use fortbite_functions_m, only: eval_trigonometric, eval_hyperbolic, eval_logarithmic, &
18 eval_exponential, eval_statistical, eval_special, eval_complex_functions
19 use iso_fortran_env, only: real64
20 implicit none
21 private
22
23 public :: evaluate_expression, evaluation_context_t, evaluation_error_t
24 public :: create_context, destroy_context, set_variable, get_variable
25
26 !> Evaluation context (variable storage)
27 type :: evaluation_context_t
28 type(variable_t), pointer :: variables => null()
29 integer :: default_precision = 15
30 end type evaluation_context_t
31
32 !> Evaluation error information
33 type :: evaluation_error_t
34 logical :: has_error = .false.
35 character(len=200) :: message = ''
36 end type evaluation_error_t
37
38 contains
39
40 !> Create a new evaluation context
41 function create_context() result(context)
42 type(evaluation_context_t) :: context
43
44 ! Initialize with default precision
45 context%default_precision = 15
46 nullify(context%variables)
47 end function create_context
48
49 !> Destroy evaluation context and free variables
50 subroutine destroy_context(context)
51 type(evaluation_context_t), intent(inout) :: context
52
53 call free_variables(context%variables)
54 end subroutine destroy_context
55
56 !> Set a variable in the context
57 subroutine set_variable(context, name, value)
58 type(evaluation_context_t), intent(inout) :: context
59 character(len=*), intent(in) :: name
60 type(value_t), intent(in) :: value
61
62 type(variable_t), pointer :: var, current
63
64 ! Look for existing variable
65 current => context%variables
66 do while (associated(current))
67 if (current%name == name) then
68 current%value = value
69 return
70 end if
71 current => current%next
72 end do
73
74 ! Create new variable
75 allocate(var)
76 var%name = trim(name)
77 var%value = value
78 var%next => context%variables
79 context%variables => var
80 end subroutine set_variable
81
82 !> Get a variable from the context
83 function get_variable(context, name, value) result(found)
84 type(evaluation_context_t), intent(in) :: context
85 character(len=*), intent(in) :: name
86 type(value_t), intent(out) :: value
87 logical :: found
88
89 type(variable_t), pointer :: current
90
91 found = .false.
92 current => context%variables
93
94 do while (associated(current))
95 if (current%name == name) then
96 value = current%value
97 found = .true.
98 return
99 end if
100 current => current%next
101 end do
102 end function get_variable
103
104 !> Evaluate an AST expression
105 recursive function evaluate_expression(node, context, error) result(value)
106 type(ast_node_t), pointer, intent(in) :: node
107 type(evaluation_context_t), intent(inout) :: context
108 type(evaluation_error_t), intent(out), optional :: error
109 type(value_t) :: value
110
111 type(evaluation_error_t) :: local_error
112
113 local_error%has_error = .false.
114
115 if (.not. associated(node)) then
116 call set_eval_error(local_error, 'Null AST node')
117 value = create_scalar(0.0_real64) ! Default value
118 if (present(error)) error = local_error
119 return
120 end if
121
122 select case (node%node_type)
123 case (AST_LITERAL)
124 value = evaluate_literal(node, local_error)
125
126 case (AST_IDENTIFIER)
127 value = evaluate_identifier(node, context, local_error)
128
129 case (AST_BINARY_OP)
130 value = evaluate_binary_op(node, context, local_error)
131
132 case (AST_UNARY_OP)
133 value = evaluate_unary_op(node, context, local_error)
134
135 case (AST_FUNCTION_CALL)
136 value = evaluate_function_call(node, context, local_error)
137
138 case (AST_ASSIGNMENT)
139 value = evaluate_assignment(node, context, local_error)
140
141 case (AST_PRECISION_SPEC)
142 value = evaluate_precision_spec(node, context, local_error)
143
144 case (AST_MATRIX_LITERAL)
145 value = evaluate_matrix_literal(node, local_error)
146
147 case default
148 call set_eval_error(local_error, 'Unknown AST node type')
149 value = create_scalar(0.0_real64)
150 end select
151
152 if (present(error)) error = local_error
153 end function evaluate_expression
154
155 !> Evaluate a literal value
156 function evaluate_literal(node, error) result(value)
157 type(ast_node_t), pointer, intent(in) :: node
158 type(evaluation_error_t), intent(out) :: error
159 type(value_t) :: value
160
161 error%has_error = .false.
162 value = node%literal_value
163 end function evaluate_literal
164
165 !> Evaluate an identifier (variable or constant)
166 function evaluate_identifier(node, context, error) result(value)
167 type(ast_node_t), pointer, intent(in) :: node
168 type(evaluation_context_t), intent(inout) :: context
169 type(evaluation_error_t), intent(out) :: error
170 type(value_t) :: value
171
172 logical :: found
173
174 error%has_error = .false.
175
176 ! First check for mathematical constants
177 if (node%identifier == 'pi') then
178 value = create_scalar(4.0_real64 * atan(1.0_real64)) ! Precise π
179 return
180 else if (node%identifier == 'e') then
181 value = create_scalar(exp(1.0_real64)) ! Precise e
182 return
183 else if (node%identifier == 'i') then
184 value = create_complex(0.0_real64, 1.0_real64) ! Imaginary unit
185 return
186 end if
187
188 ! Check for matrix creation shortcuts
189 if (node%identifier == 'zeros2') then
190 value = create_zeros_matrix(2, 2)
191 return
192 else if (node%identifier == 'ones2') then
193 value = create_ones_matrix(2, 2)
194 return
195 else if (node%identifier == 'ones3') then
196 value = create_ones_matrix(3, 3)
197 return
198 else if (node%identifier == 'eye2') then
199 value = create_eye_matrix(2)
200 return
201 else if (node%identifier == 'testmat') then
202 ! Create a test matrix [[1,2],[3,4]]
203 value = create_matrix(reshape([1.0_real64, 3.0_real64, 2.0_real64, 4.0_real64], [2, 2]))
204 return
205 end if
206
207 ! Check user-defined variables
208 found = get_variable(context, node%identifier, value)
209
210 if (.not. found) then
211 call set_eval_error(error, 'Undefined variable: ' // node%identifier)
212 value = create_scalar(0.0_real64)
213 end if
214 end function evaluate_identifier
215
216 !> Evaluate a binary operation
217 recursive function evaluate_binary_op(node, context, error) result(value)
218 type(ast_node_t), pointer, intent(in) :: node
219 type(evaluation_context_t), intent(inout) :: context
220 type(evaluation_error_t), intent(out) :: error
221 type(value_t) :: value
222
223 type(value_t) :: left_val, right_val
224 type(evaluation_error_t) :: left_error, right_error
225
226 error%has_error = .false.
227
228 ! Evaluate operands
229 left_val = evaluate_expression(node%left, context, left_error)
230 if (left_error%has_error) then
231 error = left_error
232 return
233 end if
234
235 right_val = evaluate_expression(node%right, context, right_error)
236 if (right_error%has_error) then
237 error = right_error
238 return
239 end if
240
241 ! Perform operation
242 select case (node%operator)
243 case (OP_ADD)
244 value = add_values(left_val, right_val)
245 case (OP_SUB)
246 value = subtract_values(left_val, right_val)
247 case (OP_MUL)
248 value = multiply_values(left_val, right_val)
249 case (OP_DIV)
250 value = divide_values(left_val, right_val)
251 case (OP_POW)
252 value = power_values(left_val, right_val)
253 case (OP_MOD)
254 ! Modulo operation (for now, only on real numbers)
255 if (left_val%value_type == VALUE_SCALAR .and. right_val%value_type == VALUE_SCALAR) then
256 value = create_scalar(mod(left_val%scalar_val, right_val%scalar_val))
257 else
258 call set_eval_error(error, 'Modulo operation only supported for real numbers')
259 value = create_scalar(0.0_real64)
260 end if
261 case default
262 call set_eval_error(error, 'Unknown binary operator')
263 value = create_scalar(0.0_real64)
264 end select
265 end function evaluate_binary_op
266
267 !> Evaluate a unary operation
268 recursive function evaluate_unary_op(node, context, error) result(value)
269 type(ast_node_t), pointer, intent(in) :: node
270 type(evaluation_context_t), intent(inout) :: context
271 type(evaluation_error_t), intent(out) :: error
272 type(value_t) :: value
273
274 type(value_t) :: operand_val
275 type(evaluation_error_t) :: operand_error
276
277 error%has_error = .false.
278
279 ! Evaluate operand
280 operand_val = evaluate_expression(node%operand, context, operand_error)
281 if (operand_error%has_error) then
282 error = operand_error
283 return
284 end if
285
286 ! Perform operation
287 select case (node%operator)
288 case (OP_UNARY_PLUS)
289 value = operand_val ! Unary plus doesn't change the value
290 case (OP_UNARY_MINUS)
291 value = negate_value(operand_val)
292 case default
293 call set_eval_error(error, 'Unknown unary operator')
294 value = create_scalar(0.0_real64)
295 end select
296 end function evaluate_unary_op
297
298 !> Validate function argument count and return error if invalid
299 logical function validate_function_args(node, expected_count, error) result(is_valid)
300 type(ast_node_t), pointer, intent(in) :: node
301 integer, intent(in) :: expected_count
302 type(evaluation_error_t), intent(out) :: error
303
304 character(len=100) :: error_msg
305
306 error%has_error = .false.
307
308 if (node%arg_count /= expected_count) then
309 if (expected_count == 1) then
310 write(error_msg, '(A,A,A)') trim(node%function_name), '() expects 1 argument'
311 else
312 write(error_msg, '(A,A,A,I0,A)') trim(node%function_name), '() expects ', '', expected_count, ' arguments'
313 end if
314 call set_eval_error(error, error_msg)
315 is_valid = .false.
316 else
317 is_valid = .true.
318 end if
319 end function validate_function_args
320
321 !> Validate matrix dimension arguments (positive integers)
322 logical function validate_matrix_dims(arg1, arg2, func_name, error) result(is_valid)
323 type(value_t), intent(in) :: arg1
324 type(value_t), intent(in), optional :: arg2
325 character(len=*), intent(in) :: func_name
326 type(evaluation_error_t), intent(out) :: error
327
328 character(len=100) :: error_msg
329
330 error%has_error = .false.
331 is_valid = .false.
332
333 ! Check first argument
334 if (arg1%value_type /= VALUE_SCALAR) then
335 write(error_msg, '(A,A)') trim(func_name), '() expects numeric size argument'
336 call set_eval_error(error, error_msg)
337 return
338 end if
339
340 if (arg1%scalar_val <= 0 .or. arg1%scalar_val /= int(arg1%scalar_val)) then
341 write(error_msg, '(A,A)') trim(func_name), '() size must be a positive integer'
342 call set_eval_error(error, error_msg)
343 return
344 end if
345
346 ! Check second argument if present
347 if (present(arg2)) then
348 if (arg2%value_type /= VALUE_SCALAR) then
349 write(error_msg, '(A,A)') trim(func_name), '() expects numeric size arguments'
350 call set_eval_error(error, error_msg)
351 return
352 end if
353
354 if (arg2%scalar_val <= 0 .or. arg2%scalar_val /= int(arg2%scalar_val)) then
355 write(error_msg, '(A,A)') trim(func_name), '() sizes must be positive integers'
356 call set_eval_error(error, error_msg)
357 return
358 end if
359 end if
360
361 is_valid = .true.
362 end function validate_matrix_dims
363
364 !> Handle matrix creation functions (zeros, ones, eye)
365 function eval_matrix_creation(func_name, node, args, error) result(value)
366 character(len=*), intent(in) :: func_name
367 type(ast_node_t), pointer, intent(in) :: node
368 type(value_t), intent(in) :: args(:)
369 type(evaluation_error_t), intent(out) :: error
370 type(value_t) :: value
371
372 integer :: rows, cols
373 character(len=100) :: error_msg
374
375 select case (trim(func_name))
376 case ('zeros', 'ones')
377 if (node%arg_count == 1) then
378 ! Square matrix
379 if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then
380 rows = int(args(1)%scalar_val)
381 if (trim(func_name) == 'zeros') then
382 value = create_zeros_matrix(rows, rows)
383 else
384 value = create_ones_matrix(rows, rows)
385 end if
386 else
387 value = create_scalar(0.0_real64)
388 end if
389 else if (node%arg_count == 2) then
390 ! Rectangular matrix
391 if (validate_matrix_dims(args(1), args(2), func_name, error)) then
392 rows = int(args(1)%scalar_val)
393 cols = int(args(2)%scalar_val)
394 if (trim(func_name) == 'zeros') then
395 value = create_zeros_matrix(rows, cols)
396 else
397 value = create_ones_matrix(rows, cols)
398 end if
399 else
400 value = create_scalar(0.0_real64)
401 end if
402 else
403 write(error_msg, '(A,A)') trim(func_name), '() expects 1 or 2 arguments'
404 call set_eval_error(error, error_msg)
405 value = create_scalar(0.0_real64)
406 end if
407
408 case ('eye')
409 if (node%arg_count == 1) then
410 if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then
411 rows = int(args(1)%scalar_val)
412 value = create_eye_matrix(rows)
413 else
414 value = create_scalar(0.0_real64)
415 end if
416 else
417 call set_eval_error(error, 'eye() expects 1 argument')
418 value = create_scalar(0.0_real64)
419 end if
420 end select
421 end function eval_matrix_creation
422
423 !> Handle single-argument matrix operation functions
424 function eval_matrix_operation(func_name, node, args, error) result(value)
425 character(len=*), intent(in) :: func_name
426 type(ast_node_t), pointer, intent(in) :: node
427 type(value_t), intent(in) :: args(:)
428 type(evaluation_error_t), intent(out) :: error
429 type(value_t) :: value
430
431 character(len=100) :: error_msg
432
433 ! Validate single argument
434 if (.not. validate_function_args(node, 1, error)) then
435 value = create_scalar(0.0_real64)
436 return
437 end if
438
439 ! Validate matrix argument
440 if (args(1)%value_type /= VALUE_MATRIX) then
441 write(error_msg, '(A,A)') trim(func_name), '() expects a matrix argument'
442 call set_eval_error(error, error_msg)
443 value = create_scalar(0.0_real64)
444 return
445 end if
446
447 ! Dispatch to appropriate function
448 select case (trim(func_name))
449 case ('transpose', 'trans')
450 value = matrix_transpose(args(1))
451 case ('det', 'determinant')
452 value = create_scalar(matrix_determinant(args(1)))
453 case ('inv', 'inverse')
454 value = matrix_inverse(args(1))
455 case default
456 call set_eval_error(error, 'Unknown matrix operation function')
457 value = create_scalar(0.0_real64)
458 end select
459 end function eval_matrix_operation
460
461 !> Evaluate a function call
462 recursive function evaluate_function_call(node, context, error) result(value)
463 type(ast_node_t), pointer, intent(in) :: node
464 type(evaluation_context_t), intent(inout) :: context
465 type(evaluation_error_t), intent(out) :: error
466 type(value_t) :: value
467
468 type(value_t), allocatable :: args(:)
469 type(evaluation_error_t) :: arg_error
470 integer :: i
471 real(real64) :: x, result_val
472
473 error%has_error = .false.
474
475 ! Evaluate arguments
476 if (node%arg_count > 0) then
477 allocate(args(node%arg_count))
478 do i = 1, node%arg_count
479 args(i) = evaluate_expression(node%arguments(i)%ptr, context, arg_error)
480 if (arg_error%has_error) then
481 error = arg_error
482 return
483 end if
484 end do
485 end if
486
487 ! Call function
488 select case (node%function_name)
489 ! Trigonometric functions
490 case ('sin', 'cos', 'tan', 'asin', 'arcsin', 'acos', 'arccos', 'atan', 'arctan', &
491 'sec', 'csc', 'cot')
492 if (.not. validate_function_args(node, 1, error)) then
493 value = create_scalar(0.0_real64)
494 return
495 end if
496 value = eval_trigonometric(node%function_name, args(1))
497
498 ! Hyperbolic functions
499 case ('sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', 'sech', 'csch', 'coth')
500 if (.not. validate_function_args(node, 1, error)) then
501 value = create_scalar(0.0_real64)
502 return
503 end if
504 value = eval_hyperbolic(node%function_name, args(1))
505
506 ! Logarithmic functions
507 case ('log', 'ln', 'log10', 'lg', 'log2')
508 if (.not. validate_function_args(node, 1, error)) then
509 value = create_scalar(0.0_real64)
510 return
511 end if
512 value = eval_logarithmic(node%function_name, args(1))
513
514 ! Exponential functions
515 case ('exp', 'exp2', 'exp10', 'expm1')
516 if (.not. validate_function_args(node, 1, error)) then
517 value = create_scalar(0.0_real64)
518 return
519 end if
520 value = eval_exponential(node%function_name, args(1))
521
522 ! Statistical functions
523 case ('mean', 'average', 'sum', 'std', 'stddev')
524 if (.not. validate_function_args(node, 1, error)) then
525 value = create_scalar(0.0_real64)
526 return
527 end if
528 value = eval_statistical(node%function_name, args(1))
529
530 ! Special functions
531 case ('gamma', 'lgamma', 'loggamma', 'factorial', 'fact', 'erf', 'erfc', &
532 'ceil', 'ceiling', 'floor', 'round', 'nint', 'frac', 'fraction')
533 if (.not. validate_function_args(node, 1, error)) then
534 value = create_scalar(0.0_real64)
535 return
536 end if
537 value = eval_special(node%function_name, args(1))
538
539 ! Complex functions
540 case ('real', 're', 'imag', 'im', 'conj', 'conjugate', 'arg', 'phase', 'angle', 'cabs', 'modulus')
541 if (.not. validate_function_args(node, 1, error)) then
542 value = create_scalar(0.0_real64)
543 return
544 end if
545 value = eval_complex_functions(node%function_name, args(1))
546
547 case ('sqrt')
548 if (.not. validate_function_args(node, 1, error)) then
549 value = create_scalar(0.0_real64)
550 return
551 end if
552 if (args(1)%value_type == VALUE_SCALAR) then
553 x = args(1)%scalar_val
554 if (x >= 0.0_real64) then
555 value = create_scalar(sqrt(x))
556 else
557 call set_eval_error(error, 'sqrt() argument must be non-negative')
558 value = create_scalar(0.0_real64)
559 end if
560 else
561 call set_eval_error(error, 'sqrt() expects a real argument')
562 value = create_scalar(0.0_real64)
563 end if
564
565 case ('abs')
566 if (.not. validate_function_args(node, 1, error)) then
567 value = create_scalar(0.0_real64)
568 return
569 end if
570 value = abs_value(args(1))
571
572 ! Matrix creation functions
573 case ('zeros', 'ones', 'eye')
574 value = eval_matrix_creation(node%function_name, node, args, error)
575
576 ! Matrix functions
577 case ('transpose', 'trans', 'det', 'determinant', 'inv', 'inverse')
578 value = eval_matrix_operation(node%function_name, node, args, error)
579
580 case ('solve')
581 ! Solve linear system Ax = b
582 if (node%arg_count /= 2) then
583 call set_eval_error(error, 'solve() expects 2 arguments: solve(A, b)')
584 value = create_scalar(0.0_real64)
585 return
586 end if
587 if (args(1)%value_type == VALUE_MATRIX .and. args(2)%value_type == VALUE_MATRIX) then
588 value = matrix_solve(args(1), args(2))
589 else
590 call set_eval_error(error, 'solve() expects matrix arguments')
591 value = create_scalar(0.0_real64)
592 end if
593
594 case ('rank')
595 ! Calculate matrix rank
596 if (args(1)%value_type == VALUE_MATRIX) then
597 value = create_scalar(real(matrix_rank(args(1)), real64))
598 else
599 call set_eval_error(error, 'rank() expects a matrix argument')
600 value = create_scalar(0.0_real64)
601 end if
602
603 case ('trace')
604 ! Calculate matrix trace (sum of diagonal elements)
605 if (args(1)%value_type == VALUE_MATRIX) then
606 value = create_scalar(matrix_trace(args(1)))
607 else
608 call set_eval_error(error, 'trace() expects a matrix argument')
609 value = create_scalar(0.0_real64)
610 end if
611
612 case default
613 call set_eval_error(error, 'Unknown function: ' // node%function_name)
614 value = create_scalar(0.0_real64)
615 end select
616
617 if (allocated(args)) deallocate(args)
618 end function evaluate_function_call
619
620 !> Evaluate an assignment
621 recursive function evaluate_assignment(node, context, error) result(value)
622 type(ast_node_t), pointer, intent(in) :: node
623 type(evaluation_context_t), intent(inout) :: context
624 type(evaluation_error_t), intent(out) :: error
625 type(value_t) :: value
626
627 type(evaluation_error_t) :: expr_error
628 character(len=:), allocatable :: var_name
629
630 error%has_error = .false.
631
632 if (.not. associated(node%left) .or. node%left%node_type /= AST_IDENTIFIER) then
633 call set_eval_error(error, 'Left side of assignment must be a variable')
634 value = create_scalar(0.0_real64)
635 return
636 end if
637
638 var_name = node%left%identifier
639
640 ! Evaluate the right-hand side expression
641 value = evaluate_expression(node%right, context, expr_error)
642 if (expr_error%has_error) then
643 error = expr_error
644 return
645 end if
646
647 ! Store the variable
648 call set_variable(context, var_name, value)
649 end function evaluate_assignment
650
651 !> Evaluate a precision specification
652 recursive function evaluate_precision_spec(node, context, error) result(value)
653 type(ast_node_t), pointer, intent(in) :: node
654 type(evaluation_context_t), intent(inout) :: context
655 type(evaluation_error_t), intent(out) :: error
656 type(value_t) :: value
657
658 type(evaluation_error_t) :: expr_error
659
660 error%has_error = .false.
661
662 ! Evaluate the expression
663 value = evaluate_expression(node%expression, context, expr_error)
664 if (expr_error%has_error) then
665 error = expr_error
666 return
667 end if
668
669 ! Store the requested precision in the value
670 ! We'll use the precision_kind field to store the number of digits
671 value%precision_kind = node%precision_digits
672 end function evaluate_precision_spec
673
674 !> Evaluate a matrix literal
675 function evaluate_matrix_literal(node, error) result(value)
676 type(ast_node_t), pointer, intent(in) :: node
677 type(evaluation_error_t), intent(out) :: error
678 type(value_t) :: value
679
680 error%has_error = .false.
681
682 if (allocated(node%matrix_elements)) then
683 value = create_matrix(node%matrix_elements)
684 else
685 call set_eval_error(error, 'Invalid matrix literal')
686 value = create_scalar(0.0_real64)
687 end if
688 end function evaluate_matrix_literal
689
690 !> Set an evaluation error
691 subroutine set_eval_error(error, message)
692 type(evaluation_error_t), intent(out) :: error
693 character(len=*), intent(in) :: message
694
695 error%has_error = .true.
696 error%message = trim(message)
697 end subroutine set_eval_error
698
699 !> Free variable linked list
700 recursive subroutine free_variables(var)
701 type(variable_t), pointer, intent(inout) :: var
702
703 if (associated(var)) then
704 call free_variables(var%next)
705 deallocate(var)
706 end if
707 nullify(var)
708 end subroutine free_variables
709
710 end module fortbite_evaluator_m