fortrangoingonforty/fortbite / d6fd7dd

Browse files

fixes, cleaning

Authored by espadonne
SHA
d6fd7ddf5ce52b4de2e98a9d8007d296d6169c8d
Parents
871d980
Tree
8c082ce

2 changed files

StatusFile+-
M src/fortbite_evaluator_m.f90 176 139
M src/fortbite_types_m.f90 18 31
src/fortbite_evaluator_m.f90modified
@@ -295,6 +295,169 @@ contains
295295
         end select
296296
     end function evaluate_unary_op
297297
     
298
+    !> Validate function argument count and return error if invalid
299
+    logical function validate_function_args(node, expected_count, error) result(is_valid)
300
+        type(ast_node_t), pointer, intent(in) :: node
301
+        integer, intent(in) :: expected_count
302
+        type(evaluation_error_t), intent(out) :: error
303
+        
304
+        character(len=100) :: error_msg
305
+        
306
+        error%has_error = .false.
307
+        
308
+        if (node%arg_count /= expected_count) then
309
+            if (expected_count == 1) then
310
+                write(error_msg, '(A,A,A)') trim(node%function_name), '() expects 1 argument'
311
+            else
312
+                write(error_msg, '(A,A,A,I0,A)') trim(node%function_name), '() expects ', '', expected_count, ' arguments'
313
+            end if
314
+            call set_eval_error(error, error_msg)
315
+            is_valid = .false.
316
+        else
317
+            is_valid = .true.
318
+        end if
319
+    end function validate_function_args
320
+    
321
+    !> Validate matrix dimension arguments (positive integers)
322
+    logical function validate_matrix_dims(arg1, arg2, func_name, error) result(is_valid)
323
+        type(value_t), intent(in) :: arg1
324
+        type(value_t), intent(in), optional :: arg2
325
+        character(len=*), intent(in) :: func_name
326
+        type(evaluation_error_t), intent(out) :: error
327
+        
328
+        character(len=100) :: error_msg
329
+        
330
+        error%has_error = .false.
331
+        is_valid = .false.
332
+        
333
+        ! Check first argument
334
+        if (arg1%value_type /= VALUE_SCALAR) then
335
+            write(error_msg, '(A,A)') trim(func_name), '() expects numeric size argument'
336
+            call set_eval_error(error, error_msg)
337
+            return
338
+        end if
339
+        
340
+        if (arg1%scalar_val <= 0 .or. arg1%scalar_val /= int(arg1%scalar_val)) then
341
+            write(error_msg, '(A,A)') trim(func_name), '() size must be a positive integer'
342
+            call set_eval_error(error, error_msg)
343
+            return
344
+        end if
345
+        
346
+        ! Check second argument if present
347
+        if (present(arg2)) then
348
+            if (arg2%value_type /= VALUE_SCALAR) then
349
+                write(error_msg, '(A,A)') trim(func_name), '() expects numeric size arguments'
350
+                call set_eval_error(error, error_msg)
351
+                return
352
+            end if
353
+            
354
+            if (arg2%scalar_val <= 0 .or. arg2%scalar_val /= int(arg2%scalar_val)) then
355
+                write(error_msg, '(A,A)') trim(func_name), '() sizes must be positive integers'
356
+                call set_eval_error(error, error_msg)
357
+                return
358
+            end if
359
+        end if
360
+        
361
+        is_valid = .true.
362
+    end function validate_matrix_dims
363
+    
364
+    !> Handle matrix creation functions (zeros, ones, eye)
365
+    function eval_matrix_creation(func_name, node, args, error) result(value)
366
+        character(len=*), intent(in) :: func_name
367
+        type(ast_node_t), pointer, intent(in) :: node
368
+        type(value_t), intent(in) :: args(:)
369
+        type(evaluation_error_t), intent(out) :: error
370
+        type(value_t) :: value
371
+        
372
+        integer :: rows, cols
373
+        character(len=100) :: error_msg
374
+        
375
+        select case (trim(func_name))
376
+        case ('zeros', 'ones')
377
+            if (node%arg_count == 1) then
378
+                ! Square matrix
379
+                if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then
380
+                    rows = int(args(1)%scalar_val)
381
+                    if (trim(func_name) == 'zeros') then
382
+                        value = create_zeros_matrix(rows, rows)
383
+                    else
384
+                        value = create_ones_matrix(rows, rows)
385
+                    end if
386
+                else
387
+                    value = create_scalar(0.0_real64)
388
+                end if
389
+            else if (node%arg_count == 2) then
390
+                ! Rectangular matrix
391
+                if (validate_matrix_dims(args(1), args(2), func_name, error)) then
392
+                    rows = int(args(1)%scalar_val)
393
+                    cols = int(args(2)%scalar_val)
394
+                    if (trim(func_name) == 'zeros') then
395
+                        value = create_zeros_matrix(rows, cols)
396
+                    else
397
+                        value = create_ones_matrix(rows, cols)
398
+                    end if
399
+                else
400
+                    value = create_scalar(0.0_real64)
401
+                end if
402
+            else
403
+                write(error_msg, '(A,A)') trim(func_name), '() expects 1 or 2 arguments'
404
+                call set_eval_error(error, error_msg)
405
+                value = create_scalar(0.0_real64)
406
+            end if
407
+            
408
+        case ('eye')
409
+            if (node%arg_count == 1) then
410
+                if (validate_matrix_dims(args(1), func_name=func_name, error=error)) then
411
+                    rows = int(args(1)%scalar_val)
412
+                    value = create_eye_matrix(rows)
413
+                else
414
+                    value = create_scalar(0.0_real64)
415
+                end if
416
+            else
417
+                call set_eval_error(error, 'eye() expects 1 argument')
418
+                value = create_scalar(0.0_real64)
419
+            end if
420
+        end select
421
+    end function eval_matrix_creation
422
+    
423
+    !> Handle single-argument matrix operation functions
424
+    function eval_matrix_operation(func_name, node, args, error) result(value)
425
+        character(len=*), intent(in) :: func_name
426
+        type(ast_node_t), pointer, intent(in) :: node
427
+        type(value_t), intent(in) :: args(:)
428
+        type(evaluation_error_t), intent(out) :: error
429
+        type(value_t) :: value
430
+        
431
+        character(len=100) :: error_msg
432
+        
433
+        ! Validate single argument
434
+        if (.not. validate_function_args(node, 1, error)) then
435
+            value = create_scalar(0.0_real64)
436
+            return
437
+        end if
438
+        
439
+        ! Validate matrix argument
440
+        if (args(1)%value_type /= VALUE_MATRIX) then
441
+            write(error_msg, '(A,A)') trim(func_name), '() expects a matrix argument'
442
+            call set_eval_error(error, error_msg)
443
+            value = create_scalar(0.0_real64)
444
+            return
445
+        end if
446
+        
447
+        ! Dispatch to appropriate function
448
+        select case (trim(func_name))
449
+        case ('transpose', 'trans')
450
+            value = matrix_transpose(args(1))
451
+        case ('det', 'determinant')
452
+            value = create_scalar(matrix_determinant(args(1)))
453
+        case ('inv', 'inverse')
454
+            value = matrix_inverse(args(1))
455
+        case default
456
+            call set_eval_error(error, 'Unknown matrix operation function')
457
+            value = create_scalar(0.0_real64)
458
+        end select
459
+    end function eval_matrix_operation
460
+    
298461
     !> Evaluate a function call
299462
     recursive function evaluate_function_call(node, context, error) result(value)
300463
         type(ast_node_t), pointer, intent(in) :: node
@@ -326,8 +489,7 @@ contains
326489
         ! Trigonometric functions
327490
         case ('sin', 'cos', 'tan', 'asin', 'arcsin', 'acos', 'arccos', 'atan', 'arctan', &
328491
               'sec', 'csc', 'cot')
329
-            if (node%arg_count /= 1) then
330
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
492
+            if (.not. validate_function_args(node, 1, error)) then
331493
                 value = create_scalar(0.0_real64)
332494
                 return
333495
             end if
@@ -335,8 +497,7 @@ contains
335497
             
336498
         ! Hyperbolic functions  
337499
         case ('sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', 'sech', 'csch', 'coth')
338
-            if (node%arg_count /= 1) then
339
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
500
+            if (.not. validate_function_args(node, 1, error)) then
340501
                 value = create_scalar(0.0_real64)
341502
                 return
342503
             end if
@@ -344,8 +505,7 @@ contains
344505
             
345506
         ! Logarithmic functions
346507
         case ('log', 'ln', 'log10', 'lg', 'log2')
347
-            if (node%arg_count /= 1) then
348
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
508
+            if (.not. validate_function_args(node, 1, error)) then
349509
                 value = create_scalar(0.0_real64)
350510
                 return
351511
             end if
@@ -353,8 +513,7 @@ contains
353513
             
354514
         ! Exponential functions
355515
         case ('exp', 'exp2', 'exp10', 'expm1')
356
-            if (node%arg_count /= 1) then
357
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
516
+            if (.not. validate_function_args(node, 1, error)) then
358517
                 value = create_scalar(0.0_real64)
359518
                 return
360519
             end if
@@ -362,8 +521,7 @@ contains
362521
             
363522
         ! Statistical functions
364523
         case ('mean', 'average', 'sum', 'std', 'stddev')
365
-            if (node%arg_count /= 1) then
366
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
524
+            if (.not. validate_function_args(node, 1, error)) then
367525
                 value = create_scalar(0.0_real64)
368526
                 return
369527
             end if
@@ -372,8 +530,7 @@ contains
372530
         ! Special functions
373531
         case ('gamma', 'lgamma', 'loggamma', 'factorial', 'fact', 'erf', 'erfc', &
374532
               'ceil', 'ceiling', 'floor', 'round', 'nint', 'frac', 'fraction')
375
-            if (node%arg_count /= 1) then
376
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
533
+            if (.not. validate_function_args(node, 1, error)) then
377534
                 value = create_scalar(0.0_real64)
378535
                 return
379536
             end if
@@ -381,16 +538,14 @@ contains
381538
             
382539
         ! Complex functions
383540
         case ('real', 're', 'imag', 'im', 'conj', 'conjugate', 'arg', 'phase', 'angle', 'cabs', 'modulus')
384
-            if (node%arg_count /= 1) then
385
-                call set_eval_error(error, trim(node%function_name) // '() expects 1 argument')
541
+            if (.not. validate_function_args(node, 1, error)) then
386542
                 value = create_scalar(0.0_real64)
387543
                 return
388544
             end if
389545
             value = eval_complex_functions(node%function_name, args(1))
390546
             
391547
         case ('sqrt')
392
-            if (node%arg_count /= 1) then
393
-                call set_eval_error(error, 'sqrt() expects 1 argument')
548
+            if (.not. validate_function_args(node, 1, error)) then
394549
                 value = create_scalar(0.0_real64)
395550
                 return
396551
             end if
@@ -408,137 +563,19 @@ contains
408563
             end if
409564
             
410565
         case ('abs')
411
-            if (node%arg_count /= 1) then
412
-                call set_eval_error(error, 'abs() expects 1 argument')
566
+            if (.not. validate_function_args(node, 1, error)) then
413567
                 value = create_scalar(0.0_real64)
414568
                 return
415569
             end if
416570
             value = abs_value(args(1))
417571
             
418572
         ! Matrix creation functions
419
-        case ('zeros')
420
-            if (node%arg_count == 1) then
421
-                ! zeros(n) - square matrix
422
-                if (args(1)%value_type == VALUE_SCALAR) then
423
-                    if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then
424
-                        value = create_zeros_matrix(int(args(1)%scalar_val), int(args(1)%scalar_val))
425
-                    else
426
-                        call set_eval_error(error, 'zeros() size must be a positive integer')
427
-                        value = create_scalar(0.0_real64)
428
-                    end if
429
-                else
430
-                    call set_eval_error(error, 'zeros() expects numeric size argument')
431
-                    value = create_scalar(0.0_real64)
432
-                end if
433
-            else if (node%arg_count == 2) then
434
-                ! zeros(m,n) - rectangular matrix
435
-                if (args(1)%value_type == VALUE_SCALAR .and. args(2)%value_type == VALUE_SCALAR) then
436
-                    if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val) .and. &
437
-                        args(2)%scalar_val > 0 .and. args(2)%scalar_val == int(args(2)%scalar_val)) then
438
-                        value = create_zeros_matrix(int(args(1)%scalar_val), int(args(2)%scalar_val))
439
-                    else
440
-                        call set_eval_error(error, 'zeros() sizes must be positive integers')
441
-                        value = create_scalar(0.0_real64)
442
-                    end if
443
-                else
444
-                    call set_eval_error(error, 'zeros() expects numeric size arguments')
445
-                    value = create_scalar(0.0_real64)
446
-                end if
447
-            else
448
-                call set_eval_error(error, 'zeros() expects 1 or 2 arguments')
449
-                value = create_scalar(0.0_real64)
450
-            end if
451
-            
452
-        case ('ones')
453
-            if (node%arg_count == 1) then
454
-                ! ones(n) - square matrix
455
-                if (args(1)%value_type == VALUE_SCALAR) then
456
-                    if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then
457
-                        value = create_ones_matrix(int(args(1)%scalar_val), int(args(1)%scalar_val))
458
-                    else
459
-                        call set_eval_error(error, 'ones() size must be a positive integer')
460
-                        value = create_scalar(0.0_real64)
461
-                    end if
462
-                else
463
-                    call set_eval_error(error, 'ones() expects numeric size argument')
464
-                    value = create_scalar(0.0_real64)
465
-                end if
466
-            else if (node%arg_count == 2) then
467
-                ! ones(m,n) - rectangular matrix
468
-                if (args(1)%value_type == VALUE_SCALAR .and. args(2)%value_type == VALUE_SCALAR) then
469
-                    if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val) .and. &
470
-                        args(2)%scalar_val > 0 .and. args(2)%scalar_val == int(args(2)%scalar_val)) then
471
-                        value = create_ones_matrix(int(args(1)%scalar_val), int(args(2)%scalar_val))
472
-                    else
473
-                        call set_eval_error(error, 'ones() sizes must be positive integers')
474
-                        value = create_scalar(0.0_real64)
475
-                    end if
476
-                else
477
-                    call set_eval_error(error, 'ones() expects numeric size arguments')
478
-                    value = create_scalar(0.0_real64)
479
-                end if
480
-            else
481
-                call set_eval_error(error, 'ones() expects 1 or 2 arguments')
482
-                value = create_scalar(0.0_real64)
483
-            end if
484
-            
485
-        case ('eye')
486
-            if (node%arg_count /= 1) then
487
-                call set_eval_error(error, 'eye() expects 1 argument')
488
-                value = create_scalar(0.0_real64)
489
-                return
490
-            end if
491
-            if (args(1)%value_type == VALUE_SCALAR) then
492
-                if (args(1)%scalar_val > 0 .and. args(1)%scalar_val == int(args(1)%scalar_val)) then
493
-                    value = create_eye_matrix(int(args(1)%scalar_val))
494
-                else
495
-                    call set_eval_error(error, 'eye() size must be a positive integer')
496
-                    value = create_scalar(0.0_real64)
497
-                end if
498
-            else
499
-                call set_eval_error(error, 'eye() expects numeric size argument')
500
-                value = create_scalar(0.0_real64)
501
-            end if
573
+        case ('zeros', 'ones', 'eye')
574
+            value = eval_matrix_creation(node%function_name, node, args, error)
502575
             
503576
         ! Matrix functions
504
-        case ('transpose', 'trans')
505
-            if (node%arg_count /= 1) then
506
-                call set_eval_error(error, 'transpose() expects 1 argument')
507
-                value = create_scalar(0.0_real64)
508
-                return
509
-            end if
510
-            if (args(1)%value_type == VALUE_MATRIX) then
511
-                value = matrix_transpose(args(1))
512
-            else
513
-                call set_eval_error(error, 'transpose() expects a matrix argument')
514
-                value = create_scalar(0.0_real64)
515
-            end if
516
-            
517
-        case ('det', 'determinant')
518
-            if (node%arg_count /= 1) then
519
-                call set_eval_error(error, 'det() expects 1 argument')
520
-                value = create_scalar(0.0_real64)
521
-                return
522
-            end if
523
-            if (args(1)%value_type == VALUE_MATRIX) then
524
-                value = create_scalar(matrix_determinant(args(1)))
525
-            else
526
-                call set_eval_error(error, 'det() expects a matrix argument')
527
-                value = create_scalar(0.0_real64)
528
-            end if
529
-            
530
-        case ('inv', 'inverse')
531
-            if (node%arg_count /= 1) then
532
-                call set_eval_error(error, 'inv() expects 1 argument')
533
-                value = create_scalar(0.0_real64)
534
-                return
535
-            end if
536
-            if (args(1)%value_type == VALUE_MATRIX) then
537
-                value = matrix_inverse(args(1))
538
-            else
539
-                call set_eval_error(error, 'inv() expects a matrix argument')
540
-                value = create_scalar(0.0_real64)
541
-            end if
577
+        case ('transpose', 'trans', 'det', 'determinant', 'inv', 'inverse')
578
+            value = eval_matrix_operation(node%function_name, node, args, error)
542579
             
543580
         case ('solve')
544581
             ! Solve linear system Ax = b
src/fortbite_types_m.f90modified
@@ -35,6 +35,7 @@ module fortbite_types_m
3535
     end enum
3636
     integer, parameter :: value_type_enum = kind(VALUE_UNDEFINED)
3737
     
38
+    
3839
     !> Enumeration for token types
3940
     enum, bind(c)
4041
         enumerator :: TOKEN_EOF = 0
@@ -150,11 +151,11 @@ contains
150151
         value%complex_matrix_val = matrix_data
151152
     end function create_complex_matrix
152153
     
153
-    !> Create a zeros matrix
154
-    function create_zeros_matrix(rows, cols, precision_kind) result(value)
154
+    !> Internal helper to set up basic matrix properties
155
+    subroutine setup_matrix_base(value, rows, cols, precision_kind)
156
+        type(value_t), intent(inout) :: value
155157
         integer, intent(in) :: rows, cols
156158
         integer, intent(in), optional :: precision_kind
157
-        type(value_t) :: value
158159
         
159160
         value%value_type = VALUE_MATRIX
160161
         value%precision_kind = real64
@@ -165,6 +166,15 @@ contains
165166
         value%is_complex_matrix = .false.
166167
         
167168
         allocate(value%matrix_val(rows, cols))
169
+    end subroutine setup_matrix_base
170
+    
171
+    !> Create a zeros matrix
172
+    function create_zeros_matrix(rows, cols, precision_kind) result(value)
173
+        integer, intent(in) :: rows, cols
174
+        integer, intent(in), optional :: precision_kind
175
+        type(value_t) :: value
176
+        
177
+        call setup_matrix_base(value, rows, cols, precision_kind)
168178
         value%matrix_val = 0.0_real64
169179
     end function create_zeros_matrix
170180
     
@@ -174,15 +184,7 @@ contains
174184
         integer, intent(in), optional :: precision_kind
175185
         type(value_t) :: value
176186
         
177
-        value%value_type = VALUE_MATRIX
178
-        value%precision_kind = real64
179
-        if (present(precision_kind)) value%precision_kind = precision_kind
180
-        
181
-        value%rows = rows
182
-        value%cols = cols
183
-        value%is_complex_matrix = .false.
184
-        
185
-        allocate(value%matrix_val(rows, cols))
187
+        call setup_matrix_base(value, rows, cols, precision_kind)
186188
         value%matrix_val = 1.0_real64
187189
     end function create_ones_matrix
188190
     
@@ -191,17 +193,10 @@ contains
191193
         integer, intent(in) :: size
192194
         integer, intent(in), optional :: precision_kind
193195
         type(value_t) :: value
194
-        integer :: i
195
-        
196
-        value%value_type = VALUE_MATRIX
197
-        value%precision_kind = real64
198
-        if (present(precision_kind)) value%precision_kind = precision_kind
199196
         
200
-        value%rows = size
201
-        value%cols = size
202
-        value%is_complex_matrix = .false.
197
+        integer :: i
203198
         
204
-        allocate(value%matrix_val(size, size))
199
+        call setup_matrix_base(value, size, size, precision_kind)
205200
         value%matrix_val = 0.0_real64
206201
         
207202
         ! Set diagonal elements to 1
@@ -215,19 +210,11 @@ contains
215210
         real(real64), intent(in) :: diagonal_elements(:)
216211
         integer, intent(in), optional :: precision_kind
217212
         type(value_t) :: value
218
-        integer :: i, n
219213
         
214
+        integer :: n, i
220215
         n = size(diagonal_elements)
221216
         
222
-        value%value_type = VALUE_MATRIX
223
-        value%precision_kind = real64
224
-        if (present(precision_kind)) value%precision_kind = precision_kind
225
-        
226
-        value%rows = n
227
-        value%cols = n
228
-        value%is_complex_matrix = .false.
229
-        
230
-        allocate(value%matrix_val(n, n))
217
+        call setup_matrix_base(value, n, n, precision_kind)
231218
         value%matrix_val = 0.0_real64
232219
         
233220
         ! Set diagonal elements