diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index 7d6b3369e88d..57c288bdefc4 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -525,7 +525,6 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); // The case of parametric relu - CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim(); CHECK(size_t(param.axis) < dshape.Size()) << "Wrong axis (" << param.axis << ")value."; diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 2356634c4ed0..9aa4116e76cf 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -226,6 +226,23 @@ bool PReluRel(const Array& types, return true; } +template +Array > PReluInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array> &old_in_shapes) { + + CHECK_EQ(old_in_layouts.size(), 2U); + CHECK_EQ(old_in_shapes.size(), 2U); + Layout data_layout = old_in_layouts[0]; + if (new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 2U); + } + return Array >{{data_layout, Layout("C")}, + {data_layout}}; +} + // Positional relay function to create prelu operator used by frontend FFI. Expr MakePRelu(Expr data, Expr alpha, @@ -253,7 +270,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. .add_argument("alpha", "Tensor", "Input channelwise alpha.") .set_support_level(3) .add_type_rel("PRelu", PReluRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) +.set_attr("FInferCorrectLayout", PReluInferCorrectLayout) .set_attr( "FTVMCompute", [](const Attrs& attrs, const Array& inputs, diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index f7a1c83ddff1..94b3fcd8e30e 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -513,6 +513,52 @@ def expected(): assert(alpha_equal(a, b)) +def test_alter_layout_prelu(): + """Test PRelu operator""" + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight") + alpha = relay.var("alpha", relay.IncompleteType()) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.prelu(y, alpha) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=110) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + w = relay.var("weight") + alpha = relay.var("alpha", relay.IncompleteType()) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.prelu(y, alpha) + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + if __name__ == "__main__": test_alter_op() @@ -525,3 +571,4 @@ def expected(): test_alter_layout_concatenate() test_alter_layout_nchw_upsamping_op() test_alter_layout_strided_slice() + test_alter_layout_prelu()