Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Oct 10, 2024
1 parent d4f066a commit f0331cb
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 265 deletions.
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type",)
PRESERVED_ANNEX_ATTRS = ("type", "domain")

expr_map: dict[int, itir.SymRef]

Expand All @@ -51,7 +51,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.Node:
if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda):
eligible_params = []
for arg in node.args:
eligible_params.append(isinstance(arg, itir.SymRef)) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let
eligible_params.append(
isinstance(arg, itir.SymRef)
) # and arg.id.startswith("_cs")) # TODO: document? this is for lets in the global tmp pass, e.g. test_trivial_let
if any(eligible_params):
# note: the inline is opcount preserving anyway so avoid the additional
# effort in the inliner by disabling opcount preservation.
Expand Down
5 changes: 1 addition & 4 deletions src/gt4py/next/iterator/transforms/fencil_to_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import global_tmps


class FencilToProgram(eve.NodeTranslator):
@classmethod
def apply(
cls, node: itir.FencilDefinition | itir.Program
) -> itir.Program:
def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program:
return cls().visit(node)

def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt:
Expand Down
206 changes: 117 additions & 89 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,119 +8,145 @@

from __future__ import annotations

import copy
import dataclasses
import functools
from collections.abc import Mapping
from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence
from typing import Callable, Optional

from gt4py import eve
from gt4py.eve import utils as eve_utils
from gt4py.next import common, utils as next_utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.type_system import type_info
from gt4py.next.iterator.transforms import inline_lambdas

# TODO: remove
SimpleTemporaryExtractionHeuristics = None
CreateGlobalTmps = None

from gt4py.next.iterator.transforms import cse

class IncompleteTemporary:
expr: itir.Expr
target: itir.Expr

def get_expr_domain(expr: itir.Expr, ctx=None):
ctx = ctx or {}

if cpm.is_applied_as_fieldop(expr):
_, domain = expr.fun.args
return domain
elif cpm.is_call_to(expr, "tuple_get"):
idx_expr, tuple_expr = expr.args
assert isinstance(idx_expr, itir.Literal) and type_info.is_integer(idx_expr.type)
idx = int(idx_expr.value)
tuple_expr_domain = get_expr_domain(tuple_expr, ctx)
assert isinstance(tuple_expr_domain, tuple) and idx < len(tuple_expr_domain)
return tuple_expr_domain[idx]
elif cpm.is_call_to(expr, "make_tuple"):
return tuple(get_expr_domain(el, ctx) for el in expr.args)
elif cpm.is_call_to(expr, "if_"):
cond, true_val, false_val = expr.args
true_domain, false_domain = get_expr_domain(true_val, ctx), get_expr_domain(false_val, ctx)
assert true_domain == false_domain
return true_domain
elif cpm.is_let(expr):
new_ctx = {}
for var_name, var_value in zip(expr.fun.params, expr.args, strict=True):
new_ctx[var_name.id] = get_expr_domain(var_value, ctx)
return get_expr_domain(expr.fun.expr, ctx={**ctx, **new_ctx})
raise ValueError()


def transform_if(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator):
from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas
from gt4py.next.iterator.type_system import inference as type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


def transform_if(
stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator
) -> Optional[list[itir.Stmt]]:
if not isinstance(stmt, itir.SetAt):
return None

if cpm.is_call_to(stmt.expr, "if_"):
cond, true_val, false_val = stmt.expr.args
return [itir.IfStmt(
cond=cond,
# recursively transform
true_branch=transform(itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), declarations, uids),
false_branch=transform(itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), declarations, uids),
)]
return [
itir.IfStmt(
cond=cond,
# recursively transform
true_branch=transform(
itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain),
declarations,
uids,
),
false_branch=transform(
itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain),
declarations,
uids,
),
)
]
return None

def transform_by_pattern(stmt: itir.SetAt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator):

def transform_by_pattern(
stmt: itir.Stmt, predicate, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator
) -> Optional[list[itir.Stmt]]:
if not isinstance(stmt, itir.SetAt):
return None

new_expr, extracted_fields, _ = cse.extract_subexpression(
stmt.expr,
predicate=predicate,
uid_generator=uids,
# allows better fusing later on
#deepest_expr_first=True # TODO: better, but not supported right now
uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"),
# TODO(tehrengruber): extracting the deepest expression first would allow us to fuse
# the extracted expressions resulting in fewer kernel calls, better data-locality.
# Extracting the multiple expressions deepest-first is however not supported right now.
# deepest_expr_first=True # noqa: ERA001
)

if extracted_fields:
new_stmts = []
tmp_stmts: list[itir.Stmt] = []

# for each extracted expression generate:
# - one or more `Temporary` declarations (depending on whether the expression is a field
# or a tuple thereof)
# - one `SetAt` statement that materializes the expression into the temporary
for tmp_sym, tmp_expr in extracted_fields.items():
# TODO: expr domain can not be a tuple here
domain = get_expr_domain(tmp_expr)
domain = tmp_expr.annex.domain

# TODO(tehrengruber): Implement. This happens when the expression for a combination
# of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are
# able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level
# of a SetAt, the CollapseTuple pass will eliminate most of this cases.
if isinstance(domain, tuple):
flattened_domains: tuple[itir.Expr] = next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough
if not all(d == flattened_domains[0] for d in flattened_domains):
raise NotImplementedError(
"Tuple expressions with different domains is not " "supported yet."
)
domain = flattened_domains[0]

assert isinstance(tmp_expr.type, ts.TypeSpec)
tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents(
lambda x: uids.sequential_id(),
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = (
type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
)

scalar_type = type_info.apply_to_primitive_constituents(
type_info.extract_dtype, tmp_expr.type
# allocate temporary for all tuple elements
def allocate_temporary(tmp_name: str, dtype: ts.ScalarType, domain: itir.Expr):
declarations.append(itir.Temporary(id=tmp_name, domain=domain, dtype=dtype))

next_utils.tree_map(functools.partial(allocate_temporary, domain=domain))(
tmp_names, tmp_dtypes
)
declarations.append(itir.Temporary(id=tmp_sym.id, domain=domain, dtype=scalar_type))

# TODO: transform not needed if deepest_expr_first=True
new_stmts.extend(transform(itir.SetAt(target=im.ref(tmp_sym.id), domain=domain, expr=tmp_expr), declarations, uids))
# if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a
# `make_tuple` expression.
target_expr: itir.Expr = next_utils.tree_map(
lambda x: im.ref(x), result_collection_constructor=lambda els: im.make_tuple(*els)
)(tmp_names) # type: ignore[assignment] # typing of tree_map does not reflect action of `result_collection_constructor` yet

return [
*new_stmts,
itir.SetAt(
target=stmt.target,
domain=stmt.domain,
expr=new_expr
# note: the let would be removed automatically by the `cse.extract_subexpression`, but
# we remove it here for readability & debuggability.
new_expr = inline_lambdas.inline_lambda(
im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False
)
]

# TODO: transform not needed if deepest_expr_first=True
tmp_stmts.extend(
transform(
itir.SetAt(target=target_expr, domain=domain, expr=tmp_expr), declarations, uids
)
)

return [*tmp_stmts, itir.SetAt(target=stmt.target, domain=stmt.domain, expr=new_expr)]
return None

def transform(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator):
# TODO: what happens for a trivial let, e.g `let a=as_fieldop() in a end`?
unprocessed_stmts = [stmt]
stmts = []

transforms = [
def transform(
stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator
) -> list[itir.Stmt]:
unprocessed_stmts: list[itir.Stmt] = [stmt]
stmts: list[itir.Stmt] = []

transforms: list[Callable] = [
# transform functional if_ into if-stmt
transform_if,
# extract applied `as_fieldop` to top-level
functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr)),
functools.partial(
transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr)
),
# extract functional if_ to the top-level
functools.partial(transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_")),
functools.partial(
transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_")
),
]

while unprocessed_stmts:
Expand All @@ -140,23 +166,25 @@ def transform(stmt: itir.SetAt, declarations: list[itir.Temporary], uids: eve_ut

return stmts

def create_global_tmps(program: itir.Program):

def create_global_tmps(
program: itir.Program, offset_provider: common.OffsetProvider
) -> itir.Program:
program = infer_domain.infer_program(program, offset_provider)
program = type_inference.infer(program, offset_provider=offset_provider)

uids = eve_utils.UIDGenerator(prefix="__tmp")
declarations = program.declarations
declarations = program.declarations.copy()
new_body = []

for stmt in program.body:
if isinstance(stmt, (itir.SetAt, itir.IfStmt)):
new_body.extend(
transform(stmt, uids=uids, declarations=declarations)
)
else:
raise NotImplementedError()
assert isinstance(stmt, itir.SetAt)
new_body.extend(transform(stmt, uids=uids, declarations=declarations))

return itir.Program(
id=program.id,
function_definitions=program.function_definitions,
params=program.params,
declarations=declarations,
body=new_body
)
body=new_body,
)
Loading

0 comments on commit f0331cb

Please sign in to comment.