diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index a9f26d544a..ae0a24605d 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -40,7 +40,12 @@ def __init__(self, program: dace.CompiledSDFG): ] def __call__(self, *args: Any, **kwargs: Any) -> None: - self.sdfg_program(*args, **kwargs) + result = self.sdfg_program(*args, **kwargs) + assert result is None + + def fast_call(self) -> None: + result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs) + assert result is None @dataclasses.dataclass(frozen=True) @@ -104,7 +109,7 @@ def decorated_program( *args: Any, offset_provider: common.OffsetProvider, out: Any = None, - ) -> Any: + ) -> None: if out is not None: args = (*args, out) if len(sdfg.arg_names) > len(args): @@ -141,7 +146,7 @@ def decorated_program( ), f"argument '{arg_name}' not found." if use_fast_call: - return sdfg_program.fast_call(*sdfg_program._lastargs) + return inp.fast_call() sdfg_args = dace_backend.get_sdfg_args( sdfg, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py index 54bffcb607..602453fc5a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py @@ -8,7 +8,7 @@ from gt4py.next.program_processors.runners.dace_common.dace_backend import get_sdfg_args -from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir +from gt4py.next.program_processors.runners.dace_fieldview.gtir_sdfg import build_sdfg_from_gtir __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 51778e28c7..e91bd880c6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,18 +18,18 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as gtir_ts +from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( + gtir_dataflow, gtir_python_codegen, - gtir_to_tasklet, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_specifications as ts if TYPE_CHECKING: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg + from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes @@ -51,8 +51,8 @@ def __call__( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Creates the dataflow subgraph representing a GTIR primitive function. @@ -80,12 +80,12 @@ def _parse_fieldop_arg( node: gtir.Expr, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.SDFGBuilder, domain: list[ tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] ], - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], -) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr: + reduce_identity: Optional[gtir_dataflow.SymbolExpr], +) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: arg = sdfg_builder.visit( node, sdfg=sdfg, @@ -98,10 +98,10 @@ def _parse_fieldop_arg( raise ValueError(f"Received {node} as argument to field operator, expected a field.") if isinstance(arg.data_type, ts.ScalarType): - return gtir_to_tasklet.MemletExpr(arg.data_node, sbs.Indices([0])) + return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0])) elif isinstance(arg.data_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_to_tasklet.ValueExpr] = { - dim: gtir_to_tasklet.SymbolExpr( + indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = { + dim: gtir_dataflow.SymbolExpr( dace_gtir_utils.get_map_variable(dim), IteratorIndexDType, ) @@ -110,9 +110,9 @@ def _parse_fieldop_arg( dims = arg.data_type.dims + ( # we add an extra anonymous dimension in the iterator definition to enable # dereferencing elements in `ListType` - [gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, gtir_ts.ListType) else [] + [gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, itir_ts.ListType) else [] ) - return gtir_to_tasklet.IteratorExpr(arg.data_node, dims, indices) + return gtir_dataflow.IteratorExpr(arg.data_node, dims, indices) else: raise NotImplementedError(f"Node type {type(arg.data_type)} not supported.") @@ -139,7 +139,7 @@ def _create_temporary_field( field_shape = list(domain_ubs) if isinstance(output_desc, dace.data.Array): - assert isinstance(node_type.dtype, gtir_ts.ListType) + assert isinstance(node_type.dtype, itir_ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) field_dtype = node_type.dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) @@ -161,10 +161,22 @@ def translate_as_field_op( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: - """Generates the dataflow subgraph for the `as_fieldop` builtin function.""" + """ + Generates the dataflow subgraph for the `as_fieldop` builtin function. + + Expects a `FunCall` node with two arguments: + 1. a lambda function representing the stencil, which is lowered to a dataflow subgraph + 2. the domain of the field operator, which is used as map range + + The dataflow can be as simple as a single tasklet, or implement a local computation + as a composition of tasklets and even include a map to range on local dimensions (e.g. + neighbors and map builtins). + The stencil dataflow is instantiated inside a map scope, which apply the stencil over + the field domain. + """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") assert isinstance(node.type, ts.FieldType) @@ -172,12 +184,10 @@ def translate_as_field_op( fun_node = node.fun assert len(fun_node.args) == 2 stencil_expr, domain_expr = fun_node.args - # expect stencil (represented as a lambda function) as first argument assert isinstance(stencil_expr, gtir.Lambda) - # the domain of the field operator is passed as second argument assert isinstance(domain_expr, gtir.FunCall) - # add local storage to compute the field operator over the given domain + # parse the domain of the field operator domain = dace_gtir_utils.get_domain(domain_expr) if cpm.is_applied_reduce(stencil_expr.expr): @@ -185,7 +195,7 @@ def translate_as_field_op( raise NotImplementedError("nested reductions not supported.") # the reduce identity value is used to fill the skip values in neighbors list - _, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr) + _, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr) # visit the list of arguments to be passed to the lambda expression stencil_args = [ @@ -194,55 +204,34 @@ def translate_as_field_op( ] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity) - input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args) - assert isinstance(output_expr, gtir_to_tasklet.DataExpr) - output_desc = output_expr.node.desc(sdfg) - - # retrieve the tasklet node which writes the result - last_node = state.in_edges(output_expr.node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - last_node_connector = state.in_edges(output_expr.node)[0].src_conn - state.remove_node(output_expr.node) - else: - last_node = output_expr.node - last_node_connector = None - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc) + taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) + input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) + output_desc = output.expr.node.desc(sdfg) - # assume tasklet with single output - output_subset = [dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain] - if isinstance(output_desc, dace.data.Array): - # additional local dimension for neighbors + domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + if isinstance(node.type.dtype, itir_ts.ListType): + assert isinstance(output_desc, dace.data.Array) assert set(output_desc.offset) == {0} - output_subset.extend(f"0:{size}" for size in output_desc.shape) + # additional local dimension for neighbors + # TODO(phimuell): Investigate if we should swap the two. + output_subset = sbs.Range.from_indices(domain_index) + sbs.Range.from_array(output_desc) + else: + assert isinstance(output_desc, dace.data.Scalar) + output_subset = sbs.Range.from_indices(domain_index) # create map range corresponding to the field operator domain map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} me, mx = sdfg_builder.add_map("field_op", state, map_ranges) - if len(input_connections) == 0: - # dace requires an empty edge from map entry node to tasklet node, in case there no input memlets - state.add_nedge(me, last_node, dace.Memlet()) - else: - for data_node, data_subset, lambda_node, lambda_connector in input_connections: - memlet = dace.Memlet(data=data_node.data, subset=data_subset) - state.add_memlet_path( - data_node, - me, - lambda_node, - dst_conn=lambda_connector, - memlet=memlet, - ) - state.add_memlet_path( - last_node, - mx, - result_field.data_node, - src_conn=last_node_connector, - memlet=dace.Memlet(data=result_field.data_node.data, subset=",".join(output_subset)), - ) + # allocate local temporary storage for the result field + result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc) + + # here we setup the edges from the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the result data through the map exit node + output.connect(mx, result_field.data_node, output_subset) return result_field @@ -251,8 +240,8 @@ def translate_if( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for the `if_` builtin function.""" assert cpm.is_call_to(node, "if_") @@ -345,7 +334,7 @@ def make_temps(x: Field) -> Field: def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.SDFGBuilder, sym_name: str, sym_type: ts.DataType, ) -> FieldopResult: @@ -374,7 +363,7 @@ def _get_data_nodes( def _get_symbolic_value( sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, + sdfg_builder: gtir_sdfg.SDFGBuilder, symbolic_expr: dace.symbolic.SymExpr, scalar_type: ts.ScalarType, temp_name: Optional[str] = None, @@ -407,8 +396,8 @@ def translate_literal( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for a `ir.Literal` node.""" assert isinstance(node, gtir.Literal) @@ -423,8 +412,8 @@ def translate_make_tuple( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert cpm.is_call_to(node, "make_tuple") return tuple( @@ -442,8 +431,8 @@ def translate_tuple_get( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert cpm.is_call_to(node, "tuple_get") assert len(node.args) == 2 @@ -474,8 +463,8 @@ def translate_scalar_expr( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: assert isinstance(node, gtir.FunCall) assert isinstance(node.type, ts.ScalarType) @@ -558,8 +547,8 @@ def translate_symbol_ref( node: gtir.Node, sdfg: dace.SDFG, state: dace.SDFGState, - sdfg_builder: gtir_to_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + sdfg_builder: gtir_sdfg.SDFGBuilder, + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> FieldopResult: """Generates the dataflow subgraph for a `ir.SymRef` node.""" assert isinstance(node, gtir.SymRef) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py similarity index 81% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index bfddd6f723..9739d7927a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -8,8 +8,9 @@ from __future__ import annotations +import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Set, Tuple, TypeAlias, Union +from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union import dace import dace.subsets as sbs @@ -18,11 +19,11 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as gtir_ts +from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, - gtir_to_sdfg, + gtir_sdfg, utility as dace_gtir_utils, ) from gt4py.next.type_system import type_specifications as ts @@ -33,7 +34,7 @@ class DataExpr: """Local storage for the computation result returned by a tasklet node.""" node: dace.nodes.AccessNode - dtype: gtir_ts.ListType | ts.ScalarType + dtype: itir_ts.ListType | ts.ScalarType @dataclasses.dataclass(frozen=True) @@ -52,14 +53,6 @@ class SymbolExpr: dtype: dace.typeclass -# Define alias for the elements needed to setup input connections to a map scope -InputConnection: TypeAlias = tuple[ - dace.nodes.AccessNode, - sbs.Range, - dace.nodes.Node, - Optional[str], -] - ValueExpr: TypeAlias = DataExpr | MemletExpr | SymbolExpr @@ -83,6 +76,102 @@ class IteratorExpr: indices: dict[gtx_common.Dimension, ValueExpr] +class DataflowInputEdge(Protocol): + """ + This protocol represents an open connection into the dataflow. + + It provides the `connect` method to setup an input edge from an external data source. + Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope + and connect its inputs and outputs to external data nodes by means of memlets that + traverse the map entry and exit nodes. + """ + + @abc.abstractmethod + def connect(self, me: dace.nodes.MapEntry) -> None: ... + + +@dataclasses.dataclass(frozen=True) +class MemletInputEdge(DataflowInputEdge): + """ + Allows to setup an input memlet through a map entry node. + + The edge source has to be a data access node, while the destination node can either + be a tasklet, in which case the connector name is also required, or an access node. + """ + + state: dace.SDFGState + source: dace.nodes.AccessNode + subset: sbs.Range + dest: dace.nodes.AccessNode | dace.nodes.Tasklet + dest_conn: Optional[str] + + def connect(self, me: dace.nodes.MapEntry) -> None: + memlet = dace.Memlet(data=self.source.data, subset=self.subset) + self.state.add_memlet_path( + self.source, + me, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) + + +@dataclasses.dataclass(frozen=True) +class EmptyInputEdge(DataflowInputEdge): + """ + Allows to setup an edge from a map entry node to a tasklet with no arguements. + + The reason behind this kind of connection is that all nodes inside a map scope + must have an in/out path that traverses the entry and exit nodes. + """ + + state: dace.SDFGState + node: dace.nodes.Tasklet + + def connect(self, me: dace.nodes.MapEntry) -> None: + self.state.add_nedge(me, self.node, dace.Memlet()) + + +@dataclasses.dataclass(frozen=True) +class DataflowOutputEdge: + """ + Allows to setup an output memlet through a map exit node. + + The result of a dataflow subgraph needs to be written to an external data node. + Since the dataflow represents a stencil and the dataflow is computed over + a field domain, the dataflow is instatiated inside a map scope. The `connect` + method creates a memlet that writes the dataflow result to the external array + passing through the map exit node. + """ + + state: dace.SDFGState + expr: DataExpr + + def connect( + self, + mx: dace.nodes.MapExit, + result_node: dace.nodes.AccessNode, + subset: sbs.Range, + ) -> None: + # retrieve the node which writes the result + last_node = self.state.in_edges(self.expr.node)[0].src + if isinstance(last_node, dace.nodes.Tasklet): + # the last transient node can be deleted + last_node_connector = self.state.in_edges(self.expr.node)[0].src_conn + self.state.remove_node(self.expr.node) + else: + last_node = self.expr.node + last_node_connector = None + + self.state.add_memlet_path( + last_node, + mx, + result_node, + src_conn=last_node_connector, + memlet=dace.Memlet(data=result_node.data, subset=subset), + ) + + DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { "minimum": dace.dtypes.ReductionType.Min, "maximum": dace.dtypes.ReductionType.Max, @@ -116,43 +205,52 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity -class LambdaToTasklet(eve.NodeVisitor): - """Translates an `ir.Lambda` expression to a dataflow graph. - - Lambda functions should only be encountered as argument to the `as_fieldop` - builtin function, therefore the dataflow graph generated here typically - represents the stencil function of a field operator. +class LambdaToDataflow(eve.NodeVisitor): + """ + Translates an `ir.Lambda` expression to a dataflow graph. + + The dataflow graph generated here typically represents the stencil function + of a field operator. It only computes single elements or pure local fields, + in case of neighbor values. In case of local fields, the dataflow contains + inner maps with fixed literal size (max number of neighbors). + Once the lambda expression has been lowered to a dataflow, the dataflow graph + needs to be instantiated, that is we have to connect all in/out edges to + external source/destination data nodes. Since the lambda expression is used + in GTIR as argument to a field operator, the dataflow is instatiated inside + a map scope and applied on the field domain. Therefore, all in/out edges + must traverse the entry/exit map nodes. """ sdfg: dace.SDFG state: dace.SDFGState - subgraph_builder: gtir_to_sdfg.DataflowBuilder + subgraph_builder: gtir_sdfg.DataflowBuilder reduce_identity: Optional[SymbolExpr] - input_connections: list[InputConnection] + input_edges: list[DataflowInputEdge] symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] def __init__( self, sdfg: dace.SDFG, state: dace.SDFGState, - subgraph_builder: gtir_to_sdfg.DataflowBuilder, + subgraph_builder: gtir_sdfg.DataflowBuilder, reduce_identity: Optional[SymbolExpr], ): self.sdfg = sdfg self.state = state self.subgraph_builder = subgraph_builder self.reduce_identity = reduce_identity - self.input_connections = [] + self.input_edges = [] self.symbol_map = {} - def _add_entry_memlet_path( + def _add_input_data_edge( self, src: dace.nodes.AccessNode, src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, ) -> None: - self.input_connections.append((src, src_subset, dst_node, dst_conn)) + edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + self.input_edges.append(edge) def _add_edge( self, @@ -186,9 +284,18 @@ def _add_tasklet( **kwargs: Any, ) -> dace.nodes.Tasklet: """Helper method to add a tasklet with unique name in current state.""" - return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs) - - def _get_tasklet_result( + tasklet_node = self.subgraph_builder.add_tasklet( + name, self.state, inputs, outputs, code, **kwargs + ) + if len(inputs) == 0: + # All nodes inside a map scope must have an in/out path that traverses + # the entry and exit nodes. Therefore, a tasklet node with no arguments + # still needs an (empty) input edge from map entry node. + edge = EmptyInputEdge(self.state, tasklet_node) + self.input_edges.append(edge) + return tasklet_node + + def _construct_tasklet_result( self, dtype: dace.typeclass, src_node: dace.nodes.Tasklet, @@ -267,7 +374,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: code=f"val = field[{index_internals}]", ) # add new termination point for the field parameter - self._add_entry_memlet_path( + self._add_input_data_edge( arg_expr.field, sbs.Range.from_array(field_desc), deref_node, @@ -278,7 +385,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: # add termination points for the dynamic iterator indices deref_connector = IndexConnectorFmt.format(dim=dim.value) if isinstance(index_expr, MemletExpr): - self._add_entry_memlet_path( + self._add_input_data_edge( index_expr.node, index_expr.subset, deref_node, @@ -297,7 +404,7 @@ def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(index_expr, SymbolExpr) dtype = arg_expr.field.desc(self.sdfg).dtype - return self._get_tasklet_result(dtype, deref_node, "val") + return self._construct_tasklet_result(dtype, deref_node, "val") else: # dereferencing a scalar or a literal node results in the node itself @@ -353,7 +460,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: else f"0:{size}" for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ) - self._add_entry_memlet_path( + self._add_input_data_edge( it.field, sbs.Range.from_string(field_subset), field_slice_node, @@ -367,7 +474,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: find_new_name=True, ) connectivity_slice_node = self.state.add_access(connectivity_slice_view) - self._add_entry_memlet_path( + self._add_input_data_edge( self.state.add_access(connectivity), sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), connectivity_slice_node, @@ -378,7 +485,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: ) neighbors_node = self.state.add_access(neighbors_temp) - offset_dim = gtx_common.Dimension(offset) + offset_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) neighbor_idx = dace_gtir_utils.get_map_variable(offset_dim) me, mx = self._add_map( f"{offset}_neighbors", @@ -428,7 +535,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: memlet=dace.Memlet(data=neighbors_temp, subset=neighbor_idx), ) - assert isinstance(node.type, gtir_ts.ListType) + assert isinstance(node.type, itir_ts.ListType) return DataExpr(neighbors_node, node.type) def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: @@ -465,7 +572,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_init.value) if isinstance(input_expr, MemletExpr): - self._add_entry_memlet_path( + self._add_input_data_edge( input_expr.node, input_expr.subset, reduce_node, @@ -557,7 +664,7 @@ def _make_cartesian_shift( ) for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: if isinstance(input_expr, MemletExpr): - self._add_entry_memlet_path( + self._add_input_data_edge( input_expr.node, input_expr.subset, dynamic_offset_tasklet, @@ -577,7 +684,9 @@ def _make_cartesian_shift( else: dtype = index_expr.node.desc(self.sdfg).dtype - new_index = self._get_tasklet_result(dtype, dynamic_offset_tasklet, new_index_connector) + new_index = self._construct_tasklet_result( + dtype, dynamic_offset_tasklet, new_index_connector + ) # a new iterator with a shifted index along one dimension return IteratorExpr( @@ -605,14 +714,14 @@ def _make_dynamic_neighbor_offset( {new_index_connector}, f"{new_index_connector} = table[{origin_index.value}, offset]", ) - self._add_entry_memlet_path( + self._add_input_data_edge( offset_table_node, sbs.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) if isinstance(offset_expr, MemletExpr): - self._add_entry_memlet_path( + self._add_input_data_edge( offset_expr.node, offset_expr.subset, tasklet_node, @@ -628,7 +737,7 @@ def _make_dynamic_neighbor_offset( ) dtype = offset_table_node.desc(self.sdfg).dtype - return self._get_tasklet_result(dtype, tasklet_node, new_index_connector) + return self._construct_tasklet_result(dtype, tasklet_node, new_index_connector) def _make_unstructured_shift( self, @@ -702,21 +811,13 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: it, offset_provider, offset_table_node, offset_expr ) - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - - elif cpm.is_call_to(node, "neighbors"): - return self._visit_neighbors(node) - - elif cpm.is_applied_reduce(node): - return self._visit_reduce(node) - - elif cpm.is_applied_shift(node): - return self._visit_shift(node) - - else: - assert isinstance(node.fun, gtir.SymRef) + def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: + """ + Generic handler called by `visit_FunCall()` when it encounters + a builtin function that does not match any other specific handler. + """ + assert isinstance(node.type, ts.ScalarType) + dtype = dace_utils.as_dace_type(node.type) node_internals = [] node_connections: dict[str, MemletExpr | DataExpr] = {} @@ -732,8 +833,9 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: # use the argument value without adding any connector node_internals.append(arg_expr.value) - # use tasklet connectors as expression arguments + assert isinstance(node.fun, gtir.SymRef) builtin_name = str(node.fun.id) + # use tasklet connectors as expression arguments code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) out_connector = "result" @@ -754,42 +856,61 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: dace.Memlet(data=arg_expr.node.data, subset="0"), ) else: - self._add_entry_memlet_path( + self._add_input_data_edge( arg_expr.node, arg_expr.subset, tasklet_node, connector, ) - assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) + return self._construct_tasklet_result(dtype, tasklet_node, "result") - return self._get_tasklet_result(dtype, tasklet_node, "result") + def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + + elif cpm.is_call_to(node, "neighbors"): + return self._visit_neighbors(node) + + elif cpm.is_applied_reduce(node): + return self._visit_reduce(node) + + elif cpm.is_applied_shift(node): + return self._visit_shift(node) + + elif isinstance(node.fun, gtir.SymRef): + return self._visit_generic_builtin(node) + + else: + raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") def visit_Lambda( self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[InputConnection], DataExpr]: + ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: for p, arg in zip(node.params, args, strict=True): self.symbol_map[str(p.id)] = arg - output_expr: MemletExpr | SymbolExpr | DataExpr = self.visit(node.expr) + output_expr: ValueExpr = self.visit(node.expr) if isinstance(output_expr, DataExpr): - return self.input_connections, output_expr + return self.input_edges, DataflowOutputEdge(self.state, output_expr) if isinstance(output_expr, MemletExpr): # special case where the field operator is simply copying data from source to destination node output_dtype = output_expr.node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_entry_memlet_path( + self._add_input_data_edge( output_expr.node, output_expr.subset, tasklet_node, "__inp", ) else: + assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node output_dtype = output_expr.dtype tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") - return self.input_connections, self._get_tasklet_result(output_dtype, tasklet_node, "__out") + + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return self.input_edges, DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dtype = dace_utils.as_dace_type(node.type) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py similarity index 98% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 7de1192829..7d878dde99 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -30,7 +30,7 @@ from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_builtin_translators, - gtir_to_tasklet, + gtir_dataflow, transformations as gtx_transformations, utility as dace_gtir_utils, ) @@ -41,16 +41,13 @@ class DataflowBuilder(Protocol): """Visitor interface to build a dataflow subgraph.""" @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - pass + def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... @abc.abstractmethod - def unique_map_name(self, name: str) -> str: - pass + def unique_map_name(self, name: str) -> str: ... @abc.abstractmethod - def unique_tasklet_name(self, name: str) -> str: - pass + def unique_tasklet_name(self, name: str) -> str: ... def add_map( self, @@ -86,12 +83,12 @@ class SDFGBuilder(DataflowBuilder, Protocol): @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: """Retrieve the GT4Py type of a symbol used in the program.""" - pass + ... @abc.abstractmethod def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: """Visit a node of the GT4Py IR.""" - pass + ... @dataclasses.dataclass(frozen=True) @@ -408,7 +405,7 @@ def visit_FunCall( node: gtir.FunCall, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): @@ -457,7 +454,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + reduce_identity: Optional[gtir_dataflow.SymbolExpr], args: list[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ @@ -623,7 +620,7 @@ def visit_Literal( node: gtir.Literal, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: return gtir_builtin_translators.translate_literal( node, sdfg, head_state, self, reduce_identity=None @@ -634,7 +631,7 @@ def visit_SymRef( node: gtir.SymRef, sdfg: dace.SDFG, head_state: dace.SDFGState, - reduce_identity: Optional[gtir_to_tasklet.SymbolExpr], + reduce_identity: Optional[gtir_dataflow.SymbolExpr], ) -> gtir_builtin_translators.FieldopResult: return gtir_builtin_translators.translate_symbol_ref( node, sdfg, head_state, self, reduce_identity=None diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 28d518dd34..ffc33a9f25 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -22,7 +22,7 @@ from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg +from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg from gt4py.next.type_system import type_translation as tt @@ -49,7 +49,7 @@ def generate_sdfg( # TODO(edopao): Call IR transformations and domain inference, finally lower IR to SDFG raise NotImplementedError - return gtir_to_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) + return gtir_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) def __call__( self, inp: stages.CompilableProgram diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 5f1c24bcc4..e819cdcd8c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -395,10 +395,10 @@ def test_gtir_tuple_target(): def test_gtir_update(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) stencil1 = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref("a"), 0 - 1.0)), + im.lambda_("a")(im.plus(im.deref("a"), im.plus(im.minus(0.0, 2.0), 1.0))), domain, )("x") - stencil2 = im.op_as_fieldop("plus", domain)("x", 0 - 1.0) + stencil2 = im.op_as_fieldop("plus", domain)("x", im.plus(im.minus(0.0, 2.0), 1.0)) for i, stencil in enumerate([stencil1, stencil2]): testee = gtir.Program(