Skip to content

Commit

Permalink
[ONNX] [Relay] Resize Opset 13 (apache#9265)
Browse files Browse the repository at this point in the history
* Fix handling of optional inputs.

* Missed one test in the ignore list.

* split 11 and 13

* removed comments, adjusted for git review

Co-authored-by: Josh Fromm <jwfromm@uw.edu>
Co-authored-by: Matthew <mbrookhart@octoml.ai>
Co-authored-by: CircleSpin <jocelyn@pop-os.localdomain>
  • Loading branch information
4 people authored and ylc committed Jan 13, 2022
1 parent c5f5dc3 commit 05c125b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 18 deletions.
44 changes: 34 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,40 @@ def _impl_v10(cls, inputs, attr, params):

@classmethod
def _impl_v11(cls, inputs, attr, params):
scale = inputs[2]
scale_shape = infer_shape(scale)
if len(inputs) == 4:
assert (
len(scale_shape) == 0 or scale_shape[0] == 0
), "One of scale or size should be passed, not both."
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
return cls.v11_13_common(inputs, size, attr, params)

@classmethod
def _impl_v13(cls, inputs, attr, params):
scale = inputs[2]
size = inputs[3]
if size is not None:
assert scale is None, "One of scale or size should be passed, not both."
else:
scale_type = infer_type(scale)
scale_shape = scale_type.checked_type.shape
scale_dtype = scale_type.checked_type.dtype
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), scale_dtype) * scale

return cls.v11_13_common(inputs, size, attr, params)

@classmethod
def v11_13_common(cls, inputs, size, attr, params):
"""
Resize v11 and Resize v13 are identical except in how
they handle the passing of scale and size. This utility
provides the implementation for both
"""
ndims = len(infer_shape(inputs[0]))
mode = attr.get("mode").decode("ascii")
if mode == "nearest":
Expand All @@ -2715,16 +2749,6 @@ def _impl_v11(cls, inputs, attr, params):
alpha = attr.get("cubic_coeff_a", -0.75)
exclude = attr.get("exclude_outside", 0)

scale = inputs[2]
scale_shape = infer_shape(scale)
if len(inputs) == 4:
assert (
len(scale_shape) == 0 or scale_shape[0] == 0
), "One of scale or size should be passed, not both."
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
out = None
if ndims == 3:
Expand Down
9 changes: 1 addition & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3970,6 +3970,7 @@ def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, ex
make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales),
]
input_names = ["X", "roi", "scales"]

if oshape != []:
nodes.append(
make_constant_node("sizes", onnx.TensorProto.INT64, (len(oshape),), oshape)
Expand Down Expand Up @@ -4954,15 +4955,7 @@ def verify_eyelike(indata):
"test_reduce_sum_keepdims_random",
"test_reduce_sum_negative_axes_keepdims_example",
"test_reduce_sum_negative_axes_keepdims_random",
"test_resize_downsample_sizes_cubic",
"test_resize_downsample_sizes_linear_pytorch_half_pixel",
"test_resize_downsample_sizes_nearest",
"test_resize_tf_crop_and_resize",
"test_resize_upsample_sizes_cubic",
"test_resize_upsample_sizes_nearest",
"test_resize_upsample_sizes_nearest_ceil_half_pixel",
"test_resize_upsample_sizes_nearest_floor_align_corners",
"test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
"test_rnn_seq_length",
"test_round",
"test_scan9_sum",
Expand Down

0 comments on commit 05c125b

Please sign in to comment.