From 822520e499165466deaf1f237e4c1c972f070218 Mon Sep 17 00:00:00 2001 From: "honghua.cao" Date: Tue, 27 Oct 2020 17:39:08 +0800 Subject: [PATCH] [RELAY][OP] roi_pool operator alter layout --- python/tvm/relay/op/vision/_rcnn.py | 44 ++++++++++++++- src/relay/op/vision/rcnn_op.cc | 30 +++++++++-- .../relay/test_pass_convert_op_layout.py | 54 +++++++++++++++++++ 3 files changed, 123 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py index a5cc266f1566..359375218df4 100644 --- a/python/tvm/relay/op/vision/_rcnn.py +++ b/python/tvm/relay/op/vision/_rcnn.py @@ -69,11 +69,53 @@ def convert_roi_align(attrs, inputs, tinfos, desired_layouts): raise ValueError("Layout %s is not yet supported." % desired_data_layout) +@reg.register_convert_op_layout("vision.roi_pool") +def convert_roi_pool(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for roi_pool op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current roi_pool + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and rois inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + data, rois = inputs + new_attrs = dict(attrs) + assert ( + len(desired_layouts) == 2 + ), "A desired layout is expected for both of vision.roi_pool's inputs" + + desired_data_layout, desired_rois_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + assert desired_rois_layout == "default", "Rois layout must be default" + + new_attrs["layout"] = desired_data_layout + # rois layout not change + if desired_data_layout in ["NCHW", "NHWC"]: + return relay.vision.roi_pool(data, rois, **new_attrs) + + raise ValueError("Layout %s is not yet supported." % desired_data_layout) + + # roi_pool @reg.register_compute("vision.roi_pool") def compute_roi_pool(attrs, inputs, _): """Compute definition of roi_pool""" - assert attrs.layout == "NCHW" + assert attrs.layout == "NCHW", "only support nchw for now" return [ topi.vision.rcnn.roi_pool_nchw( inputs[0], diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 8be38d020480..f7bbf378d09c 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -119,14 +119,35 @@ bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK(roi_pool_attrs); ICHECK_EQ(dshape.size(), 4) << "Input data should be 4-D."; ICHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D."; - ICHECK_EQ(roi_pool_attrs->layout, "NCHW") << "ROI Pool only supports NCHW layout"; // assign output type - std::vector oshape( - {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]}); + std::vector oshape; + if (roi_pool_attrs->layout == "NCHW") { + oshape = {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]}; + } else if (roi_pool_attrs->layout == "NHWC") { + oshape = {rshape[0], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1], dshape[3]}; + } else { + LOG(FATAL) << "vision.roi_pool does not support " << roi_pool_attrs->layout << " layout"; + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } +template +Array > ROIPoolInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + Layout data_layout = params->layout; + + // Layout inference needs to define the layout for all inputs and output data layouts. + // For roi_pool, the second inputs is 2-D tensor with shape [num_roi, 5]. + // So, we set the layout as "N5". + return Array >{{data_layout, Layout("N5")}, {data_layout}}; +} + Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spatial_scale, String layout) { auto attrs = make_object(); @@ -153,7 +174,8 @@ RELAY_REGISTER_OP("vision.roi_pool") .add_argument("data", "Tensor", "The input tensor.") .add_argument("rois", "Tensor", "The input rois") .set_support_level(5) - .add_type_rel("ROIPool", ROIPoolRel); + .add_type_rel("ROIPool", ROIPoolRel) + .set_attr("FInferCorrectLayout", ROIPoolInferCorrectLayout); TVM_REGISTER_NODE_TYPE(ProposalAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 1fc5d39b9486..7fc896a72905 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -966,6 +966,59 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_roi_pool_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + rois = relay.var("rois", shape=(32, 5)) + y = relay.vision.roi_pool( + y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NCHW" + ) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight1 = relay.layout_transform(weight1, "OIHW", "HWIO") + y = relay.nn.conv2d( + x, + weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + rois = relay.var("rois", shape=(32, 5)) + y = relay.vision.roi_pool( + y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NHWC" + ) + ret = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + desired_layouts = { + "nn.conv2d": ["NHWC", "HWIO"], + "vision.roi_pool": ["NHWC", "default"], + } + a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_default_keyword(): """ Check that the default keyword selects correct TVM default layout. """ @@ -1253,6 +1306,7 @@ def expected(): test_conv_convert_kernel_layout() test_conv_transpose_convert_layout() test_conv_roi_align_convert_layout() + test_conv_roi_pool_convert_layout() test_conv_strided_slice_convert_layout() test_default_keyword() test_different_ops_convert_layout()