Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Lowering to SDFG of index builtin #1743

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6b4957c
wip - draft but one gtir_dace test fails
edopao Nov 5, 2024
a2db9af
Merge remote-tracking branch 'origin/main' into dace-gtir-casts
edopao Nov 5, 2024
c66214e
fix - working solution
edopao Nov 6, 2024
6eeaf68
Merge remote-tracking branch 'origin/main' into dace-gtir-casts
edopao Nov 6, 2024
9fd01f8
fix previous commit
edopao Nov 6, 2024
3bee0d5
Merge remote-tracking branch 'origin/main' into dace-gtir-casts
edopao Nov 8, 2024
e932293
undo changes
edopao Nov 8, 2024
27434c3
apply pruning while lowering to SDFG
edopao Nov 8, 2024
893ead7
minor edit
edopao Nov 8, 2024
b02207e
fix docstring
edopao Nov 8, 2024
71859db
re-implement IR pass to include cast as_fieldop expressions
edopao Nov 11, 2024
370b090
treat cast_ as op_as_fieldop
edopao Nov 13, 2024
1356503
initial draft
edopao Nov 14, 2024
b202f51
Merge remote-tracking branch 'origin/main' into dace-gtir-index
edopao Nov 15, 2024
842d1e6
working draft
edopao Nov 15, 2024
cf41fb2
Merge remote-tracking branch 'origin/main' into dace-gtir-index
edopao Nov 18, 2024
caec101
Merge branch 'dace-gtir-casts' into dace-gtir-index
edopao Nov 18, 2024
b42a9fc
Merge remote-tracking branch 'origin/main' into dace-gtir-casts
edopao Nov 18, 2024
c9451c4
Merge remote-tracking branch 'origin/main' into dace-gtir-index
edopao Nov 19, 2024
2ed3b91
Merge remote-tracking branch 'origin/main' into dace-gtir-casts
edopao Nov 19, 2024
48466a2
working draft
edopao Nov 19, 2024
3ef17d8
review comments
edopao Nov 19, 2024
0127fa5
Merge remote-tracking branch 'origin/dace-gtir-casts' into dace-gtir-…
edopao Nov 19, 2024
4742eac
minor edit
edopao Nov 20, 2024
d7b0e78
minor edit
edopao Nov 20, 2024
4b2f00f
code refactoring
edopao Nov 20, 2024
65f4bc6
remove extra change
edopao Nov 20, 2024
c4b4e08
Merge remote-tracking branch 'origin/dace-gtir-casts' into dace-gtir-…
edopao Nov 20, 2024
abdf25a
Merge remote-tracking branch 'origin/main' into dace-gtir-index
edopao Nov 20, 2024
59ec5e7
edit comment
edopao Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ def _impl(it: itir.Expr) -> itir.FunCall:
return _impl


def index(dim: common.Dimension) -> itir.FunCall:
"""
Create a call to the `index` builtin on a given dimension, shorthand for `call("index")(axis)`,
after converting the dimension to `itir.AxisLiteral`.
"""
return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind))


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand Down Expand Up @@ -224,20 +224,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain:
the corresponding lower and upper bounds. The returned lower bound is inclusive,
the upper bound is exclusive: [lower_bound, upper_bound[
"""
assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))

domain = []
for named_range in node.args:
assert cpm.is_call_to(named_range, "named_range")
assert len(named_range.args) == 3
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
lower_bound, upper_bound = (
dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg))
for arg in named_range.args[1:3]
)
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))

def parse_range_boundary(expr: gtir.Expr) -> str:
return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr))

if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
for named_range in node.args:
assert cpm.is_call_to(named_range, "named_range")
assert len(named_range.args) == 3
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3])
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))

elif isinstance(node, domain_utils.SymbolicDomain):
assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"}
for dim, drange in node.ranges.items():
domain.append(
(dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop))
)

else:
raise ValueError(f"Invalid domain {node}.")

return domain

Expand Down Expand Up @@ -504,6 +515,45 @@ def construct_output(inner_data: FieldopData) -> FieldopData:
return result_temps


def translate_index(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_sdfg.SDFGBuilder,
) -> FieldopResult:
"""
Lowers the `index` builtin function to a mapped tasklet that writes the dimension
index values to a transient array. The extent of the index range is taken from
the domain information that should be present in the node annex.
"""
assert "domain" in node.annex
domain = extract_domain(node.annex.domain)
assert len(domain) == 1
dim, lower_bound, upper_bound = domain[0]
dim_index = dace_gtir_utils.get_map_variable(dim)

field_dims, field_shape = _get_field_shape(domain)

output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
"index",
state,
map_ranges={
dim_index: f"{lower_bound}:{upper_bound}",
},
inputs={},
code=f"__val = {dim_index}",
outputs={"__val": dace.Memlet(data=output_node.data, subset=dim_index)},
input_nodes={},
output_nodes={output_node.data: output_node},
external_edges=True,
)

return FieldopData(output_node, ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)))


def _get_data_nodes(
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand Down Expand Up @@ -736,6 +786,7 @@ def translate_symbol_ref(
translate_as_fieldop,
translate_broadcast_scalar,
translate_if,
translate_index,
translate_literal,
translate_make_tuple,
translate_tuple_get,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils
from gt4py.next.iterator.transforms import (
infer_domain,
prune_casts as ir_prune_casts,
symbol_ref_utils,
)
from gt4py.next.iterator.type_system import inference as gtir_type_inference
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand Down Expand Up @@ -526,6 +530,8 @@ def visit_FunCall(
# use specialized dataflow builder classes for each builtin function
if cpm.is_call_to(node, "if_"):
return gtir_builtin_translators.translate_if(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "index"):
return gtir_builtin_translators.translate_index(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "make_tuple"):
return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "tuple_get"):
Expand Down Expand Up @@ -794,6 +800,14 @@ def build_sdfg_from_gtir(
An SDFG in the DaCe canonical form (simplified)
"""

if (
eve.walk_values(ir)
.filter(lambda node: cpm.is_call_to(node, "index") and "domain" not in node.annex)
.to_set()
):
# We need to rerun domain inference because the annex information is not preserved by IR passes,
# and the domain information is stored in annex for the nodes implementing the index builtin.
ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
ir = gtir_type_inference.infer(ir, offset_provider=offset_provider)
ir = ir_prune_casts.PruneCasts().visit(ir)
sdfg_genenerator = GTIRToSDFG(offset_provider)
Expand Down
1 change: 0 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
]
GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [
(USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1886,3 +1886,39 @@ def test_gtir_if_values():

sdfg(a, b, c, **FSYMBOLS)
assert np.allclose(c, np.where(a < b, a, b))


def test_gtir_index():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, im.minus("size", 1))})

testee = gtir.Program(
id="gtir_cast",
function_definitions=[],
params=[
gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.let("i", im.index(IDim))(
im.op_as_fieldop("plus", domain)(
"i",
im.as_fieldop(
im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain
)("i"),
)
),
domain=subdomain,
target=gtir.SymRef(id="x"),
)
],
)

v = np.empty(N, dtype=np.int32)

sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS)

sdfg(v, **FSYMBOLS)
np.allclose(v[: N - 1], np.arange(N - 1, dtype=np.int32) + np.arange(1, N, dtype=np.int32))
Loading