From d02eef2199c8fb23a4bb412781f6da36bb34f725 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Sun, 26 Oct 2025 18:24:00 +0100 Subject: [PATCH] fix: refactor codegen grouping helpers (fixes #1948) --- .../codegen_character_normalization.f90 | 114 ++++--- src/codegen/codegen_declaration_grouping.f90 | 192 ++++++++---- src/codegen/codegen_grouped_body.f90 | 291 ++++++++++++------ 3 files changed, 401 insertions(+), 196 deletions(-) diff --git a/src/codegen/codegen_character_normalization.f90 b/src/codegen/codegen_character_normalization.f90 index 16507822..2513441c 100644 --- a/src/codegen/codegen_character_normalization.f90 +++ b/src/codegen/codegen_character_normalization.f90 @@ -11,16 +11,81 @@ module codegen_character_normalization contains + pure subroutine try_extract_length_from_star(trimmed_str, open_paren, has_length, & + length_spec) + character(len=*), intent(in) :: trimmed_str + integer, intent(in) :: open_paren + logical, intent(inout) :: has_length + character(len=:), allocatable, intent(inout) :: length_spec + integer :: star_pos + integer :: trimmed_len + character(len=:), allocatable :: candidate + + if (has_length) return + + star_pos = index(trimmed_str, "*") + if (star_pos <= 0) return + if (open_paren /= 0) then + if (star_pos > open_paren) return + end if + + trimmed_len = len_trim(trimmed_str) + if (star_pos >= trimmed_len) return + + candidate = trim(trimmed_str(star_pos + 1:trimmed_len)) + if (len_trim(candidate) == 0) return + + length_spec = candidate + has_length = .true. + end subroutine try_extract_length_from_star + + pure subroutine try_extract_length_from_parentheses(trimmed_str, open_paren, & + has_length, length_spec) + character(len=*), intent(in) :: trimmed_str + integer, intent(in) :: open_paren + logical, intent(inout) :: has_length + character(len=:), allocatable, intent(inout) :: length_spec + integer :: close_paren + integer :: depth + integer :: last_char + integer :: idx + character(len=:), allocatable :: candidate + + if (has_length) return + if (open_paren <= 0) return + + depth = 0 + close_paren = 0 + last_char = len_trim(trimmed_str) + + do idx = open_paren + 1, last_char + select case (trimmed_str(idx:idx)) + case ("(") + depth = depth + 1 + case (")") + if (depth == 0) then + close_paren = idx + exit + else + depth = depth - 1 + end if + end select + end do + + if (close_paren <= open_paren + 1) return + + candidate = trim(trimmed_str(open_paren + 1:close_paren - 1)) + if (len_trim(candidate) == 0) return + + length_spec = candidate + has_length = .true. + end subroutine try_extract_length_from_parentheses + subroutine extract_character_length(type_str, has_length, length_spec) character(len=*), intent(in) :: type_str logical, intent(out) :: has_length character(len=:), allocatable, intent(out) :: length_spec - integer :: star_pos integer :: open_paren - integer :: close_paren - integer :: depth - integer :: i - integer :: last_char character(len=:), allocatable :: trimmed_str has_length = .false. @@ -29,41 +94,12 @@ subroutine extract_character_length(type_str, has_length, length_spec) trimmed_str = trim(type_str) open_paren = index(trimmed_str, "(") - star_pos = index(trimmed_str, "*") - if (star_pos > 0) then - if (open_paren == 0 .or. star_pos < open_paren) then - if (star_pos < len_trim(trimmed_str)) then - length_spec = trim(trimmed_str(star_pos + 1:)) - if (len_trim(length_spec) > 0) then - has_length = .true. - return - end if - end if - end if - end if + call try_extract_length_from_star(trimmed_str, open_paren, has_length, & + length_spec) + if (has_length) return - if (open_paren > 0) then - depth = 0 - close_paren = 0 - last_char = len_trim(trimmed_str) - do i = open_paren + 1, last_char - select case (trimmed_str(i:i)) - case ("(") - depth = depth + 1 - case (")") - if (depth == 0) then - close_paren = i - exit - else - depth = depth - 1 - end if - end select - end do - if (close_paren > open_paren + 1) then - length_spec = trim(trimmed_str(open_paren + 1:close_paren - 1)) - if (len_trim(length_spec) > 0) has_length = .true. - end if - end if + call try_extract_length_from_parentheses(trimmed_str, open_paren, & + has_length, length_spec) end subroutine extract_character_length subroutine preprocess_character_type(raw_type, trimmed, has_length, length_spec, & diff --git a/src/codegen/codegen_declaration_grouping.f90 b/src/codegen/codegen_declaration_grouping.f90 index a529c687..29bd638c 100644 --- a/src/codegen/codegen_declaration_grouping.f90 +++ b/src/codegen/codegen_declaration_grouping.f90 @@ -18,58 +18,119 @@ module codegen_declaration_grouping contains - function can_group_declarations(node1, node2) result(can_group) + pure logical function declarations_share_basic_flags(node1, node2) result(match) type(declaration_node), intent(in) :: node1 type(declaration_node), intent(in) :: node2 - logical :: can_group - logical :: types_match - if (node1%initializer_index > 0 .or. node2%initializer_index > 0) then - can_group = .false. + if (node1%initializer_index > 0) then + match = .false. + return + end if + if (node2%initializer_index > 0) then + match = .false. return end if - if (node1%is_array .or. node2%is_array) then - can_group = .false. + match = .false. return end if - if (node1%is_allocatable .neqv. node2%is_allocatable) then - can_group = .false. + match = .false. return end if if (node1%is_pointer .neqv. node2%is_pointer) then - can_group = .false. + match = .false. return end if if (node1%is_target .neqv. node2%is_target) then - can_group = .false. + match = .false. return end if if (node1%is_external .neqv. node2%is_external) then - can_group = .false. + match = .false. return end if if (node1%is_parameter .neqv. node2%is_parameter) then - can_group = .false. + match = .false. + return + end if + + match = .true. + end function declarations_share_basic_flags + + pure logical function declarations_have_matching_types(node1, node2) result(match) + type(declaration_node), intent(in) :: node1 + type(declaration_node), intent(in) :: node2 + logical :: both_have_names + + both_have_names = len_trim(node1%type_name) > 0 .and. & + len_trim(node2%type_name) > 0 + if (both_have_names) then + match = trim(node1%type_name) == trim(node2%type_name) return end if - if (len_trim(node1%type_name) > 0 .and. len_trim(node2%type_name) > 0) then - types_match = trim(node1%type_name) == trim(node2%type_name) - else if (node1%inferred_type%kind > 0 .and. node2%inferred_type%kind > 0) then - types_match = node1%inferred_type%kind == node2%inferred_type%kind + if (node1%inferred_type%kind > 0 .and. node2%inferred_type%kind > 0) then + match = node1%inferred_type%kind == node2%inferred_type%kind else - types_match = .false. + match = .false. + end if + end function declarations_have_matching_types + + pure logical function declarations_match_attributes(node1, node2) result(match) + type(declaration_node), intent(in) :: node1 + type(declaration_node), intent(in) :: node2 + logical :: intents_match + + if (node1%kind_value /= node2%kind_value) then + match = .false. + return + end if + if (node1%has_kind .neqv. node2%has_kind) then + match = .false. + return end if - can_group = types_match .and. & - (node1%kind_value == node2%kind_value) .and. & - (node1%has_kind .eqv. node2%has_kind) .and. & - ((node1%has_intent .and. node2%has_intent .and. & - trim(node1%intent) == trim(node2%intent)) .or. & - (.not. node1%has_intent .and. .not. node2%has_intent)) .and. & - (node1%is_optional .eqv. node2%is_optional) + if (node1%has_intent .and. node2%has_intent) then + intents_match = trim(node1%intent) == trim(node2%intent) + else + intents_match = (.not. node1%has_intent) .and. (.not. node2%has_intent) + end if + if (.not. intents_match) then + match = .false. + return + end if + + if (node1%is_optional .neqv. node2%is_optional) then + match = .false. + return + end if + if (node1%is_target .neqv. node2%is_target) then + match = .false. + return + end if + + match = .true. + end function declarations_match_attributes + + function can_group_declarations(node1, node2) result(can_group) + type(declaration_node), intent(in) :: node1 + type(declaration_node), intent(in) :: node2 + logical :: can_group + logical :: types_match + + if (.not. declarations_share_basic_flags(node1, node2)) then + can_group = .false. + return + end if + + types_match = declarations_have_matching_types(node1, node2) + if (.not. types_match) then + can_group = .false. + return + end if + + can_group = declarations_match_attributes(node1, node2) end function can_group_declarations function can_group_parameters(node1, node2) result(can_group) @@ -87,14 +148,39 @@ function can_group_parameters(node1, node2) result(can_group) (node1%is_target .eqv. node2%is_target) end function can_group_parameters + subroutine resolve_parameter_metadata(node, param_map, intent_text, & + optional_flag, & + target_flag) + type(declaration_node), intent(in) :: node + type(parameter_info_t), intent(in) :: param_map(:) + character(len=:), allocatable, intent(out) :: intent_text + logical, intent(out) :: optional_flag + logical, intent(out) :: target_flag + integer :: idx + + idx = find_parameter_info(param_map, node%var_name) + if (idx > 0) then + intent_text = param_map(idx)%intent_str + optional_flag = param_map(idx)%is_optional + target_flag = param_map(idx)%is_target + return + end if + + if (node%has_intent) then + intent_text = node%intent + else + intent_text = "" + end if + optional_flag = node%is_optional + target_flag = node%is_target + end subroutine resolve_parameter_metadata + function can_group_declarations_with_params(node1, node2, param_map) & result(can_group) type(declaration_node), intent(in) :: node1 type(declaration_node), intent(in) :: node2 type(parameter_info_t), intent(in) :: param_map(:) logical :: can_group - integer :: idx1 - integer :: idx2 character(len=:), allocatable :: intent1 character(len=:), allocatable :: intent2 logical :: optional1 @@ -102,46 +188,28 @@ function can_group_declarations_with_params(node1, node2, param_map) & logical :: target1 logical :: target2 - if (node1%initializer_index > 0 .or. node2%initializer_index > 0) then + if (.not. declarations_share_basic_flags(node1, node2)) then + can_group = .false. + return + end if + + if (trim(node1%type_name) /= trim(node2%type_name)) then + can_group = .false. + return + end if + if (node1%kind_value /= node2%kind_value) then + can_group = .false. + return + end if + if (node1%has_kind .neqv. node2%has_kind) then can_group = .false. return end if - idx1 = find_parameter_info(param_map, node1%var_name) - idx2 = find_parameter_info(param_map, node2%var_name) + call resolve_parameter_metadata(node1, param_map, intent1, optional1, target1) + call resolve_parameter_metadata(node2, param_map, intent2, optional2, target2) - if (idx1 > 0) then - intent1 = param_map(idx1)%intent_str - optional1 = param_map(idx1)%is_optional - target1 = param_map(idx1)%is_target - else - if (node1%has_intent) then - intent1 = node1%intent - else - intent1 = "" - end if - optional1 = node1%is_optional - target1 = node1%is_target - end if - - if (idx2 > 0) then - intent2 = param_map(idx2)%intent_str - optional2 = param_map(idx2)%is_optional - target2 = param_map(idx2)%is_target - else - if (node2%has_intent) then - intent2 = node2%intent - else - intent2 = "" - end if - optional2 = node2%is_optional - target2 = node2%is_target - end if - - can_group = trim(node1%type_name) == trim(node2%type_name) .and. & - node1%kind_value == node2%kind_value .and. & - node1%has_kind .eqv. node2%has_kind .and. & - trim(intent1) == trim(intent2) .and. & + can_group = trim(intent1) == trim(intent2) .and. & optional1 .eqv. optional2 .and. & target1 .eqv. target2 end function can_group_declarations_with_params diff --git a/src/codegen/codegen_grouped_body.f90 b/src/codegen/codegen_grouped_body.f90 index 76552552..c68f5375 100644 --- a/src/codegen/codegen_grouped_body.f90 +++ b/src/codegen/codegen_grouped_body.f90 @@ -149,6 +149,183 @@ function generate_grouped_body_context(arena, body_indices, indent, & code = generate_grouped_body(arena, body_indices, indent) end function generate_grouped_body_context + subroutine emit_declaration_statement(arena, idx, indent_str, code) + type(ast_arena_t), intent(in) :: arena + integer, intent(in) :: idx + character(len=*), intent(in) :: indent_str + character(len=:), allocatable, intent(inout) :: code + character(len=:), allocatable :: stmt_code + + stmt_code = generate_code_from_arena(arena, idx) + code = code // indent_str // stmt_code // new_line('A') + end subroutine emit_declaration_statement + + pure logical function is_groupable_declaration(node) result(can_group) + type(declaration_node), intent(in) :: node + + if (node%is_multi_declaration) then + can_group = .false. + return + end if + if (node%is_array) then + can_group = .false. + return + end if + if (node%is_allocatable) then + can_group = .false. + return + end if + if (node%is_pointer) then + can_group = .false. + return + end if + if (node%is_target) then + can_group = .false. + return + end if + if (node%is_external) then + can_group = .false. + return + end if + if (node%is_parameter) then + can_group = .false. + return + end if + if (node%initializer_index > 0) then + can_group = .false. + return + end if + + can_group = .true. + end function is_groupable_declaration + + subroutine extend_declaration_group(arena, body_indices, start_pos, first_node, & + grouped_names, group_count, next_index) + type(ast_arena_t), intent(in) :: arena + integer, intent(in) :: body_indices(:) + integer, intent(in) :: start_pos + type(declaration_node), intent(in) :: first_node + character(len=64), allocatable, intent(inout) :: grouped_names(:) + integer, intent(inout) :: group_count + integer, intent(out) :: next_index + integer :: j + + j = start_pos + do while (j <= size(body_indices)) + if (body_indices(j) <= 0) exit + if (body_indices(j) > arena%size) exit + if (.not. allocated(arena%entries(body_indices(j))%node)) exit + select type (next_node => arena%entries(body_indices(j))%node) + type is (declaration_node) + if (can_group_declarations(first_node, next_node)) then + group_count = group_count + 1 + call append_name(grouped_names, group_count, & + trim(next_node%var_name)) + j = j + 1 + else + exit + end if + class default + exit + end select + end do + next_index = j + end subroutine extend_declaration_group + + function build_grouped_statement(first_node, grouped_names, group_count) & + result(stmt) + type(declaration_node), intent(in) :: first_node + character(len=64), allocatable, intent(in) :: grouped_names(:) + integer, intent(in) :: group_count + character(len=:), allocatable :: stmt + character(len=:), allocatable :: intent_text + character(len=:), allocatable :: var_list + character(len=64), allocatable :: sorted_names(:) + + sorted_names = grouped_names + call sort_names(sorted_names, group_count) + var_list = build_var_list(sorted_names, group_count) + + if (first_node%has_intent) then + intent_text = first_node%intent + else + intent_text = "" + end if + + stmt = generate_grouped_declaration(first_node%type_name, & + first_node%kind_value, & + first_node%has_kind, & + intent_text, & + var_list, & + first_node%is_optional, & + first_node%is_target) + end function build_grouped_statement + + subroutine extend_parameter_group(arena, body_indices, start_pos, var_list, & + next_index, first_node) + type(ast_arena_t), intent(in) :: arena + integer, intent(in) :: body_indices(:) + integer, intent(in) :: start_pos + character(len=:), allocatable, intent(inout) :: var_list + integer, intent(out) :: next_index + type(parameter_declaration_node), intent(in) :: first_node + integer :: j + + j = start_pos + do while (j <= size(body_indices)) + if (body_indices(j) <= 0) exit + if (body_indices(j) > arena%size) exit + if (.not. allocated(arena%entries(body_indices(j))%node)) exit + select type (next_node => arena%entries(body_indices(j))%node) + type is (parameter_declaration_node) + if (can_group_parameters(first_node, next_node)) then + var_list = var_list // ", " // trim(next_node%name) + j = j + 1 + else + exit + end if + class default + exit + end select + end do + next_index = j + end subroutine extend_parameter_group + + function build_parameter_statement(first_node, var_list) result(stmt) + type(parameter_declaration_node), intent(in) :: first_node + character(len=*), intent(in) :: var_list + character(len=:), allocatable :: stmt + character(len=:), allocatable :: base_type + + if (allocated(first_node%type_name)) then + base_type = first_node%type_name + else + base_type = "real" + end if + + if (is_character_type_string(base_type)) then + stmt = normalize_character_type_param(base_type, & + first_node%has_kind, & + first_node%kind_value) + else + stmt = base_type + if (first_node%has_kind .and. first_node%kind_value > 0) then + stmt = stmt // "(" // & + trim(adjustl(int_to_string(first_node%kind_value))) // ")" + end if + end if + + if (first_node%intent_type /= INTENT_NONE) then + stmt = stmt // ", intent(" // & + intent_type_to_string(first_node%intent_type) // ")" + end if + if (first_node%is_optional) then + stmt = stmt // ", optional" + end if + + stmt = stmt // " :: " // var_list + end function build_parameter_statement + subroutine process_grouped_declarations(arena, body_indices, i, indent_str, code) type(ast_arena_t), intent(in) :: arena integer, intent(in) :: body_indices(:) @@ -157,18 +334,16 @@ subroutine process_grouped_declarations(arena, body_indices, i, indent_str, code character(len=:), allocatable, intent(inout) :: code type(declaration_node) :: first_node - character(len=:), allocatable :: var_list character(len=:), allocatable :: stmt_code character(len=64), allocatable :: grouped_names(:) integer :: group_count - integer :: j - integer :: k + integer :: next_index select type (node => arena%entries(body_indices(i))%node) type is (declaration_node) - if (node%is_multi_declaration) then - stmt_code = generate_code_from_arena(arena, body_indices(i)) - code = code // indent_str // stmt_code // new_line('A') + if (.not. is_groupable_declaration(node)) then + call emit_declaration_statement(arena, body_indices(i), & + indent_str, code) i = i + 1 return end if @@ -178,59 +353,19 @@ subroutine process_grouped_declarations(arena, body_indices, i, indent_str, code allocate (grouped_names(group_count)) grouped_names(1) = trim(node%var_name) - if (node%is_array .or. node%is_allocatable .or. node%is_pointer .or. & - node%is_target .or. node%is_external .or. node%is_parameter .or. & - node%initializer_index > 0) then - stmt_code = generate_code_from_arena(arena, body_indices(i)) - code = code // indent_str // stmt_code // new_line('A') - i = i + 1 - return - end if - - j = i + 1 - do while (j <= size(body_indices)) - if (body_indices(j) <= 0 .or. body_indices(j) > arena%size) exit - if (.not. allocated(arena%entries(body_indices(j))%node)) exit - select type (next_node => arena%entries(body_indices(j))%node) - type is (declaration_node) - if (can_group_declarations(first_node, next_node)) then - group_count = group_count + 1 - call append_name(grouped_names, group_count, & - trim(next_node%var_name)) - j = j + 1 - else - exit - end if - class default - exit - end select - end do + call extend_declaration_group(arena, body_indices, i + 1, first_node, & + grouped_names, group_count, next_index) if (group_count == 1) then - stmt_code = generate_code_from_arena(arena, body_indices(i)) - code = code // indent_str // stmt_code // new_line('A') - i = j - else - call sort_names(grouped_names, group_count) - var_list = build_var_list(grouped_names, group_count) - block - character(len=:), allocatable :: intent_text - if (first_node%has_intent) then - intent_text = first_node%intent - else - intent_text = "" - end if - stmt_code = generate_grouped_declaration(first_node%type_name, & - first_node%kind_value, & - first_node%has_kind, & - intent_text, & - var_list, & - first_node%is_optional, & - first_node%is_target) - end block - code = code // indent_str // stmt_code // new_line('A') - i = j + call emit_declaration_statement(arena, body_indices(i), & + indent_str, code) + i = next_index + return end if + + stmt_code = build_grouped_statement(first_node, grouped_names, group_count) + code = code // indent_str // stmt_code // new_line('A') + i = next_index end select end subroutine process_grouped_declarations @@ -244,53 +379,19 @@ subroutine process_grouped_parameters(arena, body_indices, i, indent_str, code) type(parameter_declaration_node) :: first_node character(len=:), allocatable :: var_list character(len=:), allocatable :: stmt_code - integer :: j + integer :: next_index select type (node => arena%entries(body_indices(i))%node) type is (parameter_declaration_node) first_node = node var_list = trim(node%name) - j = i + 1 - do while (j <= size(body_indices)) - if (body_indices(j) <= 0 .or. body_indices(j) > arena%size) exit - if (.not. allocated(arena%entries(body_indices(j))%node)) exit - select type (next_node => arena%entries(body_indices(j))%node) - type is (parameter_declaration_node) - if (can_group_parameters(first_node, next_node)) then - var_list = var_list // ", " // trim(next_node%name) - j = j + 1 - else - exit - end if - class default - exit - end select - end do + call extend_parameter_group(arena, body_indices, i + 1, var_list, & + next_index, first_node) - if (allocated(first_node%type_name)) then - stmt_code = first_node%type_name - else - stmt_code = "real" - end if - if (is_character_type_string(stmt_code)) then - stmt_code = normalize_character_type_param(stmt_code, & - first_node%has_kind, & - first_node%kind_value) - else if (first_node%has_kind .and. first_node%kind_value > 0) then - stmt_code = stmt_code // "(" // & - trim(adjustl(int_to_string(first_node%kind_value))) // ")" - end if - if (first_node%intent_type /= INTENT_NONE) then - stmt_code = stmt_code // ", intent(" // & - intent_type_to_string(first_node%intent_type) // ")" - end if - if (first_node%is_optional) then - stmt_code = stmt_code // ", optional" - end if - stmt_code = stmt_code // " :: " // var_list + stmt_code = build_parameter_statement(first_node, var_list) code = code // indent_str // stmt_code // new_line('A') - i = j + i = next_index end select end subroutine process_grouped_parameters