Skip to content

Commit

Permalink
feat[next][dace]: refactoring (#1685)
Browse files Browse the repository at this point in the history
Implements some refactoring changes discussed during review of #1683.
  • Loading branch information
edopao authored Oct 10, 2024
1 parent 9bb6c94 commit 92e83f3
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def __init__(self, program: dace.CompiledSDFG):
]

def __call__(self, *args: Any, **kwargs: Any) -> None:
self.sdfg_program(*args, **kwargs)
result = self.sdfg_program(*args, **kwargs)
assert result is None

def fast_call(self) -> None:
result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs)
assert result is None


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -104,7 +109,7 @@ def decorated_program(
*args: Any,
offset_provider: common.OffsetProvider,
out: Any = None,
) -> Any:
) -> None:
if out is not None:
args = (*args, out)
if len(sdfg.arg_names) > len(args):
Expand Down Expand Up @@ -141,7 +146,7 @@ def decorated_program(
), f"argument '{arg_name}' not found."

if use_fast_call:
return sdfg_program.fast_call(*sdfg_program._lastargs)
return inp.fast_call()

sdfg_args = dace_backend.get_sdfg_args(
sdfg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


from gt4py.next.program_processors.runners.dace_common.dace_backend import get_sdfg_args
from gt4py.next.program_processors.runners.dace_fieldview.gtir_to_sdfg import build_sdfg_from_gtir
from gt4py.next.program_processors.runners.dace_fieldview.gtir_sdfg import build_sdfg_from_gtir


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.type_system import type_specifications as gtir_ts
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
gtir_dataflow,
gtir_python_codegen,
gtir_to_tasklet,
utility as dace_gtir_utils,
)
from gt4py.next.type_system import type_specifications as ts


if TYPE_CHECKING:
from gt4py.next.program_processors.runners.dace_fieldview import gtir_to_sdfg
from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
Expand All @@ -51,8 +51,8 @@ def __call__(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
"""Creates the dataflow subgraph representing a GTIR primitive function.
Expand Down Expand Up @@ -80,12 +80,12 @@ def _parse_fieldop_arg(
node: gtir.Expr,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
sdfg_builder: gtir_sdfg.SDFGBuilder,
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr:
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
arg = sdfg_builder.visit(
node,
sdfg=sdfg,
Expand All @@ -98,10 +98,10 @@ def _parse_fieldop_arg(
raise ValueError(f"Received {node} as argument to field operator, expected a field.")

if isinstance(arg.data_type, ts.ScalarType):
return gtir_to_tasklet.MemletExpr(arg.data_node, sbs.Indices([0]))
return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0]))
elif isinstance(arg.data_type, ts.FieldType):
indices: dict[gtx_common.Dimension, gtir_to_tasklet.ValueExpr] = {
dim: gtir_to_tasklet.SymbolExpr(
indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = {
dim: gtir_dataflow.SymbolExpr(
dace_gtir_utils.get_map_variable(dim),
IteratorIndexDType,
)
Expand All @@ -110,9 +110,9 @@ def _parse_fieldop_arg(
dims = arg.data_type.dims + (
# we add an extra anonymous dimension in the iterator definition to enable
# dereferencing elements in `ListType`
[gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, gtir_ts.ListType) else []
[gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, itir_ts.ListType) else []
)
return gtir_to_tasklet.IteratorExpr(arg.data_node, dims, indices)
return gtir_dataflow.IteratorExpr(arg.data_node, dims, indices)
else:
raise NotImplementedError(f"Node type {type(arg.data_type)} not supported.")

Expand All @@ -139,7 +139,7 @@ def _create_temporary_field(
field_shape = list(domain_ubs)

if isinstance(output_desc, dace.data.Array):
assert isinstance(node_type.dtype, gtir_ts.ListType)
assert isinstance(node_type.dtype, itir_ts.ListType)
assert isinstance(node_type.dtype.element_type, ts.ScalarType)
field_dtype = node_type.dtype.element_type
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
Expand All @@ -161,31 +161,41 @@ def translate_as_field_op(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
"""Generates the dataflow subgraph for the `as_fieldop` builtin function."""
"""
Generates the dataflow subgraph for the `as_fieldop` builtin function.
Expects a `FunCall` node with two arguments:
1. a lambda function representing the stencil, which is lowered to a dataflow subgraph
2. the domain of the field operator, which is used as map range
The dataflow can be as simple as a single tasklet, or implement a local computation
as a composition of tasklets and even include a map to range on local dimensions (e.g.
neighbors and map builtins).
The stencil dataflow is instantiated inside a map scope, which apply the stencil over
the field domain.
"""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)

fun_node = node.fun
assert len(fun_node.args) == 2
stencil_expr, domain_expr = fun_node.args
# expect stencil (represented as a lambda function) as first argument
assert isinstance(stencil_expr, gtir.Lambda)
# the domain of the field operator is passed as second argument
assert isinstance(domain_expr, gtir.FunCall)

# add local storage to compute the field operator over the given domain
# parse the domain of the field operator
domain = dace_gtir_utils.get_domain(domain_expr)

if cpm.is_applied_reduce(stencil_expr.expr):
if reduce_identity is not None:
raise NotImplementedError("nested reductions not supported.")

# the reduce identity value is used to fill the skip values in neighbors list
_, _, reduce_identity = gtir_to_tasklet.get_reduce_params(stencil_expr.expr)
_, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr)

# visit the list of arguments to be passed to the lambda expression
stencil_args = [
Expand All @@ -194,55 +204,34 @@ def translate_as_field_op(
]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder, reduce_identity)
input_connections, output_expr = taskgen.visit(stencil_expr, args=stencil_args)
assert isinstance(output_expr, gtir_to_tasklet.DataExpr)
output_desc = output_expr.node.desc(sdfg)

# retrieve the tasklet node which writes the result
last_node = state.in_edges(output_expr.node)[0].src
if isinstance(last_node, dace.nodes.Tasklet):
# the last transient node can be deleted
last_node_connector = state.in_edges(output_expr.node)[0].src_conn
state.remove_node(output_expr.node)
else:
last_node = output_expr.node
last_node_connector = None

# allocate local temporary storage for the result field
result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc)
taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity)
input_edges, output = taskgen.visit(stencil_expr, args=stencil_args)
output_desc = output.expr.node.desc(sdfg)

# assume tasklet with single output
output_subset = [dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]
if isinstance(output_desc, dace.data.Array):
# additional local dimension for neighbors
domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain])
if isinstance(node.type.dtype, itir_ts.ListType):
assert isinstance(output_desc, dace.data.Array)
assert set(output_desc.offset) == {0}
output_subset.extend(f"0:{size}" for size in output_desc.shape)
# additional local dimension for neighbors
# TODO(phimuell): Investigate if we should swap the two.
output_subset = sbs.Range.from_indices(domain_index) + sbs.Range.from_array(output_desc)
else:
assert isinstance(output_desc, dace.data.Scalar)
output_subset = sbs.Range.from_indices(domain_index)

# create map range corresponding to the field operator domain
map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain}
me, mx = sdfg_builder.add_map("field_op", state, map_ranges)

if len(input_connections) == 0:
# dace requires an empty edge from map entry node to tasklet node, in case there no input memlets
state.add_nedge(me, last_node, dace.Memlet())
else:
for data_node, data_subset, lambda_node, lambda_connector in input_connections:
memlet = dace.Memlet(data=data_node.data, subset=data_subset)
state.add_memlet_path(
data_node,
me,
lambda_node,
dst_conn=lambda_connector,
memlet=memlet,
)
state.add_memlet_path(
last_node,
mx,
result_field.data_node,
src_conn=last_node_connector,
memlet=dace.Memlet(data=result_field.data_node.data, subset=",".join(output_subset)),
)
# allocate local temporary storage for the result field
result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc)

# here we setup the edges from the map entry node
for edge in input_edges:
edge.connect(me)

# and here the edge writing the result data through the map exit node
output.connect(mx, result_field.data_node, output_subset)

return result_field

Expand All @@ -251,8 +240,8 @@ def translate_if(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
"""Generates the dataflow subgraph for the `if_` builtin function."""
assert cpm.is_call_to(node, "if_")
Expand Down Expand Up @@ -345,7 +334,7 @@ def make_temps(x: Field) -> Field:
def _get_data_nodes(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
sdfg_builder: gtir_sdfg.SDFGBuilder,
sym_name: str,
sym_type: ts.DataType,
) -> FieldopResult:
Expand Down Expand Up @@ -374,7 +363,7 @@ def _get_data_nodes(
def _get_symbolic_value(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
sdfg_builder: gtir_sdfg.SDFGBuilder,
symbolic_expr: dace.symbolic.SymExpr,
scalar_type: ts.ScalarType,
temp_name: Optional[str] = None,
Expand Down Expand Up @@ -407,8 +396,8 @@ def translate_literal(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
"""Generates the dataflow subgraph for a `ir.Literal` node."""
assert isinstance(node, gtir.Literal)
Expand All @@ -423,8 +412,8 @@ def translate_make_tuple(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
assert cpm.is_call_to(node, "make_tuple")
return tuple(
Expand All @@ -442,8 +431,8 @@ def translate_tuple_get(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
assert cpm.is_call_to(node, "tuple_get")
assert len(node.args) == 2
Expand Down Expand Up @@ -474,8 +463,8 @@ def translate_scalar_expr(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
assert isinstance(node, gtir.FunCall)
assert isinstance(node.type, ts.ScalarType)
Expand Down Expand Up @@ -558,8 +547,8 @@ def translate_symbol_ref(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_to_tasklet.SymbolExpr],
sdfg_builder: gtir_sdfg.SDFGBuilder,
reduce_identity: Optional[gtir_dataflow.SymbolExpr],
) -> FieldopResult:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, gtir.SymRef)
Expand Down
Loading

0 comments on commit 92e83f3

Please sign in to comment.