Fortran · 11559 bytes Raw Blame History
1 !> Abstract Syntax Tree module for FORTBITE
2 !>
3 !> Defines the AST node types and operations for representing parsed
4 !> mathematical expressions in tree form.
5 module fortbite_ast_m
6 use fortbite_types_m, only: value_t
7 use iso_fortran_env, only: real64
8 implicit none
9 private
10
11 public :: ast_node_t, ast_node_type_enum, operator_type_enum, ast_node_ptr_t
12 public :: AST_LITERAL, AST_IDENTIFIER, AST_BINARY_OP, AST_UNARY_OP
13 public :: AST_FUNCTION_CALL, AST_ASSIGNMENT, AST_PRECISION_SPEC, AST_MATRIX_LITERAL
14 public :: OP_ADD, OP_SUB, OP_MUL, OP_DIV, OP_POW, OP_MOD
15 public :: OP_UNARY_PLUS, OP_UNARY_MINUS
16 public :: create_literal_node, create_identifier_node, create_binary_node
17 public :: create_unary_node, create_function_node, create_assignment_node
18 public :: create_precision_node, create_matrix_literal_node, free_ast, print_ast
19
20 !> AST node types
21 enum, bind(c)
22 enumerator :: AST_LITERAL = 1
23 enumerator :: AST_IDENTIFIER = 2
24 enumerator :: AST_BINARY_OP = 3
25 enumerator :: AST_UNARY_OP = 4
26 enumerator :: AST_FUNCTION_CALL = 5
27 enumerator :: AST_ASSIGNMENT = 6
28 enumerator :: AST_PRECISION_SPEC = 7
29 enumerator :: AST_MATRIX_LITERAL = 8
30 end enum
31 integer, parameter :: ast_node_type_enum = kind(AST_LITERAL)
32
33 !> Operator types
34 enum, bind(c)
35 enumerator :: OP_ADD = 1
36 enumerator :: OP_SUB = 2
37 enumerator :: OP_MUL = 3
38 enumerator :: OP_DIV = 4
39 enumerator :: OP_POW = 5
40 enumerator :: OP_MOD = 6
41 enumerator :: OP_UNARY_PLUS = 7
42 enumerator :: OP_UNARY_MINUS = 8
43 end enum
44 integer, parameter :: operator_type_enum = kind(OP_ADD)
45
46 !> Pointer wrapper for AST nodes in arrays
47 type :: ast_node_ptr_t
48 type(ast_node_t), pointer :: ptr => null()
49 end type ast_node_ptr_t
50
51 !> AST node type - can represent any expression component
52 type :: ast_node_t
53 integer(ast_node_type_enum) :: node_type
54
55 ! Node-specific data
56 type(value_t) :: literal_value ! For AST_LITERAL
57 character(len=:), allocatable :: identifier ! For AST_IDENTIFIER
58 integer(operator_type_enum) :: operator ! For AST_BINARY_OP, AST_UNARY_OP
59 character(len=:), allocatable :: function_name ! For AST_FUNCTION_CALL
60 integer :: precision_digits ! For AST_PRECISION_SPEC
61
62 ! Child node pointers
63 type(ast_node_t), pointer :: left => null() ! Left operand
64 type(ast_node_t), pointer :: right => null() ! Right operand
65 type(ast_node_t), pointer :: operand => null() ! For unary operations
66 type(ast_node_t), pointer :: expression => null() ! For precision specs
67
68 ! Function arguments (array of pointer wrappers)
69 type(ast_node_ptr_t), allocatable :: arguments(:)
70 integer :: arg_count = 0
71
72 ! Matrix literal data
73 real(real64), allocatable :: matrix_elements(:,:)
74 integer :: matrix_rows = 0, matrix_cols = 0
75 end type ast_node_t
76
77 contains
78
79 !> Create a literal value node
80 function create_literal_node(value) result(node)
81 type(value_t), intent(in) :: value
82 type(ast_node_t), pointer :: node
83
84 allocate(node)
85 node%node_type = AST_LITERAL
86 node%literal_value = value
87 end function create_literal_node
88
89 !> Create an identifier node
90 function create_identifier_node(name) result(node)
91 character(len=*), intent(in) :: name
92 type(ast_node_t), pointer :: node
93
94 allocate(node)
95 node%node_type = AST_IDENTIFIER
96 node%identifier = trim(name)
97 end function create_identifier_node
98
99 !> Create a binary operation node
100 function create_binary_node(op, left_node, right_node) result(node)
101 integer(operator_type_enum), intent(in) :: op
102 type(ast_node_t), pointer, intent(in) :: left_node, right_node
103 type(ast_node_t), pointer :: node
104
105 allocate(node)
106 node%node_type = AST_BINARY_OP
107 node%operator = op
108 node%left => left_node
109 node%right => right_node
110 end function create_binary_node
111
112 !> Create a unary operation node
113 function create_unary_node(op, operand_node) result(node)
114 integer(operator_type_enum), intent(in) :: op
115 type(ast_node_t), pointer, intent(in) :: operand_node
116 type(ast_node_t), pointer :: node
117
118 allocate(node)
119 node%node_type = AST_UNARY_OP
120 node%operator = op
121 node%operand => operand_node
122 end function create_unary_node
123
124 !> Create a function call node
125 function create_function_node(func_name, args) result(node)
126 character(len=*), intent(in) :: func_name
127 type(ast_node_t), pointer, intent(in), optional :: args(:)
128 type(ast_node_t), pointer :: node
129 integer :: i
130
131 allocate(node)
132 node%node_type = AST_FUNCTION_CALL
133 node%function_name = trim(func_name)
134
135 if (present(args)) then
136 node%arg_count = size(args)
137 allocate(node%arguments(node%arg_count))
138 do i = 1, node%arg_count
139 node%arguments(i)%ptr => args(i)
140 end do
141 else
142 node%arg_count = 0
143 end if
144 end function create_function_node
145
146 !> Create an assignment node
147 function create_assignment_node(var_name, expression_node) result(node)
148 character(len=*), intent(in) :: var_name
149 type(ast_node_t), pointer, intent(in) :: expression_node
150 type(ast_node_t), pointer :: node
151
152 type(ast_node_t), pointer :: identifier_node
153
154 allocate(node)
155 node%node_type = AST_ASSIGNMENT
156 node%operator = OP_ADD ! Dummy value, not used for assignments
157
158 ! Create identifier node for the variable
159 identifier_node => create_identifier_node(var_name)
160 node%left => identifier_node
161 node%right => expression_node
162 end function create_assignment_node
163
164 !> Create a precision specification node
165 function create_precision_node(expr_node, precision) result(node)
166 type(ast_node_t), pointer, intent(in) :: expr_node
167 integer, intent(in) :: precision
168 type(ast_node_t), pointer :: node
169
170 allocate(node)
171 node%node_type = AST_PRECISION_SPEC
172 node%precision_digits = precision
173 node%expression => expr_node
174 end function create_precision_node
175
176 !> Create a matrix literal node
177 function create_matrix_literal_node(elements, rows, cols) result(node)
178 real(real64), intent(in) :: elements(:,:)
179 integer, intent(in) :: rows, cols
180 type(ast_node_t), pointer :: node
181
182 allocate(node)
183 node%node_type = AST_MATRIX_LITERAL
184 node%matrix_rows = rows
185 node%matrix_cols = cols
186
187 allocate(node%matrix_elements(rows, cols))
188 node%matrix_elements = elements
189 end function create_matrix_literal_node
190
191 !> Free AST and all child nodes
192 recursive subroutine free_ast(node)
193 type(ast_node_t), pointer, intent(inout) :: node
194 integer :: i
195
196 if (.not. associated(node)) return
197
198 ! Free child nodes with nullification to prevent double-freeing
199 if (associated(node%left)) then
200 call free_ast(node%left)
201 nullify(node%left)
202 end if
203 if (associated(node%right)) then
204 call free_ast(node%right)
205 nullify(node%right)
206 end if
207 if (associated(node%operand)) then
208 call free_ast(node%operand)
209 nullify(node%operand)
210 end if
211 if (associated(node%expression)) then
212 call free_ast(node%expression)
213 nullify(node%expression)
214 end if
215
216 ! Free function arguments
217 if (allocated(node%arguments)) then
218 do i = 1, node%arg_count
219 if (associated(node%arguments(i)%ptr)) then
220 call free_ast(node%arguments(i)%ptr)
221 nullify(node%arguments(i)%ptr)
222 end if
223 end do
224 deallocate(node%arguments)
225 end if
226
227 ! Free matrix literal data
228 if (allocated(node%matrix_elements)) then
229 deallocate(node%matrix_elements)
230 end if
231
232 ! Free the node itself
233 deallocate(node)
234 nullify(node)
235 end subroutine free_ast
236
237 !> Print AST for debugging (recursive)
238 recursive subroutine print_ast(node, indent)
239 type(ast_node_t), pointer, intent(in) :: node
240 integer, intent(in), optional :: indent
241
242 integer :: ind, i
243 character(len=50) :: spaces
244
245 if (.not. associated(node)) return
246
247 ind = 0
248 if (present(indent)) ind = indent
249
250 spaces = repeat(' ', ind)
251
252 select case (node%node_type)
253 case (AST_LITERAL)
254 write(*, '(A,A)') trim(spaces), 'LITERAL: [value]'
255
256 case (AST_IDENTIFIER)
257 write(*, '(A,A,A)') trim(spaces), 'IDENTIFIER: ', node%identifier
258
259 case (AST_BINARY_OP)
260 write(*, '(A,A,A)') trim(spaces), 'BINARY_OP: ', get_operator_name(node%operator)
261 call print_ast(node%left, ind + 2)
262 call print_ast(node%right, ind + 2)
263
264 case (AST_UNARY_OP)
265 write(*, '(A,A,A)') trim(spaces), 'UNARY_OP: ', get_operator_name(node%operator)
266 call print_ast(node%operand, ind + 2)
267
268 case (AST_FUNCTION_CALL)
269 write(*, '(A,A,A,A,I0,A)') trim(spaces), 'FUNCTION: ', node%function_name, &
270 ' (', node%arg_count, ' args)'
271 if (allocated(node%arguments)) then
272 do i = 1, node%arg_count
273 call print_ast(node%arguments(i)%ptr, ind + 2)
274 end do
275 end if
276
277 case (AST_ASSIGNMENT)
278 write(*, '(A,A)') trim(spaces), 'ASSIGNMENT:'
279 call print_ast(node%left, ind + 2)
280 call print_ast(node%right, ind + 2)
281
282 case (AST_PRECISION_SPEC)
283 write(*, '(A,A,I0)') trim(spaces), 'PRECISION: ', node%precision_digits
284 call print_ast(node%expression, ind + 2)
285
286 case (AST_MATRIX_LITERAL)
287 write(*, '(A,A,I0,A,I0,A)') trim(spaces), 'MATRIX_LITERAL: [', &
288 node%matrix_rows, 'x', node%matrix_cols, ']'
289
290 case default
291 write(*, '(A,A)') trim(spaces), 'UNKNOWN NODE'
292 end select
293 end subroutine print_ast
294
295 !> Get human-readable operator name
296 function get_operator_name(op) result(name)
297 integer(operator_type_enum), intent(in) :: op
298 character(len=10) :: name
299
300 select case (op)
301 case (OP_ADD)
302 name = '+'
303 case (OP_SUB)
304 name = '-'
305 case (OP_MUL)
306 name = '*'
307 case (OP_DIV)
308 name = '/'
309 case (OP_POW)
310 name = '**'
311 case (OP_MOD)
312 name = 'mod'
313 case (OP_UNARY_PLUS)
314 name = 'unary +'
315 case (OP_UNARY_MINUS)
316 name = 'unary -'
317 case default
318 name = 'unknown'
319 end select
320 end function get_operator_name
321
322 end module fortbite_ast_m