diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 4ac7affb0c86..22636afb97e1 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2047,11 +2047,44 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { return data_sinfo; } +InferLayoutOutput InferLayoutFlip( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + int ndim = tensor_sinfo->ndim; + + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } + + int axis = attrs->axis.IntValue(); + if (axis < 0) { + axis += ndim; + } + + const int new_axis = FindAxis(existing_layout->layout, axis); + ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; + + ObjectPtr new_attrs = ffi::make_object(*attrs); + new_attrs->axis = Integer(new_axis); + + return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); +} + TVM_REGISTER_OP("relax.flip") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlip) + .set_attr("FRelaxInferLayout", InferLayoutFlip) .set_attr("FPurity", Bool(true)); /* relax.gather_elements */ diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 5ba0c4d86771..8ae96e9c07d3 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -5283,5 +5283,49 @@ def main( verify(Input, Expected) +def test_conv2d_flip(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.flip(gv, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.flip(gv, axis=3) + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main()