Skip to content

Commit

Permalink
Fix PRelu layout in Relay (apache#3013)
Browse files Browse the repository at this point in the history
* Fix PRelu layout in Relay

* Fix cpplint

* Add PRelu test case
  • Loading branch information
FrozenGene authored and Wei Chen committed May 13, 2019
1 parent 6750d95 commit 00dd7a0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
1 change: 0 additions & 1 deletion nnvm/src/top/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);

// The case of parametric relu
CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim();
CHECK(size_t(param.axis) < dshape.Size())
<< "Wrong axis (" << param.axis << ")value.";

Expand Down
19 changes: 18 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,23 @@ bool PReluRel(const Array<Type>& types,
return true;
}

template<typename T>
Array<Array<Layout> > PReluInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {

CHECK_EQ(old_in_layouts.size(), 2U);
CHECK_EQ(old_in_shapes.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 2U);
}
return Array<Array<Layout> >{{data_layout, Layout("C")},
{data_layout}};
}

// Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data,
Expr alpha,
Expand Down Expand Up @@ -265,7 +282,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,52 @@ def expected():

assert alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_prelu():
"""Test PRelu operator"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight")
alpha = relay.var("alpha", relay.IncompleteType())
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.prelu(y, alpha)
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
w = relay.var("weight")
alpha = relay.var("alpha", relay.IncompleteType())

y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.nn.prelu(y, alpha)
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))


if __name__ == "__main__":
test_alter_op()
Expand All @@ -525,3 +571,4 @@ def expected():
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
test_alter_layout_prelu()

0 comments on commit 00dd7a0

Please sign in to comment.