Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[f2dace/dev, fortran] Constant propagation in the array subscripts. #1878

Draft
wants to merge 13 commits into
base: f2dace/dev
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
Missed a case for keyword args.
pratyai committed Jan 18, 2025
commit 8a00fe9ca152f3d8c088c454dab41294b3e4670d
12 changes: 9 additions & 3 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
@@ -787,6 +787,9 @@ def _get_tempvar_name(self):
return tmpname

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
DIRECTLY_REFERNCEABLE = (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)

from dace.frontend.fortran.intrinsics import FortranIntrinsics
if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon",
*FortranIntrinsics.call_extraction_exemptions()]:
@@ -797,12 +800,15 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):

for i, arg in enumerate(node.args):
# Ensure we allow to extract function calls from arguments
if isinstance(arg, (ast_internal_classes.Name_Node, ast_internal_classes.Literal,
ast_internal_classes.Array_Subscript_Node, ast_internal_classes.Data_Ref_Node)):
if (isinstance(arg, DIRECTLY_REFERNCEABLE)
or (isinstance(arg, ast_internal_classes.Actual_Arg_Spec_Node)
and isinstance(arg.arg, DIRECTLY_REFERNCEABLE))):
# If it is a node type that's allowed to be directly referenced in a (possibly keyworded) function
# argument, then we keep the node as is.
result.args.append(arg)
continue

# These needs to be extracted, so register a temporary variable.s
# These needs to be extracted, so register a temporary variable.
tmpname = self._get_tempvar_name()
decl = ast_internal_classes.Decl_Stmt_Node(
vardecl=[ast_internal_classes.Var_Decl_Node(name=tmpname, type='VOID', sizes=None, init=None)])