Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next][dace]: Fix lowering of broadcast #1698

Merged
merged 93 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
0272158
map-reduce for full connectivity
edopao Oct 3, 2024
e62bae9
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce
edopao Oct 3, 2024
43df0e8
Use correct map index variable
edopao Oct 3, 2024
4b35572
enable test case
edopao Oct 7, 2024
e8e37a4
add support for make_const_list
edopao Oct 7, 2024
a885de2
Misc fixes for GTIR lowering
edopao Oct 8, 2024
095e28e
add test case
edopao Oct 8, 2024
8fc95a8
minor edit
edopao Oct 8, 2024
cca4fcd
undo extra change
edopao Oct 8, 2024
1bb20cd
review comments
edopao Oct 8, 2024
be06e9f
Merge remote-tracking branch 'origin/dace-gtir-fixes' into dace-gtir-…
edopao Oct 8, 2024
049b044
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce
edopao Oct 8, 2024
4241b23
minor edit
edopao Oct 8, 2024
49f7797
minor edit (1)
edopao Oct 9, 2024
892bca4
remove extra change
edopao Oct 9, 2024
fa876c9
minor edit (2)
edopao Oct 9, 2024
ed7fd6e
fix return type of dace backend
edopao Oct 9, 2024
986e695
update test case with skip values
edopao Oct 9, 2024
f7d84ba
Improved handling of map input edges for fieldop
edopao Oct 9, 2024
c90b1cc
Improved handling of map output for fieldop
edopao Oct 9, 2024
7f247fe
review comments
edopao Oct 9, 2024
0f897de
refactoring
edopao Oct 10, 2024
397cede
refactoring
edopao Oct 10, 2024
06fd06f
refactoring (1)
edopao Oct 10, 2024
feb0951
review comments
edopao Oct 10, 2024
9c9e7f8
code-style fix
edopao Oct 10, 2024
92bc1a8
undo refactoring changes
edopao Oct 10, 2024
7b8155d
Merge remote-tracking branch 'origin/dace-gtir-refactoring' into dace…
edopao Oct 10, 2024
684e607
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce
edopao Oct 10, 2024
1265f93
improve code comments
edopao Oct 10, 2024
ee972a6
better lowering of fieldop domain
edopao Oct 11, 2024
171288b
rename field_op to fieldop
edopao Oct 11, 2024
59161c6
Merge remote-tracking branch 'origin/dace-gtir-map_reduce' into dace-…
edopao Oct 11, 2024
6103461
minor unrelated edit
edopao Oct 11, 2024
37d4d8a
working draft
edopao Oct 13, 2024
4334621
review comments
edopao Oct 14, 2024
3e88216
Merge remote-tracking branch 'origin/dace-gtir-map_reduce' into dace-…
edopao Oct 14, 2024
07b5e79
minor edit in preparation to skip_values
edopao Oct 14, 2024
d0345c9
rename mask_offset to local_offset
edopao Oct 14, 2024
f7cfb57
Merge branch 'dace-gtir-map_reduce' into dace-gtir-map_reduce_skip
edopao Oct 14, 2024
bc614fe
add type description
edopao Oct 14, 2024
b468e8a
minor code refactoring
edopao Oct 14, 2024
54f5f65
review comments
edopao Oct 15, 2024
874ae00
Merge remote-tracking branch 'origin/dace-gtir-map_reduce' into dace-…
edopao Oct 15, 2024
8b97053
add support for local dim on field arguments
edopao Oct 15, 2024
7061697
use output_connector on map tasklet
edopao Oct 16, 2024
e71c283
review comments
edopao Oct 16, 2024
083f6fe
edit code comments
edopao Oct 16, 2024
faa7025
Merge remote-tracking branch 'origin/dace-gtir-map_reduce' into dace-…
edopao Oct 16, 2024
d4ec7ef
enable testcase
edopao Oct 16, 2024
67ec6f9
Merge branch 'main' into dace-gtir-map_reduce
edopao Oct 17, 2024
2afd8ab
extract helper method _construct_local_view
edopao Oct 17, 2024
875a8fa
Merge remote-tracking branch 'origin/dace-gtir-map_reduce' into dace-…
edopao Oct 17, 2024
be7bfa4
code refactoring (use _construct_local_view)
edopao Oct 17, 2024
a426a38
make local_offset non-default in Field type
edopao Oct 17, 2024
38c6ee6
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce_…
edopao Oct 17, 2024
05dac84
edit code comments
edopao Oct 17, 2024
c8bb690
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce_…
edopao Oct 18, 2024
05250c8
Fix lowering of nested let-statements
edopao Oct 21, 2024
56d639d
Add lowering of zero-dimensional fields
edopao Oct 21, 2024
96efe23
Add test case for zero-dimensional fields
edopao Oct 21, 2024
55e394b
Review comments
edopao Oct 22, 2024
4afea74
remove extra file
edopao Oct 22, 2024
3a01738
Switch type names DataExpr <-> ValueExpr
edopao Oct 22, 2024
6fc64b1
Change signature of get_map_variable
edopao Oct 22, 2024
35a51a0
Change assert into exception
edopao Oct 22, 2024
b3ce505
cleanup prev commit
edopao Oct 22, 2024
13688b3
review comments (1)
edopao Oct 23, 2024
e5f3c48
Merge remote-tracking branch 'origin/main' into dace-gtir-map_reduce_…
edopao Oct 23, 2024
8829684
Merge remote-tracking branch 'origin/dace-gtir-map_reduce_skip' into …
edopao Oct 23, 2024
947e8fc
Merge remote-tracking branch 'origin/dace-gtir-map_reduce_skip' into …
edopao Oct 23, 2024
4894831
Merge remote-tracking branch 'origin/main' into dece-gtir-zerodim_fields
edopao Oct 23, 2024
4449f5e
improve parsing of zero-dim args
edopao Oct 23, 2024
e64f06a
Merge remote-tracking branch 'origin/main' into dace-gtir-fix_let_stmts
edopao Oct 23, 2024
f24778b
fix previous PR
edopao Oct 23, 2024
ee60e16
fix pre-commit (1)
edopao Oct 23, 2024
e79152a
fix pre-commit (2)
edopao Oct 24, 2024
38e23e2
Avoid bool cast in branch condition expressions
edopao Oct 24, 2024
1c9d193
Revert "improve parsing of zero-dim args"
edopao Oct 24, 2024
ddc9604
fix for broadcast expressions
edopao Oct 24, 2024
c514765
Merge remote-tracking branch 'origin/main' into dace-gtir-fix_let_stmts
edopao Oct 24, 2024
444bb05
Review comments
edopao Oct 24, 2024
b656256
Merge remote-tracking branch 'origin/main' into dece-gtir-zerodim_fields
edopao Oct 24, 2024
274736e
Revert "Revert "improve parsing of zero-dim args""
edopao Oct 24, 2024
ccf9406
fix broadcast of fields
edopao Oct 24, 2024
1aa4a72
fix previous commit
edopao Oct 24, 2024
0482c86
Merge remote-tracking branch 'origin/main' into dace-gtir-fix_let_stmts
edopao Oct 25, 2024
007f2ac
edit code comments
edopao Oct 25, 2024
b24b3f8
edit code comments
edopao Oct 25, 2024
267fbed
Merge remote-tracking branch 'origin/dace-gtir-fix_let_stmts' into de…
edopao Oct 25, 2024
19eb3c2
Review comments
edopao Oct 25, 2024
a532d68
Merge remote-tracking branch 'origin/main' into dece-gtir-zerodim_fields
edopao Oct 25, 2024
496f8f5
avoid type conversion when accessing 0-dim ndarray
edopao Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any:
if not isinstance(arg, gtx_common.Field):
return arg
if len(arg.domain.dims) == 0:
# pass zero-dimensional fields as scalars
return arg.asnumpy().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@egparedes What was the correct pattern for avoiding type conversions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type conversion is not affecting GT4Py. The GT4Py argument has still type ts.FieldType with empty dims list. This type conversion (zero-dim array to scalar) only affects the argument passed to the SDFG call, because in the SDFG I am representing the zero-dim array as a dace scalar object.

Copy link
Contributor

@egparedes egparedes Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed the PR so I don't have a lot of context here, but I think Hannes' comment is about extracting the scalar value from the 0d numpy array without changing its type. ndarray.item() (https://devdocs.io/numpy~2.0/reference/generated/numpy.ndarray.item) always transforms the numpy scalar to a python scalar, which may change its precision. To avoid this, you should use an empty tuple as index for .__getitem__() like this.

Suggested change
return arg.asnumpy().item()
return arg.asnumpy()[()]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, thank you.

# field domain offsets are not supported
non_zero_offsets = [
(dim, dim_range)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,26 @@ def translate_broadcast_scalar(
domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain])

assert len(node.args) == 1
assert isinstance(node.args[0].type, ts.ScalarType)
scalar_expr = _parse_fieldop_arg(
node.args[0], sdfg, state, sdfg_builder, domain, reduce_identity=None
)
assert isinstance(scalar_expr, gtir_dataflow.MemletExpr)
assert scalar_expr.subset == sbs.Indices.from_string("0")
result = gtir_dataflow.DataflowOutputEdge(
state, gtir_dataflow.ValueExpr(scalar_expr.dc_node, node.args[0].type)
)

if isinstance(node.args[0].type, ts.ScalarType):
assert isinstance(scalar_expr, gtir_dataflow.MemletExpr)
assert scalar_expr.subset == sbs.Indices.from_string("0")
dc_node = scalar_expr.dc_node
gt_type = node.args[0].type
elif isinstance(node.args[0].type, ts.FieldType):
# zero-dimensional field
assert len(node.args[0].type.dims) == 0
assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr)
dc_node = scalar_expr.field
gt_type = node.args[0].type.dtype
else:
raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.")

assert isinstance(dc_node.desc(sdfg), dace.data.Scalar)
result = gtir_dataflow.DataflowOutputEdge(state, gtir_dataflow.ValueExpr(dc_node, gt_type))
result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result)

sdfg_builder.add_mapped_tasklet(
Expand All @@ -380,10 +391,10 @@ def translate_broadcast_scalar(
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
inputs={"__inp": dace.Memlet(data=scalar_expr.dc_node.data, subset="0")},
inputs={"__inp": dace.Memlet(data=dc_node.data, subset="0")},
code="__val = __inp",
outputs={"__val": dace.Memlet(data=result_field.dc_node.data, subset=domain_indices)},
input_nodes={scalar_expr.dc_node.data: scalar_expr.dc_node},
input_nodes={dc_node.data: dc_node},
output_nodes={result_field.dc_node.data: result_field.dc_node},
external_edges=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,82 +425,86 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr:
assert len(node.args) == 1
arg_expr = self.visit(node.args[0])

if isinstance(arg_expr, IteratorExpr):
field_desc = arg_expr.field.desc(self.sdfg)
assert len(field_desc.shape) == len(arg_expr.dimensions)
if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()):
# when all indices are symblic expressions, we can perform direct field access through a memlet
field_subset = sbs.Range(
(arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr]
if dim in arg_expr.indices
else (0, size - 1, 1)
for dim, size in zip(arg_expr.dimensions, field_desc.shape)
)
return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset)

else:
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in arg_expr.indices for dim in arg_expr.dimensions)
field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions]
index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
if not isinstance(index, SymbolExpr)
]
# here `internals` refer to the names used as index in the tasklet code string:
# an index can be either a connector name (for dynamic/indirect indices)
# or a symbol value (for literal values and scalar arguments).
index_internals = ",".join(
str(index.value)
if isinstance(index, SymbolExpr)
else IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
)
deref_node = self._add_tasklet(
"runtime_deref",
{"field"} | set(index_connectors),
{"val"},
code=f"val = field[{index_internals}]",
)
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
sbs.Range.from_array(field_desc),
deref_node,
"field",
)
if not isinstance(arg_expr, IteratorExpr):
# dereferencing a scalar or a literal node results in the node itself
return arg_expr

for dim, index_expr in field_indices:
# add termination points for the dynamic iterator indices
deref_connector = IndexConnectorFmt.format(dim=dim.value)
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
deref_node,
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
self._add_edge(
index_expr.dc_node,
None,
deref_node,
deref_connector,
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
else:
assert isinstance(index_expr, SymbolExpr)

dc_dtype = arg_expr.field.desc(self.sdfg).dtype
return self._construct_tasklet_result(
dc_dtype, deref_node, "val", arg_expr.local_offset
)
field_desc = arg_expr.field.desc(self.sdfg)
if isinstance(field_desc, dace.data.Scalar):
# deref a zero-dimensional field
assert len(arg_expr.dimensions) == 0
assert isinstance(node.type, ts.ScalarType)
return MemletExpr(arg_expr.field, subset="0")
# default case: deref a field with one or more dimensions
assert len(field_desc.shape) == len(arg_expr.dimensions)
if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()):
# when all indices are symblic expressions, we can perform direct field access through a memlet
field_subset = sbs.Range(
(arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr]
if dim in arg_expr.indices
else (0, size - 1, 1)
for dim, size in zip(arg_expr.dimensions, field_desc.shape)
)
return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset)

else:
# dereferencing a scalar or a literal node results in the node itself
return arg_expr
# we use a tasklet to dereference an iterator when one or more indices are the result of some computation,
# either indirection through connectivity table or dynamic cartesian offset.
assert all(dim in arg_expr.indices for dim in arg_expr.dimensions)
field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions]
index_connectors = [
IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
if not isinstance(index, SymbolExpr)
]
# here `internals` refer to the names used as index in the tasklet code string:
# an index can be either a connector name (for dynamic/indirect indices)
# or a symbol value (for literal values and scalar arguments).
index_internals = ",".join(
str(index.value)
if isinstance(index, SymbolExpr)
else IndexConnectorFmt.format(dim=dim.value)
for dim, index in field_indices
)
deref_node = self._add_tasklet(
"runtime_deref",
{"field"} | set(index_connectors),
{"val"},
code=f"val = field[{index_internals}]",
)
# add new termination point for the field parameter
self._add_input_data_edge(
arg_expr.field,
sbs.Range.from_array(field_desc),
deref_node,
"field",
)

for dim, index_expr in field_indices:
# add termination points for the dynamic iterator indices
deref_connector = IndexConnectorFmt.format(dim=dim.value)
if isinstance(index_expr, MemletExpr):
self._add_input_data_edge(
index_expr.dc_node,
index_expr.subset,
deref_node,
deref_connector,
)

elif isinstance(index_expr, ValueExpr):
self._add_edge(
index_expr.dc_node,
None,
deref_node,
deref_connector,
dace.Memlet(data=index_expr.dc_node.data, subset="0"),
)
else:
assert isinstance(index_expr, SymbolExpr)

return self._construct_tasklet_result(
field_desc.dtype, deref_node, "val", arg_expr.local_offset
)

def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
assert len(node.args) == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def _add_storage(
return tuple_fields

elif isinstance(gt_type, ts.FieldType):
if len(gt_type.dims) == 0:
# represent zero-dimensional fields as scalar arguments
return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient)
# handle default case: field with one or more dimensions
dc_dtype = dace_utils.as_dace_type(gt_type.dtype)
# use symbolic shape, which allows to invoke the program with fields of different size;
# and symbolic strides, which enables decoupling the memory layout from generated code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def testee(a: tuple[int32, tuple[int32, int32]]) -> cases.VField:


@pytest.mark.uses_tuple_args
@pytest.mark.uses_zero_dimensional_fields
def test_zero_dim_tuple_arg(unstructured_case):
@gtx.field_operator
def testee(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,35 @@ def test_gtir_tuple_broadcast_scalar():
assert np.allclose(d, a + 2 * b + 3 * c)


def test_gtir_zero_dim_fields():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
id="gtir_zero_dim_fields",
function_definitions=[],
params=[
gtir.Sym(id="x", type=ts.FieldType(dims=[], dtype=IFTYPE.dtype)),
gtir.Sym(id="y", type=IFTYPE),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.as_fieldop("deref", domain)("x"),
domain=domain,
target=gtir.SymRef(id="y"),
)
],
)

a = np.asarray(np.random.rand())
b = np.empty(N)

sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS)

sdfg(a.item(), b, **FSYMBOLS)
assert np.allclose(a, b)


def test_gtir_tuple_return():
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
testee = gtir.Program(
Expand Down
Loading