diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 22636afb97e1..7c5682d462fc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2613,7 +2613,35 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); } -// TODO(relax-team): implement FRelaxInferLayout for scatter_elements +InferLayoutOutput InferLayoutScatterElements( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); + LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]); + + LayoutDecision layout = data_layout; + if (NLayoutEqual()(indices_layout, updates_layout)) { + layout = indices_layout; + } + + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + layout = LayoutDecision(InitialLayout(ndim)); + } + + ObjectPtr new_attrs = ffi::make_object(*attrs); + new_attrs->axis = FindAxis(layout->layout, attrs->axis->value); + return InferLayoutOutput({layout, layout, layout}, {layout}, Attrs(new_attrs)); +} + TVM_REGISTER_OP("relax.scatter_elements") .set_attrs_type() .set_num_inputs(3) @@ -2621,6 +2649,7 @@ TVM_REGISTER_OP("relax.scatter_elements") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterElements) + .set_attr("FRelaxInferLayout", InferLayoutScatterElements) .set_attr("FPurity", Bool(true)); /* relax.scatter_nd */ diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 328fbf456e4b..3f70dce36eb4 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2443,22 +2443,22 @@ def forward(self, data, index, src): expected1 = { "inputs": [ - {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}, + {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": "AB"}, ], "outputs": [ - {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""} + {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"} ], "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 1}, } expected2 = { "inputs": [ - {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""}, - {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""}, - {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": ""}, + {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}, + {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": "AB"}, + {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": "AB"}, ], "outputs": [ - {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""} + {"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"} ], "nodes": {"total": 4, "input": 3, "scatter_elements": 1}, } diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 8ae96e9c07d3..26990bc44db3 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -5327,5 +5327,60 @@ def main( verify(Input, Expected) +def test_conv2d_scatter_elements(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + indices: R.Tensor((2, 4, 26, 26), "int64"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data) + gv = R.scatter_elements(data, indices, updates, axis=1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + indices: R.Tensor((2, 4, 26, 26), dtype="int64"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + updates: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(data) + lv2: R.Tensor((2, 26, 26, 4), dtype="int64") = R.permute_dims( + indices, axes=[0, 2, 3, 1] + ) + lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_elements( + data, lv2, updates, axis=3, reduction="update" + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv3, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main()