Skip to content

Commit

Permalink
Add op_name in error message for Pool (apache#7243)
Browse files Browse the repository at this point in the history
* add op_name in error message for Pool

* fix tiny issue for arguments

* fix tiny issue for LpPool

Co-authored-by: luyaor <luyaor@luyaordeMacBook-Pro.local>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent 5976f75 commit 801f0fb
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,19 @@ def get_pad_pair(input1d, kernel1d, stride1d):
return [pad_before, pad_after]


def onnx_default_layout(dims):
def onnx_default_layout(dims, op_name):
if dims == 1:
return "NCW"
if dims == 2:
return "NCHW"
if dims == 3:
return "NCDHW"

msg = "Only 1D, 2D and 3D layouts are currently supported"
msg = "Only 1D, 2D and 3D layouts are currently supported for operator {}."
raise tvm.error.OpAttributeInvalid(msg.format(op_name))


def onnx_storage_order2layout(storage_order, dims=2):
def onnx_storage_order2layout(storage_order, dims, op_name):
"""converter of onnx storage order parameter to tvm storage order format"""
if storage_order not in (0, 1):
raise tvm.error.OpAttributeInvalid("Mode of storage_order must be either 0 or 1")
Expand All @@ -191,7 +191,7 @@ def onnx_storage_order2layout(storage_order, dims=2):
if dims == 3:
return "NCDHW" if storage_order == 0 else "NDHWC"

msg = "Only 1D, 2D and 3D layouts are currently supported"
msg = "Only 1D, 2D and 3D layouts are currently supported for operator {}."
raise tvm.error.OpAttributeInvalid(msg.format(op_name))


Expand Down Expand Up @@ -300,10 +300,10 @@ def _impl_v1(cls, inputs, attr, params):

if "storage_order" in attr:
attr["layout"] = onnx_storage_order2layout(
attr["storage_order"], dims=(len(input_shape) - 2)
attr["storage_order"], dims=(len(input_shape) - 2), op_name=cls.name
)
else:
attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2))
attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2), op_name=cls.name)

return AttrCvt(
op_name=dimension_picker(cls.name),
Expand Down Expand Up @@ -709,10 +709,10 @@ def _impl_v1(cls, inputs, attr, params):

if "storage_order" in attr:
attr["layout"] = onnx_storage_order2layout(
attr["storage_order"], dims=(len(input_shape) - 2)
attr["storage_order"], dims=(len(input_shape) - 2), op_name="LpPool"
)
else:
attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2))
attr["layout"] = onnx_default_layout(dims=(len(input_shape) - 2), op_name="LpPool")

p = _expr.const(attr["p"], dtype)
reci_p = _expr.const(1.0 / attr["p"], dtype)
Expand Down

0 comments on commit 801f0fb

Please sign in to comment.