diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 6039c82fdb..5d3cc7a358 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -26,6 +26,13 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: if not isinstance(arg, gtx_common.Field): return arg + if len(arg.domain.dims) == 0: + # Pass zero-dimensional fields as scalars. + # We need to extract the scalar value from the 0d numpy array without changing its type. + # Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar, + # which may change its precision. To avoid this, we use here the empty tuple as index + # for 'ndarray.__getitem__()'. + return arg.ndarray[()] # field domain offsets are not supported non_zero_offsets = [ (dim, dim_range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index b60c86b349..bb37440fe2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -138,6 +138,32 @@ def _parse_fieldop_arg( raise NotImplementedError(f"Node type {type(arg.gt_dtype)} not supported.") +def _get_field_shape( + domain: FieldopDomain, +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: + """ + Parse the field operator domain and generates the shape of the result field. + + It should be enough to allocate an array with shape (upper_bound - lower_bound) + but this would require to use array offset for compensate for the start index. + Suppose that a field operator executes on domain [2,N-2], the dace array to store + the result only needs size (N-4), but this would require to compensate all array + accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose + to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset + is known to cause issues to SDFG inlining. Besides, map fusion will in any case + eliminate most of transient arrays. + + Args: + domain: The field operator domain. + + Returns: + A tuple of two lists: the list of field dimensions and the list of dace + array sizes in each dimension. + """ + domain_dims, _, domain_ubs = zip(*domain) + return list(domain_dims), list(domain_ubs) + + def _create_temporary_field( sdfg: dace.SDFG, state: dace.SDFGState, @@ -146,17 +172,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - domain_dims, _, domain_ubs = zip(*domain) - field_dims = list(domain_dims) - # It should be enough to allocate an array with shape (upper_bound - lower_bound) - # but this would require to use array offset for compensate for the start index. - # Suppose that a field operator executes on domain [2,N-2], the dace array to store - # the result only needs size (N-4), but this would require to compensate all array - # accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose - # to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset - # is known to cause issues to SDFG inlining. Besides, map fusion will in any case - # eliminate most of transient arrays. - field_shape = list(domain_ubs) + field_dims, field_shape = _get_field_shape(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -311,17 +327,46 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + field_dims, field_shape = _get_field_shape(domain) + field_subset = sbs.Range.from_string( + ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) + ) assert len(node.args) == 1 - assert isinstance(node.args[0].type, ts.ScalarType) scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - assert isinstance(scalar_expr, gtir_dataflow.MemletExpr) - assert scalar_expr.subset == sbs.Indices.from_string("0") - result = gtir_dataflow.DataflowOutputEdge( - state, gtir_dataflow.ValueExpr(scalar_expr.dc_node, node.args[0].type) - ) - result_field = _create_temporary_field(sdfg, state, domain, node.type, dataflow_output=result) + + if isinstance(node.args[0].type, ts.ScalarType): + assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) + input_subset = ( + str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" + ) + input_node = scalar_expr.dc_node + gt_dtype = node.args[0].type + elif isinstance(node.args[0].type, ts.FieldType): + assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) + if len(node.args[0].type.dims) == 0: # zero-dimensional field + input_subset = "0" + elif all( + isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) + for dim in scalar_expr.dimensions + if dim not in field_dims + ): + input_subset = ",".join( + dace_gtir_utils.get_map_variable(dim) + if dim in field_dims + else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above + for dim in scalar_expr.dimensions + ) + else: + raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + + input_node = scalar_expr.field + gt_dtype = node.args[0].type.dtype + else: + raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") + + output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( "broadcast", @@ -330,15 +375,15 @@ def translate_broadcast_scalar( dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" for dim, lower_bound, upper_bound in domain }, - inputs={"__inp": dace.Memlet(data=scalar_expr.dc_node.data, subset="0")}, + inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=result_field.dc_node.data, subset=domain_indices)}, - input_nodes={scalar_expr.dc_node.data: scalar_expr.dc_node}, - output_nodes={result_field.dc_node.data: result_field.dc_node}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + input_nodes={input_node.data: input_node}, + output_nodes={output_node.data: output_node}, external_edges=True, ) - return result_field + return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype), local_offset=None) def translate_if( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 416321c038..cf91d15aba 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -426,82 +426,86 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: assert len(node.args) == 1 arg_expr = self.visit(node.args[0]) - if isinstance(arg_expr, IteratorExpr): - field_desc = arg_expr.field.desc(self.sdfg) - assert len(field_desc.shape) == len(arg_expr.dimensions) - if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(arg_expr.dimensions, field_desc.shape) - ) - return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) - - else: - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] - index_connectors = [ - IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - if not isinstance(index, SymbolExpr) - ] - # here `internals` refer to the names used as index in the tasklet code string: - # an index can be either a connector name (for dynamic/indirect indices) - # or a symbol value (for literal values and scalar arguments). - index_internals = ",".join( - str(index.value) - if isinstance(index, SymbolExpr) - else IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - ) - deref_node = self._add_tasklet( - "runtime_deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", - ) - # add new termination point for the field parameter - self._add_input_data_edge( - arg_expr.field, - sbs.Range.from_array(field_desc), - deref_node, - "field", - ) + if not isinstance(arg_expr, IteratorExpr): + # dereferencing a scalar or a literal node results in the node itself + return arg_expr - for dim, index_expr in field_indices: - # add termination points for the dynamic iterator indices - deref_connector = IndexConnectorFmt.format(dim=dim.value) - if isinstance(index_expr, MemletExpr): - self._add_input_data_edge( - index_expr.dc_node, - index_expr.subset, - deref_node, - deref_connector, - ) - - elif isinstance(index_expr, ValueExpr): - self._add_edge( - index_expr.dc_node, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.dc_node.data, subset="0"), - ) - else: - assert isinstance(index_expr, SymbolExpr) - - dc_dtype = arg_expr.field.desc(self.sdfg).dtype - return self._construct_tasklet_result( - dc_dtype, deref_node, "val", arg_expr.local_offset - ) + field_desc = arg_expr.field.desc(self.sdfg) + if isinstance(field_desc, dace.data.Scalar): + # deref a zero-dimensional field + assert len(arg_expr.dimensions) == 0 + assert isinstance(node.type, ts.ScalarType) + return MemletExpr(arg_expr.field, subset="0") + # default case: deref a field with one or more dimensions + assert len(field_desc.shape) == len(arg_expr.dimensions) + if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): + # when all indices are symblic expressions, we can perform direct field access through a memlet + field_subset = sbs.Range( + (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] + if dim in arg_expr.indices + else (0, size - 1, 1) + for dim, size in zip(arg_expr.dimensions, field_desc.shape) + ) + return MemletExpr(arg_expr.field, field_subset, arg_expr.local_offset) else: - # dereferencing a scalar or a literal node results in the node itself - return arg_expr + # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, + # either indirection through connectivity table or dynamic cartesian offset. + assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) + field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + index_connectors = [ + IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + if not isinstance(index, SymbolExpr) + ] + # here `internals` refer to the names used as index in the tasklet code string: + # an index can be either a connector name (for dynamic/indirect indices) + # or a symbol value (for literal values and scalar arguments). + index_internals = ",".join( + str(index.value) + if isinstance(index, SymbolExpr) + else IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + ) + deref_node = self._add_tasklet( + "runtime_deref", + {"field"} | set(index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + # add new termination point for the field parameter + self._add_input_data_edge( + arg_expr.field, + sbs.Range.from_array(field_desc), + deref_node, + "field", + ) + + for dim, index_expr in field_indices: + # add termination points for the dynamic iterator indices + deref_connector = IndexConnectorFmt.format(dim=dim.value) + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + deref_node, + deref_connector, + ) + + elif isinstance(index_expr, ValueExpr): + self._add_edge( + index_expr.dc_node, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) + + return self._construct_tasklet_result( + field_desc.dtype, deref_node, "val", arg_expr.local_offset + ) def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert len(node.args) == 2 diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 28eef5c260..e489f130db 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -255,6 +255,10 @@ def _add_storage( return tuple_fields elif isinstance(gt_type, ts.FieldType): + if len(gt_type.dims) == 0: + # represent zero-dimensional fields as scalar arguments + return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) + # handle default case: field with one or more dimensions dc_dtype = dace_utils.as_dace_type(gt_type.dtype) # use symbolic shape, which allows to invoke the program with fields of different size; # and symbolic strides, which enables decoupling the memory layout from generated code. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..7540d52fb3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -219,6 +219,7 @@ def testee(a: tuple[int32, tuple[int32, int32]]) -> cases.VField: @pytest.mark.uses_tuple_args +@pytest.mark.uses_zero_dimensional_fields def test_zero_dim_tuple_arg(unstructured_case): @gtx.field_operator def testee( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 5377654b55..41f540d3cf 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -372,6 +372,35 @@ def test_gtir_tuple_broadcast_scalar(): assert np.allclose(d, a + 2 * b + 3 * c) +def test_gtir_zero_dim_fields(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="gtir_zero_dim_fields", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[], dtype=IFTYPE.dtype)), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop("deref", domain)("x"), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.asarray(np.random.rand()) + b = np.empty(N) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a.item(), b, **FSYMBOLS) + assert np.allclose(a, b) + + def test_gtir_tuple_return(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program(