@@ -295,6 +295,169 @@ contains |
| 295 | 295 | end select |
| 296 | 296 | end function evaluate_unary_op |
| 297 | 297 | |
| 298 | + !> Validate function argument count and return error if invalid |
| 299 | + logical function validate_function_args(node, expected_count, error) result(is_valid) |
| 300 | + type(ast_node_t), pointer, intent(in) :: node |
| 301 | + integer, intent(in) :: expected_count |
| 302 | + type(evaluation_error_t), intent(out) :: error |
| 303 | + |
| 304 | + character(len=100) :: error_msg |
| 305 | + |
| 306 | + error%has_error = .false. |
| 307 | + |
| 308 | + if (node%arg_count /= expected_count) then |
| 309 | + if (expected_count == 1) then |
| 310 | + write(error_msg, '(A,A,A)') trim(node%function_name), '() expects 1 argument' |
| 311 | + else |
| 312 | + write(error_msg, '(A,A,A,I0,A)') trim(node%function_name), '() expects ', '', expected_count, ' arguments' |
| 313 | + end if |
| 314 | + call set_eval_error(error, error_msg) |
| 315 | + is_valid = .false. |
| 316 | + else |
| 317 | + is_valid = .true. |
| 318 | + end if |
| 319 | + end function validate_function_args |
| 320 | + |
| 321 | + !> Validate matrix dimension arguments (positive integers) |
| 322 | + logical function validate_matrix_dims(arg1, arg2, func_name, error) result(is_valid) |
| 323 | + type(value_t), intent(in) :: arg1 |
| 324 | + type(value_t), intent(in), optional :: arg2 |
| 325 | + character(len=*), intent(in) :: func_name |
| 326 | + type(evaluation_error_t), intent(out) :: error |
| 327 | + |
| 328 | + character(len=100) :: error_msg |
| 329 | + |
| 330 | + error%has_error = .false. |
| 331 | + is_valid = .false. |
| 332 | + |
| 333 | + ! Check first argument |
| 334 | + if (arg1%value_type /= VALUE_SCALAR) then |
| 335 | + write(error_msg, '(A,A)') trim(func_name), '() expects numeric size argument' |
| 336 | + call set_eval_error(error, error_msg) |
| 337 | + return |
| 338 | + end if |
| 339 | + |
| 340 | + if (arg1%scalar_val <= 0 .or. arg1%scalar_val /= int(arg1%scalar_val)) then |
| 341 | + write(error_msg, '(A,A)') trim(func_name), '() size must be a positive integer' |
| 342 | + call set_eval_error(error, error_msg) |
| 343 | + return |
| 344 | + end if |
| 345 | + |
| 346 | + ! Check second argument if present |
| 347 | + if (present(arg2)) then |
| 348 | + if (arg2%value_type /= VALUE_SCALAR) then |
| 349 | + write(error_msg, '(A,A)') trim(func_name), '() expects numeric size arguments' |
| 350 | + call set_eval_error(error, error_msg) |
| 351 | + return |
| 352 | + end if |
| 353 | + |
| 354 | + if (arg2%scalar_val <= 0 .or. arg2%scalar_val /= int(arg2%scalar_val)) then |
| 355 | + write(error_msg, '(A,A)') trim(func_name), '() sizes must be positive integers' |
| 356 | + call set_eval_error(error, error_msg) |
| 357 | + return |
| 358 | + end if |
| 359 | + end if |
| 360 | + |
| 361 | + is_valid = .true. |
| 362 | + end function validate_matrix_dims |
| 363 | + |
| 364 | + !> Handle matrix creation functions (zeros, ones, eye) |
| 365 | + function eval_matrix_creation(func_name, node, args, error) result(value) |
| 366 | + character(len=*), intent(in) :: func_name |
| 367 | + type(ast_node_t), pointer, intent(in) :: node |
| 368 | + type(value_t), intent(in) :: args(:) |
| 369 | + type(evaluation_error_t), intent(out) :: error |
| 370 | + type(value_t) :: value |
| 371 | + |
| 372 | + integer :: rows, cols |
| 373 | + character(len=100) :: error_msg |
| 374 | + |
| 375 | + select case (trim(func_name)) |
| 376 | + case ('zeros', 'ones') |
| 377 | + if (node%arg_count == 1) then |
| 378 | + ! Square matrix |
| 379 | + if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then |
| 380 | + rows = int(args(1)%scalar_val) |
| 381 | + if (trim(func_name) == 'zeros') then |
| 382 | + value = create_zeros_matrix(rows, rows) |
| 383 | + else |
| 384 | + value = create_ones_matrix(rows, rows) |
| 385 | + end if |
| 386 | + else |
| 387 | + value = create_scalar(0.0_real64) |
| 388 | + end if |
| 389 | + else if (node%arg_count == 2) then |
| 390 | + ! Rectangular matrix |
| 391 | + if (validate_matrix_dims(args(1), args(2), func_name, error)) then |
| 392 | + rows = int(args(1)%scalar_val) |
| 393 | + cols = int(args(2)%scalar_val) |
| 394 | + if (trim(func_name) == 'zeros') then |
| 395 | + value = create_zeros_matrix(rows, cols) |
| 396 | + else |
| 397 | + value = create_ones_matrix(rows, cols) |
| 398 | + end if |
| 399 | + else |
| 400 | + value = create_scalar(0.0_real64) |
| 401 | + end if |
| 402 | + else |
| 403 | + write(error_msg, '(A,A)') trim(func_name), '() expects 1 or 2 arguments' |
| 404 | + call set_eval_error(error, error_msg) |
| 405 | + value = create_scalar(0.0_real64) |
| 406 | + end if |
| 407 | + |
| 408 | + case ('eye') |
| 409 | + if (node%arg_count == 1) then |
| 410 | + if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then |
| 411 | + rows = int(args(1)%scalar_val) |
| 412 | + value = create_eye_matrix(rows) |
| 413 | + else |
| 414 | + value = create_scalar(0.0_real64) |
| 415 | + end if |
| 416 | + else |
| 417 | + call set_eval_error(error, 'eye() expects 1 argument') |
| 418 | + value = create_scalar(0.0_real64) |
| 419 | + end if |
| 420 | + end select |
| 421 | + end function eval_matrix_creation |
| 422 | + |
| 423 | + !> Handle single-argument matrix operation functions |
| 424 | + function eval_matrix_operation(func_name, node, args, error) result(value) |
| 425 | + character(len=*), intent(in) :: func_name |
| 426 | + type(ast_node_t), pointer, intent(in) :: node |
| 427 | + type(value_t), intent(in) :: args(:) |
| 428 | + type(evaluation_error_t), intent(out) :: error |
| 429 | + type(value_t) :: value |
| 430 | + |
| 431 | + character(len=100) :: error_msg |
| 432 | + |
| 433 | + ! Validate single argument |
| 434 | + if (.not. validate_function_args(node, 1, error)) then |
| 435 | + value = create_scalar(0.0_real64) |
| 436 | + return |
| 437 | + end if |
| 438 | + |
| 439 | + ! Validate matrix argument |
| 440 | + if (args(1)%value_type /= VALUE_MATRIX) then |
| 441 | + write(error_msg, '(A,A)') trim(func_name), '() expects a matrix argument' |
| 442 | + call set_eval_error(error, error_msg) |
| 443 | + value = create_scalar(0.0_real64) |
| 444 | + return |
| 445 | + end if |
| 446 | + |
| 447 | + ! Dispatch to appropriate function |
| 448 | + select case (trim(func_name)) |
| 449 | + case ('transpose', 'trans') |
| 450 | + value = matrix_transpose(args(1)) |
| 451 | + case ('det', 'determinant') |
| 452 | + value = create_scalar(matrix_determinant(args(1))) |
| 453 | + case ('inv', 'inverse') |
| 454 | + value = matrix_inverse(args(1)) |
| 455 | + case default |
| 456 | + call set_eval_error(error, 'Unknown matrix operation function') |
| 457 | + value = create_scalar(0.0_real64) |
| 458 | + end select |
| 459 | + end function eval_matrix_operation |
| 460 | + |
| 298 | 461 | !> Evaluate a function call |
| 299 | 462 | recursive function evaluate_function_call(node, context, error) result(value) |
| 300 | 463 | type(ast_node_t), pointer, intent(in) :: node |
@@ -326,8 +489,7 @@ contains |
| 326 | 489 | ! Trigonometric functions |
| 327 | 490 | case ('sin', 'cos', 'tan', 'asin', 'arcsin', 'acos', 'arccos', 'atan', 'arctan', & |
| 328 | 491 | 'sec', 'csc', 'cot') |
| 329 | | - if (node%arg_count /= 1) then |
| 330 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 492 | + if (.not. validate_function_args(node, 1, error)) then |
| 331 | 493 | value = create_scalar(0.0_real64) |
| 332 | 494 | return |
| 333 | 495 | end if |
@@ -335,8 +497,7 @@ contains |
| 335 | 497 | |
| 336 | 498 | ! Hyperbolic functions |
| 337 | 499 | case ('sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', 'sech', 'csch', 'coth') |
| 338 | | - if (node%arg_count /= 1) then |
| 339 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 500 | + if (.not. validate_function_args(node, 1, error)) then |
| 340 | 501 | value = create_scalar(0.0_real64) |
| 341 | 502 | return |
| 342 | 503 | end if |
@@ -344,8 +505,7 @@ contains |
| 344 | 505 | |
| 345 | 506 | ! Logarithmic functions |
| 346 | 507 | case ('log', 'ln', 'log10', 'lg', 'log2') |
| 347 | | - if (node%arg_count /= 1) then |
| 348 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 508 | + if (.not. validate_function_args(node, 1, error)) then |
| 349 | 509 | value = create_scalar(0.0_real64) |
| 350 | 510 | return |
| 351 | 511 | end if |
@@ -353,8 +513,7 @@ contains |
| 353 | 513 | |
| 354 | 514 | ! Exponential functions |
| 355 | 515 | case ('exp', 'exp2', 'exp10', 'expm1') |
| 356 | | - if (node%arg_count /= 1) then |
| 357 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 516 | + if (.not. validate_function_args(node, 1, error)) then |
| 358 | 517 | value = create_scalar(0.0_real64) |
| 359 | 518 | return |
| 360 | 519 | end if |
@@ -362,8 +521,7 @@ contains |
| 362 | 521 | |
| 363 | 522 | ! Statistical functions |
| 364 | 523 | case ('mean', 'average', 'sum', 'std', 'stddev') |
| 365 | | - if (node%arg_count /= 1) then |
| 366 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 524 | + if (.not. validate_function_args(node, 1, error)) then |
| 367 | 525 | value = create_scalar(0.0_real64) |
| 368 | 526 | return |
| 369 | 527 | end if |
@@ -372,8 +530,7 @@ contains |
| 372 | 530 | ! Special functions |
| 373 | 531 | case ('gamma', 'lgamma', 'loggamma', 'factorial', 'fact', 'erf', 'erfc', & |
| 374 | 532 | 'ceil', 'ceiling', 'floor', 'round', 'nint', 'frac', 'fraction') |
| 375 | | - if (node%arg_count /= 1) then |
| 376 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 533 | + if (.not. validate_function_args(node, 1, error)) then |
| 377 | 534 | value = create_scalar(0.0_real64) |
| 378 | 535 | return |
| 379 | 536 | end if |
@@ -381,16 +538,14 @@ contains |
| 381 | 538 | |
| 382 | 539 | ! Complex functions |
| 383 | 540 | case ('real', 're', 'imag', 'im', 'conj', 'conjugate', 'arg', 'phase', 'angle', 'cabs', 'modulus') |
| 384 | | - if (node%arg_count /= 1) then |
| 385 | | - call set_eval_error(error, trim(node%function_name) // '() expects 1 argument') |
| 541 | + if (.not. validate_function_args(node, 1, error)) then |
| 386 | 542 | value = create_scalar(0.0_real64) |
| 387 | 543 | return |
| 388 | 544 | end if |
| 389 | 545 | value = eval_complex_functions(node%function_name, args(1)) |
| 390 | 546 | |
| 391 | 547 | case ('sqrt') |
| 392 | | - if (node%arg_count /= 1) then |
| 393 | | - call set_eval_error(error, 'sqrt() expects 1 argument') |
| 548 | + if (.not. validate_function_args(node, 1, error)) then |
| 394 | 549 | value = create_scalar(0.0_real64) |
| 395 | 550 | return |
| 396 | 551 | end if |
@@ -408,137 +563,19 @@ contains |
| 408 | 563 | end if |
| 409 | 564 | |
| 410 | 565 | case ('abs') |
| 411 | | - if (node%arg_count /= 1) then |
| 412 | | - call set_eval_error(error, 'abs() expects 1 argument') |
| 566 | + if (.not. validate_function_args(node, 1, error)) then |
| 413 | 567 | value = create_scalar(0.0_real64) |
| 414 | 568 | return |
| 415 | 569 | end if |
| 416 | 570 | value = abs_value(args(1)) |
| 417 | 571 | |
| 418 | 572 | ! Matrix creation functions |
| 419 | | - case ('zeros') |
| 420 | | - if (node%arg_count == 1) then |
| 421 | | - ! zeros(n) - square matrix |
| 422 | | - if (args(1)%value_type == VALUE_SCALAR) then |
| 423 | | - if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then |
| 424 | | - value = create_zeros_matrix(int(args(1)%scalar_val), int(args(1)%scalar_val)) |
| 425 | | - else |
| 426 | | - call set_eval_error(error, 'zeros() size must be a positive integer') |
| 427 | | - value = create_scalar(0.0_real64) |
| 428 | | - end if |
| 429 | | - else |
| 430 | | - call set_eval_error(error, 'zeros() expects numeric size argument') |
| 431 | | - value = create_scalar(0.0_real64) |
| 432 | | - end if |
| 433 | | - else if (node%arg_count == 2) then |
| 434 | | - ! zeros(m,n) - rectangular matrix |
| 435 | | - if (args(1)%value_type == VALUE_SCALAR .and. args(2)%value_type == VALUE_SCALAR) then |
| 436 | | - if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val) .and. & |
| 437 | | - args(2)%scalar_val > 0 .and. args(2)%scalar_val == int(args(2)%scalar_val)) then |
| 438 | | - value = create_zeros_matrix(int(args(1)%scalar_val), int(args(2)%scalar_val)) |
| 439 | | - else |
| 440 | | - call set_eval_error(error, 'zeros() sizes must be positive integers') |
| 441 | | - value = create_scalar(0.0_real64) |
| 442 | | - end if |
| 443 | | - else |
| 444 | | - call set_eval_error(error, 'zeros() expects numeric size arguments') |
| 445 | | - value = create_scalar(0.0_real64) |
| 446 | | - end if |
| 447 | | - else |
| 448 | | - call set_eval_error(error, 'zeros() expects 1 or 2 arguments') |
| 449 | | - value = create_scalar(0.0_real64) |
| 450 | | - end if |
| 451 | | - |
| 452 | | - case ('ones') |
| 453 | | - if (node%arg_count == 1) then |
| 454 | | - ! ones(n) - square matrix |
| 455 | | - if (args(1)%value_type == VALUE_SCALAR) then |
| 456 | | - if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then |
| 457 | | - value = create_ones_matrix(int(args(1)%scalar_val), int(args(1)%scalar_val)) |
| 458 | | - else |
| 459 | | - call set_eval_error(error, 'ones() size must be a positive integer') |
| 460 | | - value = create_scalar(0.0_real64) |
| 461 | | - end if |
| 462 | | - else |
| 463 | | - call set_eval_error(error, 'ones() expects numeric size argument') |
| 464 | | - value = create_scalar(0.0_real64) |
| 465 | | - end if |
| 466 | | - else if (node%arg_count == 2) then |
| 467 | | - ! ones(m,n) - rectangular matrix |
| 468 | | - if (args(1)%value_type == VALUE_SCALAR .and. args(2)%value_type == VALUE_SCALAR) then |
| 469 | | - if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val) .and. & |
| 470 | | - args(2)%scalar_val > 0 .and. args(2)%scalar_val == int(args(2)%scalar_val)) then |
| 471 | | - value = create_ones_matrix(int(args(1)%scalar_val), int(args(2)%scalar_val)) |
| 472 | | - else |
| 473 | | - call set_eval_error(error, 'ones() sizes must be positive integers') |
| 474 | | - value = create_scalar(0.0_real64) |
| 475 | | - end if |
| 476 | | - else |
| 477 | | - call set_eval_error(error, 'ones() expects numeric size arguments') |
| 478 | | - value = create_scalar(0.0_real64) |
| 479 | | - end if |
| 480 | | - else |
| 481 | | - call set_eval_error(error, 'ones() expects 1 or 2 arguments') |
| 482 | | - value = create_scalar(0.0_real64) |
| 483 | | - end if |
| 484 | | - |
| 485 | | - case ('eye') |
| 486 | | - if (node%arg_count /= 1) then |
| 487 | | - call set_eval_error(error, 'eye() expects 1 argument') |
| 488 | | - value = create_scalar(0.0_real64) |
| 489 | | - return |
| 490 | | - end if |
| 491 | | - if (args(1)%value_type == VALUE_SCALAR) then |
| 492 | | - if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then |
| 493 | | - value = create_eye_matrix(int(args(1)%scalar_val)) |
| 494 | | - else |
| 495 | | - call set_eval_error(error, 'eye() size must be a positive integer') |
| 496 | | - value = create_scalar(0.0_real64) |
| 497 | | - end if |
| 498 | | - else |
| 499 | | - call set_eval_error(error, 'eye() expects numeric size argument') |
| 500 | | - value = create_scalar(0.0_real64) |
| 501 | | - end if |
| 573 | + case ('zeros', 'ones', 'eye') |
| 574 | + value = eval_matrix_creation(node%function_name, node, args, error) |
| 502 | 575 | |
| 503 | 576 | ! Matrix functions |
| 504 | | - case ('transpose', 'trans') |
| 505 | | - if (node%arg_count /= 1) then |
| 506 | | - call set_eval_error(error, 'transpose() expects 1 argument') |
| 507 | | - value = create_scalar(0.0_real64) |
| 508 | | - return |
| 509 | | - end if |
| 510 | | - if (args(1)%value_type == VALUE_MATRIX) then |
| 511 | | - value = matrix_transpose(args(1)) |
| 512 | | - else |
| 513 | | - call set_eval_error(error, 'transpose() expects a matrix argument') |
| 514 | | - value = create_scalar(0.0_real64) |
| 515 | | - end if |
| 516 | | - |
| 517 | | - case ('det', 'determinant') |
| 518 | | - if (node%arg_count /= 1) then |
| 519 | | - call set_eval_error(error, 'det() expects 1 argument') |
| 520 | | - value = create_scalar(0.0_real64) |
| 521 | | - return |
| 522 | | - end if |
| 523 | | - if (args(1)%value_type == VALUE_MATRIX) then |
| 524 | | - value = create_scalar(matrix_determinant(args(1))) |
| 525 | | - else |
| 526 | | - call set_eval_error(error, 'det() expects a matrix argument') |
| 527 | | - value = create_scalar(0.0_real64) |
| 528 | | - end if |
| 529 | | - |
| 530 | | - case ('inv', 'inverse') |
| 531 | | - if (node%arg_count /= 1) then |
| 532 | | - call set_eval_error(error, 'inv() expects 1 argument') |
| 533 | | - value = create_scalar(0.0_real64) |
| 534 | | - return |
| 535 | | - end if |
| 536 | | - if (args(1)%value_type == VALUE_MATRIX) then |
| 537 | | - value = matrix_inverse(args(1)) |
| 538 | | - else |
| 539 | | - call set_eval_error(error, 'inv() expects a matrix argument') |
| 540 | | - value = create_scalar(0.0_real64) |
| 541 | | - end if |
| 577 | + case ('transpose', 'trans', 'det', 'determinant', 'inv', 'inverse') |
| 578 | + value = eval_matrix_operation(node%function_name, node, args, error) |
| 542 | 579 | |
| 543 | 580 | case ('solve') |
| 544 | 581 | ! Solve linear system Ax = b |