| 1 | !> Abstract Syntax Tree module for FORTBITE |
| 2 | !> |
| 3 | !> Defines the AST node types and operations for representing parsed |
| 4 | !> mathematical expressions in tree form. |
| 5 | module fortbite_ast_m |
| 6 | use fortbite_types_m, only: value_t |
| 7 | use iso_fortran_env, only: real64 |
| 8 | implicit none |
| 9 | private |
| 10 | |
| 11 | public :: ast_node_t, ast_node_type_enum, operator_type_enum, ast_node_ptr_t |
| 12 | public :: AST_LITERAL, AST_IDENTIFIER, AST_BINARY_OP, AST_UNARY_OP |
| 13 | public :: AST_FUNCTION_CALL, AST_ASSIGNMENT, AST_PRECISION_SPEC, AST_MATRIX_LITERAL |
| 14 | public :: OP_ADD, OP_SUB, OP_MUL, OP_DIV, OP_POW, OP_MOD |
| 15 | public :: OP_UNARY_PLUS, OP_UNARY_MINUS |
| 16 | public :: create_literal_node, create_identifier_node, create_binary_node |
| 17 | public :: create_unary_node, create_function_node, create_assignment_node |
| 18 | public :: create_precision_node, create_matrix_literal_node, free_ast, print_ast |
| 19 | |
| 20 | !> AST node types |
| 21 | enum, bind(c) |
| 22 | enumerator :: AST_LITERAL = 1 |
| 23 | enumerator :: AST_IDENTIFIER = 2 |
| 24 | enumerator :: AST_BINARY_OP = 3 |
| 25 | enumerator :: AST_UNARY_OP = 4 |
| 26 | enumerator :: AST_FUNCTION_CALL = 5 |
| 27 | enumerator :: AST_ASSIGNMENT = 6 |
| 28 | enumerator :: AST_PRECISION_SPEC = 7 |
| 29 | enumerator :: AST_MATRIX_LITERAL = 8 |
| 30 | end enum |
| 31 | integer, parameter :: ast_node_type_enum = kind(AST_LITERAL) |
| 32 | |
| 33 | !> Operator types |
| 34 | enum, bind(c) |
| 35 | enumerator :: OP_ADD = 1 |
| 36 | enumerator :: OP_SUB = 2 |
| 37 | enumerator :: OP_MUL = 3 |
| 38 | enumerator :: OP_DIV = 4 |
| 39 | enumerator :: OP_POW = 5 |
| 40 | enumerator :: OP_MOD = 6 |
| 41 | enumerator :: OP_UNARY_PLUS = 7 |
| 42 | enumerator :: OP_UNARY_MINUS = 8 |
| 43 | end enum |
| 44 | integer, parameter :: operator_type_enum = kind(OP_ADD) |
| 45 | |
| 46 | !> Pointer wrapper for AST nodes in arrays |
| 47 | type :: ast_node_ptr_t |
| 48 | type(ast_node_t), pointer :: ptr => null() |
| 49 | end type ast_node_ptr_t |
| 50 | |
| 51 | !> AST node type - can represent any expression component |
| 52 | type :: ast_node_t |
| 53 | integer(ast_node_type_enum) :: node_type |
| 54 | |
| 55 | ! Node-specific data |
| 56 | type(value_t) :: literal_value ! For AST_LITERAL |
| 57 | character(len=:), allocatable :: identifier ! For AST_IDENTIFIER |
| 58 | integer(operator_type_enum) :: operator ! For AST_BINARY_OP, AST_UNARY_OP |
| 59 | character(len=:), allocatable :: function_name ! For AST_FUNCTION_CALL |
| 60 | integer :: precision_digits ! For AST_PRECISION_SPEC |
| 61 | |
| 62 | ! Child node pointers |
| 63 | type(ast_node_t), pointer :: left => null() ! Left operand |
| 64 | type(ast_node_t), pointer :: right => null() ! Right operand |
| 65 | type(ast_node_t), pointer :: operand => null() ! For unary operations |
| 66 | type(ast_node_t), pointer :: expression => null() ! For precision specs |
| 67 | |
| 68 | ! Function arguments (array of pointer wrappers) |
| 69 | type(ast_node_ptr_t), allocatable :: arguments(:) |
| 70 | integer :: arg_count = 0 |
| 71 | |
| 72 | ! Matrix literal data |
| 73 | real(real64), allocatable :: matrix_elements(:,:) |
| 74 | integer :: matrix_rows = 0, matrix_cols = 0 |
| 75 | end type ast_node_t |
| 76 | |
| 77 | contains |
| 78 | |
| 79 | !> Create a literal value node |
| 80 | function create_literal_node(value) result(node) |
| 81 | type(value_t), intent(in) :: value |
| 82 | type(ast_node_t), pointer :: node |
| 83 | |
| 84 | allocate(node) |
| 85 | node%node_type = AST_LITERAL |
| 86 | node%literal_value = value |
| 87 | end function create_literal_node |
| 88 | |
| 89 | !> Create an identifier node |
| 90 | function create_identifier_node(name) result(node) |
| 91 | character(len=*), intent(in) :: name |
| 92 | type(ast_node_t), pointer :: node |
| 93 | |
| 94 | allocate(node) |
| 95 | node%node_type = AST_IDENTIFIER |
| 96 | node%identifier = trim(name) |
| 97 | end function create_identifier_node |
| 98 | |
| 99 | !> Create a binary operation node |
| 100 | function create_binary_node(op, left_node, right_node) result(node) |
| 101 | integer(operator_type_enum), intent(in) :: op |
| 102 | type(ast_node_t), pointer, intent(in) :: left_node, right_node |
| 103 | type(ast_node_t), pointer :: node |
| 104 | |
| 105 | allocate(node) |
| 106 | node%node_type = AST_BINARY_OP |
| 107 | node%operator = op |
| 108 | node%left => left_node |
| 109 | node%right => right_node |
| 110 | end function create_binary_node |
| 111 | |
| 112 | !> Create a unary operation node |
| 113 | function create_unary_node(op, operand_node) result(node) |
| 114 | integer(operator_type_enum), intent(in) :: op |
| 115 | type(ast_node_t), pointer, intent(in) :: operand_node |
| 116 | type(ast_node_t), pointer :: node |
| 117 | |
| 118 | allocate(node) |
| 119 | node%node_type = AST_UNARY_OP |
| 120 | node%operator = op |
| 121 | node%operand => operand_node |
| 122 | end function create_unary_node |
| 123 | |
| 124 | !> Create a function call node |
| 125 | function create_function_node(func_name, args) result(node) |
| 126 | character(len=*), intent(in) :: func_name |
| 127 | type(ast_node_t), pointer, intent(in), optional :: args(:) |
| 128 | type(ast_node_t), pointer :: node |
| 129 | integer :: i |
| 130 | |
| 131 | allocate(node) |
| 132 | node%node_type = AST_FUNCTION_CALL |
| 133 | node%function_name = trim(func_name) |
| 134 | |
| 135 | if (present(args)) then |
| 136 | node%arg_count = size(args) |
| 137 | allocate(node%arguments(node%arg_count)) |
| 138 | do i = 1, node%arg_count |
| 139 | node%arguments(i)%ptr => args(i) |
| 140 | end do |
| 141 | else |
| 142 | node%arg_count = 0 |
| 143 | end if |
| 144 | end function create_function_node |
| 145 | |
| 146 | !> Create an assignment node |
| 147 | function create_assignment_node(var_name, expression_node) result(node) |
| 148 | character(len=*), intent(in) :: var_name |
| 149 | type(ast_node_t), pointer, intent(in) :: expression_node |
| 150 | type(ast_node_t), pointer :: node |
| 151 | |
| 152 | type(ast_node_t), pointer :: identifier_node |
| 153 | |
| 154 | allocate(node) |
| 155 | node%node_type = AST_ASSIGNMENT |
| 156 | node%operator = OP_ADD ! Dummy value, not used for assignments |
| 157 | |
| 158 | ! Create identifier node for the variable |
| 159 | identifier_node => create_identifier_node(var_name) |
| 160 | node%left => identifier_node |
| 161 | node%right => expression_node |
| 162 | end function create_assignment_node |
| 163 | |
| 164 | !> Create a precision specification node |
| 165 | function create_precision_node(expr_node, precision) result(node) |
| 166 | type(ast_node_t), pointer, intent(in) :: expr_node |
| 167 | integer, intent(in) :: precision |
| 168 | type(ast_node_t), pointer :: node |
| 169 | |
| 170 | allocate(node) |
| 171 | node%node_type = AST_PRECISION_SPEC |
| 172 | node%precision_digits = precision |
| 173 | node%expression => expr_node |
| 174 | end function create_precision_node |
| 175 | |
| 176 | !> Create a matrix literal node |
| 177 | function create_matrix_literal_node(elements, rows, cols) result(node) |
| 178 | real(real64), intent(in) :: elements(:,:) |
| 179 | integer, intent(in) :: rows, cols |
| 180 | type(ast_node_t), pointer :: node |
| 181 | |
| 182 | allocate(node) |
| 183 | node%node_type = AST_MATRIX_LITERAL |
| 184 | node%matrix_rows = rows |
| 185 | node%matrix_cols = cols |
| 186 | |
| 187 | allocate(node%matrix_elements(rows, cols)) |
| 188 | node%matrix_elements = elements |
| 189 | end function create_matrix_literal_node |
| 190 | |
| 191 | !> Free AST and all child nodes |
| 192 | recursive subroutine free_ast(node) |
| 193 | type(ast_node_t), pointer, intent(inout) :: node |
| 194 | integer :: i |
| 195 | |
| 196 | if (.not. associated(node)) return |
| 197 | |
| 198 | ! Free child nodes with nullification to prevent double-freeing |
| 199 | if (associated(node%left)) then |
| 200 | call free_ast(node%left) |
| 201 | nullify(node%left) |
| 202 | end if |
| 203 | if (associated(node%right)) then |
| 204 | call free_ast(node%right) |
| 205 | nullify(node%right) |
| 206 | end if |
| 207 | if (associated(node%operand)) then |
| 208 | call free_ast(node%operand) |
| 209 | nullify(node%operand) |
| 210 | end if |
| 211 | if (associated(node%expression)) then |
| 212 | call free_ast(node%expression) |
| 213 | nullify(node%expression) |
| 214 | end if |
| 215 | |
| 216 | ! Free function arguments |
| 217 | if (allocated(node%arguments)) then |
| 218 | do i = 1, node%arg_count |
| 219 | if (associated(node%arguments(i)%ptr)) then |
| 220 | call free_ast(node%arguments(i)%ptr) |
| 221 | nullify(node%arguments(i)%ptr) |
| 222 | end if |
| 223 | end do |
| 224 | deallocate(node%arguments) |
| 225 | end if |
| 226 | |
| 227 | ! Free matrix literal data |
| 228 | if (allocated(node%matrix_elements)) then |
| 229 | deallocate(node%matrix_elements) |
| 230 | end if |
| 231 | |
| 232 | ! Free the node itself |
| 233 | deallocate(node) |
| 234 | nullify(node) |
| 235 | end subroutine free_ast |
| 236 | |
| 237 | !> Print AST for debugging (recursive) |
| 238 | recursive subroutine print_ast(node, indent) |
| 239 | type(ast_node_t), pointer, intent(in) :: node |
| 240 | integer, intent(in), optional :: indent |
| 241 | |
| 242 | integer :: ind, i |
| 243 | character(len=50) :: spaces |
| 244 | |
| 245 | if (.not. associated(node)) return |
| 246 | |
| 247 | ind = 0 |
| 248 | if (present(indent)) ind = indent |
| 249 | |
| 250 | spaces = repeat(' ', ind) |
| 251 | |
| 252 | select case (node%node_type) |
| 253 | case (AST_LITERAL) |
| 254 | write(*, '(A,A)') trim(spaces), 'LITERAL: [value]' |
| 255 | |
| 256 | case (AST_IDENTIFIER) |
| 257 | write(*, '(A,A,A)') trim(spaces), 'IDENTIFIER: ', node%identifier |
| 258 | |
| 259 | case (AST_BINARY_OP) |
| 260 | write(*, '(A,A,A)') trim(spaces), 'BINARY_OP: ', get_operator_name(node%operator) |
| 261 | call print_ast(node%left, ind + 2) |
| 262 | call print_ast(node%right, ind + 2) |
| 263 | |
| 264 | case (AST_UNARY_OP) |
| 265 | write(*, '(A,A,A)') trim(spaces), 'UNARY_OP: ', get_operator_name(node%operator) |
| 266 | call print_ast(node%operand, ind + 2) |
| 267 | |
| 268 | case (AST_FUNCTION_CALL) |
| 269 | write(*, '(A,A,A,A,I0,A)') trim(spaces), 'FUNCTION: ', node%function_name, & |
| 270 | ' (', node%arg_count, ' args)' |
| 271 | if (allocated(node%arguments)) then |
| 272 | do i = 1, node%arg_count |
| 273 | call print_ast(node%arguments(i)%ptr, ind + 2) |
| 274 | end do |
| 275 | end if |
| 276 | |
| 277 | case (AST_ASSIGNMENT) |
| 278 | write(*, '(A,A)') trim(spaces), 'ASSIGNMENT:' |
| 279 | call print_ast(node%left, ind + 2) |
| 280 | call print_ast(node%right, ind + 2) |
| 281 | |
| 282 | case (AST_PRECISION_SPEC) |
| 283 | write(*, '(A,A,I0)') trim(spaces), 'PRECISION: ', node%precision_digits |
| 284 | call print_ast(node%expression, ind + 2) |
| 285 | |
| 286 | case (AST_MATRIX_LITERAL) |
| 287 | write(*, '(A,A,I0,A,I0,A)') trim(spaces), 'MATRIX_LITERAL: [', & |
| 288 | node%matrix_rows, 'x', node%matrix_cols, ']' |
| 289 | |
| 290 | case default |
| 291 | write(*, '(A,A)') trim(spaces), 'UNKNOWN NODE' |
| 292 | end select |
| 293 | end subroutine print_ast |
| 294 | |
| 295 | !> Get human-readable operator name |
| 296 | function get_operator_name(op) result(name) |
| 297 | integer(operator_type_enum), intent(in) :: op |
| 298 | character(len=10) :: name |
| 299 | |
| 300 | select case (op) |
| 301 | case (OP_ADD) |
| 302 | name = '+' |
| 303 | case (OP_SUB) |
| 304 | name = '-' |
| 305 | case (OP_MUL) |
| 306 | name = '*' |
| 307 | case (OP_DIV) |
| 308 | name = '/' |
| 309 | case (OP_POW) |
| 310 | name = '**' |
| 311 | case (OP_MOD) |
| 312 | name = 'mod' |
| 313 | case (OP_UNARY_PLUS) |
| 314 | name = 'unary +' |
| 315 | case (OP_UNARY_MINUS) |
| 316 | name = 'unary -' |
| 317 | case default |
| 318 | name = 'unknown' |
| 319 | end select |
| 320 | end function get_operator_name |
| 321 | |
| 322 | end module fortbite_ast_m |