Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][TOPI][RELAY] Resize refactor #7883

Merged
merged 11 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
std::string layout;
std::string method;
std::string coordinate_transformation_mode;
std::string rounding_method;
double bicubic_alpha;
int bicubic_exclude;
DataType out_dtype;

TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
Expand All @@ -61,6 +64,17 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
TVM_ATTR_FIELD(rounding_method)
.set_default("round")
.describe(
"indicates how to find the \"nearest\" pixel in nearest_neighbor method"
"Available options are round, floor, and ceil.");
TVM_ATTR_FIELD(bicubic_alpha)
.set_default(-0.5)
.describe("Spline Coefficient for Bicubic Interpolation");
TVM_ATTR_FIELD(bicubic_exclude)
.set_default(0)
.describe("Flag to exclude exterior of the image during bicubic interpolation");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
}
};
Expand Down
41 changes: 21 additions & 20 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,11 +2045,13 @@ class Resize(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
mode = attr.get("mode")
if mode == b"nearest":
mode = attr.get("mode").decode("ascii")
if mode == "nearest":
method = "nearest_neighbor"
elif mode == b"linear":
elif mode == "linear":
method = "bilinear"
elif mode == "cubic":
method = "bicubic"
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)
Expand All @@ -2063,16 +2065,25 @@ def _impl_v10(cls, inputs, attr, params):

@classmethod
def _impl_v11(cls, inputs, attr, params):
mode = attr.get("mode")
if mode == b"nearest":
layout = "NCHW" # ONNX assumes NCHW layout

mode = attr.get("mode").decode("ascii")
if mode == "nearest":
method = "nearest_neighbor"
elif mode == b"linear":
elif mode == "linear":
method = "bilinear"
elif mode == "cubic":
method = "bicubic"
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode)
)

coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii")
nearest_mode = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii")
alpha = attr.get("cubic_coeff_a", -0.75)
exclude = attr.get("exclude_outside", 0)

scale = inputs[2]
scale_shape = infer_shape(scale)
if len(inputs) == 4:
Expand All @@ -2083,21 +2094,11 @@ def _impl_v11(cls, inputs, attr, params):
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale

coord_trans = attr.get("coordinate_transformation_mode")
if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
coord_trans = "half_pixel"
elif coord_trans == b"align_corners":
coord_trans = "align_corners"
elif coord_trans == b"asymmetric" or method == "nearest_neighbor":
coord_trans = "asymmetric"
else:
raise tvm.error.OpAttributeInvalid(
"Unsupported coordinate_transformation_mode: {}".format(coord_trans)
)
layout = "NCHW" # ONNX assumes NCHW layout
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)

return _op.image.resize(
inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude
)


class NonZero(OnnxOpConverter):
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relay/op/dyn/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,22 @@ def compute_resize(attrs, inputs, out_type):
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
bicubic_alpha = attrs.bicubic_alpha
bicubic_exclude = attrs.bicubic_exclude
out_dtype = attrs.out_dtype
return [
tvm.topi.image.resize(
inputs[0], inputs[1], layout, method, coord_trans, out_dtype, out_type.shape
inputs[0],
inputs[1],
layout,
method,
coord_trans,
rounding_method,
bicubic_alpha,
bicubic_exclude,
out_dtype,
out_type.shape,
)
]

Expand Down
18 changes: 17 additions & 1 deletion python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,28 @@
# resize
@reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type):
""" compute definition for resize op """
size = attrs.size
layout = attrs.layout
method = attrs.method
coord_trans = attrs.coordinate_transformation_mode
rounding_method = attrs.rounding_method
bicubic_alpha = attrs.bicubic_alpha
bicubic_exclude = attrs.bicubic_exclude
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
return [
topi.image.resize(
inputs[0],
size,
layout,
method,
coord_trans,
rounding_method,
bicubic_alpha,
bicubic_exclude,
out_dtype,
)
]


reg.register_injective_schedule("image.resize")
Expand Down
37 changes: 34 additions & 3 deletions python/tvm/relay/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def resize(
layout="NCHW",
method="bilinear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
bicubic_alpha=-0.5,
bicubic_exclude=0,
out_dtype=None,
):
"""Image resize operator.
Expand Down Expand Up @@ -58,6 +61,16 @@ def resize(
Refer to the ONNX Resize operator specification for details.
[half_pixel, align_corners, asymmetric]
rounding_method: string, optional
indicates how to find the "nearest" pixel in nearest_neighbor method
[round, floor, ceil]
bicubic_alpha: float
Spline Coefficient for Bicubic Interpolation
bicubic_exclude: int
Flag to exclude exterior of the image during bicubic interpolation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a bool rather than int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX is using an Int, so I did this to be consistent, but it might be clearer to use a bool. A wash?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if onnx uses an in then im cool with it.

out_dtype : str, optional
Type to return. If left None returns the same type as input.
Expand All @@ -70,9 +83,27 @@ def resize(
size = list(size.data.asnumpy().astype("int32"))
if isinstance(size, Expr):
return _dyn_make.resize(
data, size, layout, method, coordinate_transformation_mode, out_dtype
data,
size,
layout,
method,
coordinate_transformation_mode,
rounding_method,
bicubic_alpha,
bicubic_exclude,
out_dtype,
)
return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
return _make.resize(
data,
size,
layout,
method,
coordinate_transformation_mode,
rounding_method,
bicubic_alpha,
bicubic_exclude,
out_dtype,
)


def resize3d(
Expand Down Expand Up @@ -151,7 +182,7 @@ def crop_and_resize(
A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that
the i-th box refers to.
crop_size : Tuple of Expr
crop_size : Tuple of PrimExpr
The target size to which each box will be resized.
layout : str, optional
Expand Down
Loading