Skip to content

Commit

Permalink
fix prelu, now can use on 2d input and add one test (apache#2875)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaolongMeng authored and MarisaKirisame committed Apr 9, 2019
1 parent 997ae1e commit 2fc7662
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 2 deletions.
1 change: 0 additions & 1 deletion topi/include/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
const int axis = 1,
std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(4, x->shape.size());
CHECK((size_t)axis < x->shape.size()) <<
"Wrong axis (" << axis << ")value. ";
CHECK(topi::detail::GetConstInt(slope->shape[0]) ==
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def prelu(x, slope, axis=1):
[http://arxiv.org/pdf/1502.01852v1.pdf]
"""

assert len(x.shape) == 4 and len(slope.shape) == 1
assert len(slope.shape) == 1
assert axis < len(x.shape)
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])

Expand Down
1 change: 1 addition & 0 deletions topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_leaky_relu():
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
verify_prelu((1, 3), (3,), 1, (3, ))

if __name__ == "__main__":
test_schedule_big_array()
Expand Down

0 comments on commit 2fc7662

Please sign in to comment.