Skip to content

Commit

Permalink
[RELAY][OP] roi_pool operator alter layout
Browse files Browse the repository at this point in the history
  • Loading branch information
honghua.cao committed Oct 15, 2020
1 parent 98c2096 commit eff8210
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 5 deletions.
44 changes: 43 additions & 1 deletion python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,53 @@ def convert_roi_align(attrs, inputs, tinfos, desired_layouts):
raise ValueError("Layout %s is not yet supported." % desired_data_layout)


@reg.register_convert_op_layout("vision.roi_pool")
def convert_roi_pool(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for roi_pool op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current roi_pool
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data and rois inputs respectively.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, rois = inputs
new_attrs = dict(attrs)
assert (
len(desired_layouts) == 2
), "A desired layout is expected for both of vision.roi_pool's inputs"

desired_data_layout, desired_rois_layout = map(str, desired_layouts)
assert desired_data_layout != "default", "Data layout cannot be default"
assert desired_rois_layout == "default", "Rois layout must be default"

new_attrs["layout"] = desired_data_layout
# rois layout not change
if desired_data_layout in ["NCHW", "NHWC"]:
return relay.vision.roi_pool(data, rois, **new_attrs)

raise ValueError("Layout %s is not yet supported." % desired_data_layout)


# roi_pool
@reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _):
"""Compute definition of roi_pool"""
assert attrs.layout == "NCHW"
assert attrs.layout == "NCHW", "only support nchw for now"
return [
topi.vision.rcnn.roi_pool_nchw(
inputs[0],
Expand Down
30 changes: 26 additions & 4 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,35 @@ bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK(roi_pool_attrs);
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
CHECK_EQ(roi_pool_attrs->layout, "NCHW") << "ROI Pool only supports NCHW layout";
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
std::vector<IndexExpr> oshape;
if (roi_pool_attrs->layout == "NCHW") {
oshape = {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]};
} else if (roi_pool_attrs->layout == "NHWC") {
oshape = {rshape[0], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1], dshape[3]};
} else {
LOG(FATAL) << "vision.roi_pool does not support " << roi_pool_attrs->layout << " layout";
}

reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

template <typename T>
Array<Array<Layout> > ROIPoolInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
Layout data_layout = params->layout;

// Layout inference needs to define the layout for all inputs and output data layouts.
// For roi_pool, the second inputs is 2-D tensor with shape [num_roi, 5].
// So, we set the layout as "N5".
return Array<Array<Layout> >{{data_layout, Layout("N5")}, {data_layout}};
}

Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
String layout) {
auto attrs = make_object<ROIPoolAttrs>();
Expand All @@ -153,7 +174,8 @@ RELAY_REGISTER_OP("vision.roi_pool")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
.add_type_rel("ROIPool", ROIPoolRel);
.add_type_rel("ROIPool", ROIPoolRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ROIPoolInferCorrectLayout<ROIPoolAttrs>);

TVM_REGISTER_NODE_TYPE(ProposalAttrs);

Expand Down
54 changes: 54 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,59 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_roi_pool_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight1,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
rois = relay.var("rois", shape=(32, 5))
y = relay.vision.roi_pool(
y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NCHW"
)
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight1 = relay.layout_transform(weight1, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight1,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
rois = relay.var("rois", shape=(32, 5))
y = relay.vision.roi_pool(
y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NHWC"
)
ret = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before()
desired_layouts = {
"nn.conv2d": ["NHWC", "HWIO"],
"vision.roi_pool": ["NHWC", "default"],
}
a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_default_keyword():
""" Check that the default keyword selects correct TVM default layout. """

Expand Down Expand Up @@ -1181,6 +1234,7 @@ def expected():
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_conv_roi_pool_convert_layout()
test_conv_strided_slice_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()
Expand Down

0 comments on commit eff8210

Please sign in to comment.