Skip to content

Commit

Permalink
Fix the arg-spec handling of the call extractor. In fact, we can rewrite
Browse files Browse the repository at this point in the history
all of it more concisely and more generally.
  • Loading branch information
pratyai committed Jan 18, 2025
1 parent 6afc0fc commit bf089c6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 135 deletions.
152 changes: 44 additions & 108 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,133 +768,73 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
line_number=node.line_number)


class ArgumentExtractorNodeLister(NodeVisitor):
"""
Finds all arguments in function calls in the AST node and its children that have to be extracted into independent expressions
"""

def __init__(self):
self.nodes: List[ast_internal_classes.Call_Expr_Node] = []

def visit_For_Stmt_Node(self, node: ast_internal_classes.For_Stmt_Node):
return

def visit_If_Then_Stmt_Node(self, node: ast_internal_classes.If_Stmt_Node):
return

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
stop = False
# if hasattr(node, "subroutine"):
# if node.subroutine is True:
# stop = True

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if not stop and node.name.name not in [
"malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()
]:
for i in node.args:
if isinstance(i, (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node,
ast_internal_classes.Actual_Arg_Spec_Node)):
continue
else:
self.nodes.append(i)
return self.generic_visit(node)

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
return


class ArgumentExtractor(NodeTransformer):
"""
Uses the ArgumentExtractorNodeLister to find all function calls
in the AST node and its children that have to be extracted into independent expressions
It then creates a new temporary variable for each of them and replaces the call with the variable.
"""

def __init__(self, program, count=0):
self.count = count
self.program = program

def __init__(self, program):
self._count = 0
ParentScopeAssigner().visit(program)
self.scope_vars = ScopeVarsDeclarations(program)
self.scope_vars.visit(program)
# For a nesting of execution parts (rare, but in case it happens), after visiting each direct child of it,
# `self.execution_preludes[-1]` will contain all the temporary variable assignments necessary for that node.
self.execution_preludes: List[List[ast_internal_classes.BinOp_Node]] = []

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
def _get_tempvar_name(self):
tmpname, self._count = f"tmp_arg_{self._count}", self._count + 1
return tmpname

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon",
*FortranIntrinsics.call_extraction_exemptions()]:
return self.generic_visit(node)
# if node.subroutine:
# return self.generic_visit(node)
if not hasattr(self, "count"):
self.count = 0
tmp = self.count
result = ast_internal_classes.Call_Expr_Node(type=node.type, subroutine=node.subroutine,
name=node.name, args=[], line_number=node.line_number,
parent=node.parent)
result = ast_internal_classes.Call_Expr_Node(
name=node.name, args=[], line_number=node.line_number,
type=node.type, subroutine=node.subroutine, parent=node.parent)

for i, arg in enumerate(node.args):
# Ensure we allow to extract function calls from arguments
if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node,
ast_internal_classes.Actual_Arg_Spec_Node)):
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)):
result.args.append(arg)
continue

# These needs to be extracted, so register a temporary variable.s
tmpname = self._get_tempvar_name()
decl = ast_internal_classes.Decl_Stmt_Node(
vardecl=[ast_internal_classes.Var_Decl_Node(name=tmpname, type='VOID', sizes=None, init=None)])
node.parent.specification_part.specifications.append(decl)

if isinstance(arg, ast_internal_classes.Actual_Arg_Spec_Node):
self.generic_visit(arg.arg)
result.args.append(ast_internal_classes.Actual_Arg_Spec_Node(
arg_name=arg.arg_name, arg=ast_internal_classes.Name_Node(name=tmpname, type=arg.arg.type)))
asgn = ast_internal_classes.BinOp_Node(
op="=", lval=ast_internal_classes.Name_Node(name=tmpname, type=arg.arg.type),
rval=arg.arg, line_number=node.line_number, parent=node.parent)
else:
result.args.append(ast_internal_classes.Name_Node(name="tmp_arg_" + str(tmp), type='VOID'))
tmp = tmp + 1
self.count = tmp
self.generic_visit(arg)
result.args.append(ast_internal_classes.Name_Node(name=tmpname, type=arg.type))
asgn = ast_internal_classes.BinOp_Node(
op="=", lval=ast_internal_classes.Name_Node(name=tmpname, type=arg.type),
rval=arg, line_number=node.line_number, parent=node.parent)

self.execution_preludes[-1].append(asgn)
return result

def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
newbody = []

for child in node.execution:
lister = ArgumentExtractorNodeLister()
lister.visit(child)
res = lister.nodes
for i in res:
if i == child:
res.pop(res.index(i))

if res is not None:

# Variables are counted from 0...end, starting from main node, to all calls nested
# in main node arguments.
# However, we need to define nested ones first.
# We go in reverse order, counting from end-1 to 0.
temp = self.count + len(res) - 1
for i in reversed(range(0, len(res))):

if isinstance(res[i], ast_internal_classes.Data_Ref_Node):
struct_def, cur_var, _ = self.program.structures.find_definition(self.scope_vars, res[i])

var_type = cur_var.type
else:
var_type = res[i].type

node.parent.specification_part.specifications.append(
ast_internal_classes.Decl_Stmt_Node(vardecl=[
ast_internal_classes.Var_Decl_Node(
name="tmp_arg_" + str(temp),
type='VOID',
sizes=None,
init=None,
)
])
)
newbody.append(
ast_internal_classes.BinOp_Node(op="=",
lval=ast_internal_classes.Name_Node(name="tmp_arg_" +
str(temp),
type=res[i].type),
rval=res[i],
line_number=child.line_number, parent=child.parent))
temp = temp - 1

newbody.append(self.visit(child))

return ast_internal_classes.Execution_Part_Node(execution=newbody)
self.execution_preludes.append([])
for ex in node.execution:
ex = self.visit(ex)
newbody.extend(reversed(self.execution_preludes[-1]))
newbody.append(ex)
self.execution_preludes[-1].clear()
self.execution_preludes.pop()
return ast_internal_classes.Execution_Part_Node(execution = newbody)


class FunctionCallTransformer(NodeTransformer):
Expand Down Expand Up @@ -2816,10 +2756,6 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node):
return node

def visit_Actual_Arg_Spec_Node(self, node: ast_internal_classes.Actual_Arg_Spec_Node):

if node.type != 'VOID':
return node

node.arg = self.visit(node.arg)

func_arg_name_type = self._get_type(node.arg)
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,6 +2341,7 @@ def create_internal_ast(cfg: ParseConfig) -> Tuple[ast_components.InternalFortra
ast = correct_for_function_calls(ast)
ast = deconstruct_procedure_calls(ast)
ast = deconstruct_interface_calls(ast)
ast = correct_for_function_calls(ast)

if not cfg.entry_points:
# Keep all the possible entry points.
Expand Down Expand Up @@ -2814,6 +2815,7 @@ def create_sdfg_from_fortran_file_with_options(
ast = deconstruct_interface_calls(ast)
ast = make_argument_mapping_explicit(ast)
ast = convert_data_statements_into_assignments(ast)
ast = correct_for_function_calls(ast)

print("FParser Op: Inject configs & prune...")
ast = inject_const_evals(ast, cfg.config_injections)
Expand Down
46 changes: 20 additions & 26 deletions tests/fortran/arg_extract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np

from dace.frontend.fortran import fortran_parser
from dace.frontend.fortran.fortran_parser import create_singular_sdfg_from_string
from tests.fortran.fortran_test_helper import SourceCodeBuilder


def test_fortran_frontend_arg_extract():
test_string = """
Expand Down Expand Up @@ -42,32 +45,23 @@ def test_fortran_frontend_arg_extract():


def test_fortran_frontend_arg_extract3():
test_string = """
PROGRAM arg_extract3
implicit none
real, dimension(2) :: d
real, dimension(2) :: res
CALL arg_extract3_test_function(d,res)
end
SUBROUTINE arg_extract3_test_function(d,res)
real, dimension(2) :: d
real, dimension(2) :: res
integer :: jg
logical, dimension(2) :: is_cloud
jg = 1
is_cloud(1) = .true.
d(1)=10
d(2)=20
res(1) = MERGE(MERGE(d(1), d(2), d(1) < d(2) .AND. is_cloud(jg)), 0.0D0, is_cloud(jg))
res(2) = 52
END SUBROUTINE arg_extract3_test_function
"""

sdfg = fortran_parser.create_sdfg_from_string(test_string, "arg_extract3_test", normalize_offsets=True)
sources, main = SourceCodeBuilder().add_file("""
subroutine main(d, res)
implicit none
real, dimension(2) :: d
real, dimension(2) :: res
integer :: jg
logical, dimension(2) :: is_cloud
jg = 1
is_cloud(1) = .true.
d(1) = 10
d(2) = 20
res(1) = merge(merge(d(1), d(2), d(1) < d(2) .and. is_cloud(jg)), 0.0, is_cloud(jg))
res(2) = 52
end subroutine main
""").check_with_gfortran().get()

sdfg = create_singular_sdfg_from_string(sources, 'main', True)
#sdfg.simplify(verbose=True)
sdfg.compile()

Expand Down
5 changes: 4 additions & 1 deletion tests/fortran/multisdfg_construction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,13 @@ def test_subroutine_with_differnt_ways_to_specificy_arg():
real, intent(out) :: z(1)
z = fun(1.2)
z = fun(x=1.2)
z = fun(x=1.2+2.1)
z = fun(x)
z = fun(x=x)
z = fun(y)
z = fun(x=y)
z = fun(fun(x))
z = fun(x=fun(x))
end subroutine main
""").check_with_gfortran().get()
# Construct
Expand All @@ -201,4 +204,4 @@ def test_subroutine_with_differnt_ways_to_specificy_arg():

main = gmap['main'].compile()
main(z=z)
assert np.allclose(z, [4.2])
assert np.allclose(z, [4.4])

0 comments on commit bf089c6

Please sign in to comment.