diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b2662fa278..19e26f24b6 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -446,7 +446,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: +def as_fieldop(expr: itir.Expr, domain: Optional[itir.Expr] = None) -> call: """ Create an `as_fieldop` call. diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py new file mode 100644 index 0000000000..51bbd91d83 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -0,0 +1,204 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import inline_lambdas, inline_lifts, trace_shifts +from gt4py.next.iterator.type_system import ( + inference as type_inference, + type_specifications as it_ts, +) +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _merge_arguments( + args1: dict[str, itir.Expr], arg2: dict[str, itir.Expr] +) -> dict[str, itir.Expr]: + new_args = {**args1} + for stencil_param, stencil_arg in arg2.items(): + if stencil_param not in new_args: + new_args[stencil_param] = stencil_arg + else: + assert new_args[stencil_param] == stencil_arg + return new_args + + +def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + type_inference.copy_type(from_=expr, to=new_expr) + + return new_expr + + return expr + + +@dataclasses.dataclass +class FuseAsFieldOp(eve.NodeTranslator): + """ + Merge multiple `as_fieldop` calls into one. + + >>> from gt4py import next as gtx + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = gtx.Dimension("IDim") + >>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + >>> d = im.domain("cartesian_domain", {IDim: (0, 1)}) + >>> nested_as_fieldop = im.op_as_fieldop("plus", d)( + ... im.op_as_fieldop("multiplies", d)( + ... im.ref("inp1", field_type), im.ref("inp2", field_type) + ... ), + ... im.ref("inp3", field_type), + ... ) + >>> print(nested_as_fieldop) + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + ) + >>> print( + ... FuseAsFieldOp.apply( + ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... ) + ... ) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + """ # noqa: RUF002 # ignore ambiguous multiplication character + + uids: eve_utils.UIDGenerator + + def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + @classmethod + def apply( + cls, + node: itir.Program, + *, + offset_provider, + uids: Optional[eve_utils.UIDGenerator] = None, + allow_undeclared_symbols=False, + ): + node = type_inference.infer( + node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + ) + + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "as_fieldop"): + node = _canonicalize_as_fieldop(node) + + if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): + stencil: itir.Lambda = node.fun.args[0] + domain = node.fun.args[1] if len(node.fun.args) > 1 else None + + shifts = trace_shifts.trace_stencil(stencil) + + args: list[itir.Expr] = node.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.extract_dtype(arg.type) + # TODO(tehrengruber): make this configurable + should_inline = isinstance(arg, itir.Literal) or ( + isinstance(arg, itir.FunCall) + and (cpm.is_call_to(arg.fun, "as_fieldop") or cpm.is_call_to(arg, "if_")) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) + ) + if should_inline: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif isinstance(arg, itir.Literal): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + # simplify stencil directly to keep the tree small + new_stencil_body = inline_lambdas.InlineLambdas.apply( + new_stencil_body, opcount_preserving=True + ) + new_stencil_body = inline_lifts.InlineLifts().visit(new_stencil_body) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + type_inference.copy_type(from_=node, to=new_node) + + return new_node + return node diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index bc1095dfb8..fccaa56232 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -96,6 +96,16 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ +def copy_type(from_: itir.Node, to: itir.Node) -> None: + """ + Copy type from one node to another. + + This function mainly exists for readability reasons. + """ + assert isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) + + def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: """ Execute `callback` as soon as all `args` have a type. diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py new file mode 100644 index 0000000000..da2c16336e --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -0,0 +1,112 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.ref("inp3", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2", "inp3")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp3")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_trivial_literal(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) + expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_symref_used_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( + im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)( + im.ref("inp1", field_type), im.ref("inp2", field_type) + ), + im.ref("inp1", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp1")) + ), + d, + )("inp1", "inp2") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + im.lambda_("a")( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == testee + + +def test_partial_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + # first argument read at multiple locations -> not inlined + # second argument only reat at a single location -> inlined + im.lambda_("a", "b")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("b"), + ) + ), + d1, + )( + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a", "inp1")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("inp1"), + ) + ), + d1, + )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == expected