Skip to content

Commit

Permalink
Set default value of p in LpPool as 2 (#8866)
Browse files Browse the repository at this point in the history
* Set default value of p in LpPool as 2

* Update test_forward.py

Fix bug in test.

* Update test_forward.py

update with correct shape.

* Update onnx.py

* Update python/tvm/relay/frontend/onnx.py

Co-authored-by: Wuwei Lin <vincentl13x@gmail.com>

Co-authored-by: luyaor <luyaor@luyaordeMacBook-Pro.local>
Co-authored-by: Wuwei Lin <vincentl13x@gmail.com>
  • Loading branch information
3 people authored Sep 2, 2021
1 parent 27be462 commit 7c9811c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,9 @@ def _impl_v1(cls, inputs, attr, params):
else:
attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2), op_name="LpPool")

p = _expr.const(attr["p"], dtype)
reci_p = _expr.const(1.0 / attr["p"], dtype)
p_value = attr.get("p", 2)
p = _expr.const(p_value, dtype)
reci_p = _expr.const(1.0 / p_value, dtype)
data = _op.power(data, p)

out = AttrCvt(
Expand Down
16 changes: 14 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3184,25 +3184,28 @@ def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_sh
@tvm.testing.parametrize_targets
def test_lppool(target, dev):
def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"):
kwargs = {}
if p is not None:
kwargs["p"] = p
if pads is None:
pool_node = helper.make_node(
"LpPool",
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
p=p,
auto_pad=auto_pad,
strides=strides,
**kwargs,
)
else:
pool_node = helper.make_node(
"LpPool",
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
p=p,
pads=pads,
strides=strides,
**kwargs,
)

graph = helper.make_graph(
Expand Down Expand Up @@ -3295,6 +3298,15 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="
out_shape=[1, 1, 16, 16, 16],
auto_pad="SAME_UPPER",
)
# Pool2D with empty p
verify_lppool(
x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
p=None,
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
)


def verify_rnn(
Expand Down

0 comments on commit 7c9811c

Please sign in to comment.