From 9ed88cc7f9e74f52ab30e3100ce2483de6c3151e Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 21 Apr 2021 15:35:27 -0600 Subject: [PATCH] [ONNX][TOPI][RELAY] Resize refactor (#7883) * adds rounding mode for nearest neighbor, passing onnx unit tests for nearest neighbor * passing all linear test. passing all nearest tests except crop and resize, which needs a dynamic implementation of crop and resize * most of the bicubic tests are working * fix exclude outside * remove dead code * fix lint * fix defaults to match old implementation * fix lint * fix gpu tests * fix lint again * change order of operations to prevent GPU rounding errors --- include/tvm/relay/attrs/image.h | 14 + python/tvm/relay/frontend/onnx.py | 41 +-- python/tvm/relay/op/dyn/image/_image.py | 14 +- python/tvm/relay/op/image/_image.py | 18 +- python/tvm/relay/op/image/image.py | 37 ++- python/tvm/topi/image/resize.py | 271 +++++++++++--------- src/relay/op/dyn/image/resize.cc | 6 +- src/relay/op/image/resize.cc | 6 +- src/relay/op/make_op.h | 3 +- src/relay/transforms/dynamic_to_static.cc | 3 +- tests/python/frontend/onnx/test_forward.py | 17 +- tests/python/relay/test_op_level5.py | 6 +- tests/python/topi/python/test_topi_image.py | 6 +- 13 files changed, 263 insertions(+), 179 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index cf5a6eff74bce..baceb04958f08 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -38,6 +38,9 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string layout; std::string method; std::string coordinate_transformation_mode; + std::string rounding_method; + double bicubic_alpha; + int bicubic_exclude; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -61,6 +64,17 @@ struct ResizeAttrs : public tvm::AttrsNode { "to the coordinate in the original tensor." "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(bicubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(bicubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cc66cd3c6fe83..b8cb1f602656f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2088,11 +2088,13 @@ class Resize(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - mode = attr.get("mode") - if mode == b"nearest": + mode = attr.get("mode").decode("ascii") + if mode == "nearest": method = "nearest_neighbor" - elif mode == b"linear": + elif mode == "linear": method = "bilinear" + elif mode == "cubic": + method = "bicubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2106,16 +2108,25 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - mode = attr.get("mode") - if mode == b"nearest": + layout = "NCHW" # ONNX assumes NCHW layout + + mode = attr.get("mode").decode("ascii") + if mode == "nearest": method = "nearest_neighbor" - elif mode == b"linear": + elif mode == "linear": method = "bilinear" + elif mode == "cubic": + method = "bicubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) ) + coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") + nearest_mode = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii") + alpha = attr.get("cubic_coeff_a", -0.75) + exclude = attr.get("exclude_outside", 0) + scale = inputs[2] scale_shape = infer_shape(scale) if len(inputs) == 4: @@ -2126,21 +2137,11 @@ def _impl_v11(cls, inputs, attr, params): else: assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - - coord_trans = attr.get("coordinate_transformation_mode") - if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: - coord_trans = "half_pixel" - elif coord_trans == b"align_corners": - coord_trans = "align_corners" - elif coord_trans == b"asymmetric" or method == "nearest_neighbor": - coord_trans = "asymmetric" - else: - raise tvm.error.OpAttributeInvalid( - "Unsupported coordinate_transformation_mode: {}".format(coord_trans) - ) - layout = "NCHW" # ONNX assumes NCHW layout out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) + + return _op.image.resize( + inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude + ) class NonZero(OnnxOpConverter): diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index e3415795712ec..32bd88456ffcc 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -31,10 +31,22 @@ def compute_resize(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + bicubic_alpha = attrs.bicubic_alpha + bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype return [ tvm.topi.image.resize( - inputs[0], inputs[1], layout, method, coord_trans, out_dtype, out_type.shape + inputs[0], + inputs[1], + layout, + method, + coord_trans, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + out_type.shape, ) ] diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index ee8a5b3883b15..e23d2026661d6 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -31,12 +31,28 @@ # resize @reg.register_compute("image.resize") def compute_resize(attrs, inputs, out_type): + """ compute definition for resize op """ size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + bicubic_alpha = attrs.bicubic_alpha + bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype - return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] + return [ + topi.image.resize( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + ) + ] reg.register_injective_schedule("image.resize") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 153439b1e20c3..dbfe20aee72ab 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -26,6 +26,9 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, ): """Image resize operator. @@ -58,6 +61,16 @@ def resize( Refer to the ONNX Resize operator specification for details. [half_pixel, align_corners, asymmetric] + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + bicubic_alpha: float + Spline Coefficient for Bicubic Interpolation + + bicubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -70,9 +83,27 @@ def resize( size = list(size.data.asnumpy().astype("int32")) if isinstance(size, Expr): return _dyn_make.resize( - data, size, layout, method, coordinate_transformation_mode, out_dtype + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, ) - return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) + return _make.resize( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + ) def resize3d( @@ -151,7 +182,7 @@ def crop_and_resize( A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that the i-th box refers to. - crop_size : Tuple of Expr + crop_size : Tuple of PrimExpr The target size to which each box will be resized. layout : str, optional diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 433a92008b6e1..f0d564581d957 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -59,6 +59,34 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") +def get_iny_inx( + y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode +): + """ Infer input x,y from output x,y with various coordinate transformation methods """ + scale_y = te.div(image_height.astype("float"), target_height.astype("float")) + scale_x = te.div(image_width.astype("float"), target_width.astype("float")) + if coordinate_transformation_mode == "half_pixel": + in_y = (y + 0.5) * scale_y - 0.5 + in_x = (x + 0.5) * scale_x - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_y = (image_height - 1).astype("float") / (target_height - 1) * y + in_x = (image_width - 1).astype("float") / (target_width - 1) * x + elif coordinate_transformation_mode == "asymmetric": + in_y = scale_y * y + in_x = scale_x * x + elif coordinate_transformation_mode == "pytorch_half_pixel": + in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) + in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) + elif coordinate_transformation_mode == "tf_half_pixel_for_nn": + in_y = (y + 0.5) * scale_y + in_x = (x + 0.5) * scale_x + else: + raise ValueError( + "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) + ) + return in_y, in_x + + def resize_nearest_neighbor( indices, data, @@ -71,6 +99,7 @@ def resize_nearest_neighbor( extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", + rounding_method="", out_dtype=None, ): @@ -120,6 +149,10 @@ def resize_nearest_neighbor( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + out_dtype: string, optional Type to return. If left None will be same as input type. @@ -128,6 +161,11 @@ def resize_nearest_neighbor( output : out_dtype The computed result with type out_dtype """ + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: @@ -150,29 +188,37 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - in_y = h_scale * y - in_x = w_scale * x + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) - if coordinate_transformation_mode == "align_corners" or boxes is not None: + if rounding_method == "round" or boxes is not None: closest_x_index = te.round(in_x).astype("int32") closest_y_index = te.round(in_y).astype("int32") - else: + elif rounding_method == "round_prefer_floor": + closest_x_index = te.ceil(in_x - 0.5).astype("int32") + closest_y_index = te.ceil(in_y - 0.5).astype("int32") + elif rounding_method == "round_prefer_ceil": + closest_x_index = te.floor(in_x + 0.5).astype("int32") + closest_y_index = te.floor(in_y + 0.5).astype("int32") + elif rounding_method == "floor": # Add epsilon to floor to prevent gpu rounding errors. epsilon = 1e-5 closest_y_index = te.floor(in_y + epsilon).astype("int32") closest_x_index = te.floor(in_x + epsilon).astype("int32") + elif rounding_method == "ceil": + # Subract epsilon from ceil to prevent gpu rounding errors. + epsilon = 1e-5 + closest_y_index = te.ceil(in_y - epsilon).astype("int32") + closest_x_index = te.ceil(in_x - epsilon).astype("int32") + else: + raise ValueError("Uknown rounding method: {}".format(rounding_method)) value = get_2d_pixel( data, @@ -299,25 +345,15 @@ def _lerp(A, B, t): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - if coordinate_transformation_mode == "half_pixel": - in_y = h_scale * (y + 0.5) - 0.5 - in_x = w_scale * (x + 0.5) - 0.5 - else: - in_y = h_scale * y - in_x = w_scale * x + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) top_y_index = te.floor(in_y).astype("int32") bottom_y_index = te.ceil(in_y).astype("int32") @@ -416,10 +452,13 @@ def resize_bicubic( layout="NCHW", coordinate_transformation_mode="align_corners", out_dtype=None, + alpha=-0.5, + exclude_outside=0, ): """Perform resize operation with bicubic method on the data. More details about Bicubic interpolation please refer to https://en.wikipedia.org/wiki/Bicubic_interpolation. + This algorithm is doing a bicubic spline interpolation Parameters ---------- @@ -429,7 +468,7 @@ def resize_bicubic( data : tvm.te.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] + or [:batch, in_height, in_width, channel] image_height : integer Input image height @@ -466,19 +505,15 @@ def resize_bicubic( out_dtype: string, optional Type to return. If left None will be same as input type. + alpha: float, optional + Bicubic spline coefficient + Returns ------- output : out_dtype The computed result with type out_dtype """ - def _cubic_kernel(A, B, C, D, t): - a = -A / 2.0 + (3.0 * B) / 2.0 - (3.0 * C) / 2.0 + D / 2.0 - b = A - (5.0 * B) / 2.0 + 2.0 * C - D / 2.0 - c = -A / 2.0 + C / 2.0 - d = B - return a * t * t * t + b * t * t + c * t + d - def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: dtype = out_dtype @@ -501,25 +536,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - if coordinate_transformation_mode == "half_pixel": - in_y = h_scale * (y + 0.5) - 0.5 - in_x = w_scale * (x + 0.5) - 0.5 - else: - in_y = h_scale * y - in_x = w_scale * x + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) xint = te.floor(in_x).astype("int32") xfract = in_x - te.floor(in_x) @@ -527,68 +552,53 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): yint = te.floor(in_y).astype("int32") yfract = in_y - te.floor(in_y) - # 1st row - p00 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint - 1, cc, inum, ic - ) - p10 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 0, cc, inum, ic - ) - p20 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 1, cc, inum, ic - ) - p30 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 2, cc, inum, ic - ) - - # 2nd row - p01 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint - 1, cc, inum, ic - ) - p11 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 0, cc, inum, ic - ) - p21 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 1, cc, inum, ic - ) - p31 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 2, cc, inum, ic - ) - - # 3rd row - p02 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint - 1, cc, inum, ic - ) - p12 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 0, cc, inum, ic - ) - p22 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 1, cc, inum, ic - ) - p32 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 2, cc, inum, ic - ) - - # 4th row - p03 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint - 1, cc, inum, ic - ) - p13 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 0, cc, inum, ic - ) - p23 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 1, cc, inum, ic - ) - p33 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 2, cc, inum, ic - ) + # Get the surrounding values + p = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + yint + j - 1, + xint + i - 1, + cc, + inum, + ic, + ) # Interpolate bicubically - col0 = _cubic_kernel(p00, p10, p20, p30, xfract) - col1 = _cubic_kernel(p01, p11, p21, p31, xfract) - col2 = _cubic_kernel(p02, p12, p22, p32, xfract) - col3 = _cubic_kernel(p03, p13, p23, p33, xfract) - value = _cubic_kernel(col0, col1, col2, col3, yfract) + def _cubic_spline_weights(t): + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + def _cubic_kernel(inputs, w): + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + wx = _cubic_spline_weights(xfract) + wy = _cubic_spline_weights(yfract) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else(te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i]) + wy[i] = te.if_then_else(te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i]) + sum_wx = sum(wx) + sum_wy = sum(wy) + wx = [w / sum_wx for w in wx] + wy = [w / sum_wy for w in wy] + col0 = _cubic_kernel(p[0], wx) + col1 = _cubic_kernel(p[1], wx) + col2 = _cubic_kernel(p[2], wx) + col3 = _cubic_kernel(p[3], wx) + value = _cubic_kernel([col0, col1, col2, col3], wy) # use extrapolation_value if in_y/in_x is out of boundary if extrapolation_value is not None: @@ -611,6 +621,9 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, output_shape=None, ): @@ -653,7 +666,6 @@ def resize( or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] """ method = method.lower() - if layout == "NHWC": in_n, in_h, in_w, in_c = data.shape if output_shape is None: @@ -673,6 +685,13 @@ def resize( else: raise ValueError("%s layout is not supported." % layout) + if isinstance(size, tuple): + size = list(size) + + for i in range(2): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) + def _nearest_neighbor(*indices): return resize_nearest_neighbor( indices, @@ -683,6 +702,7 @@ def _nearest_neighbor(*indices): size[1], layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, out_dtype=out_dtype, ) @@ -707,9 +727,11 @@ def _bicubic(*indices): in_w, size[0], size[1], - layout, + layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, out_dtype=out_dtype, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, ) # Determine which interpolation method to use then run it. @@ -776,7 +798,6 @@ def crop_and_resize( method = method.lower() target_h = crop_size[0] target_w = crop_size[1] - if layout == "NHWC": output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]] image_h = data.shape[1].astype("int32") diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 6581250db0cde..87cf89a223ecc 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -67,11 +67,15 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + double bicubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->bicubic_alpha = bicubic_alpha; + attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("dyn.image.resize"); return Call(op, {data, size}, Attrs(attrs), {}); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b8875e48ed0fa..9c3d60198add1 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -66,12 +66,16 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + int bicubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->bicubic_alpha = bicubic_alpha; + attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return Call(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index a3a182515b6d3..cc1ff44952efe 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -98,7 +98,8 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype); + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + int bicubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 8bfe1d83bd9ef..7c947ba109bfe 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -118,7 +118,8 @@ class DynamicToStaticMutator : public MixedModeMutator { size_prim.push_back(size_int[i]); } return MakeResize(call_node->args[0], size_prim, param->layout, param->method, - param->coordinate_transformation_mode, param->out_dtype); + param->coordinate_transformation_mode, param->rounding_method, + param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); } return Expr(nullptr); }}, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 595a3b1c89b32..783408e7c6d95 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4183,25 +4183,12 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_qlinearmatmul_3D/", "test_range_float_type_positive_delta_expanded/", "test_range_int32_type_negative_delta_expanded/", - "test_resize_downsample_scales_cubic/", - "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", - "test_resize_downsample_scales_cubic_align_corners/", - "test_resize_downsample_scales_linear/", - "test_resize_downsample_scales_nearest/", - "test_resize_downsample_sizes_cubic/", - "test_resize_downsample_sizes_linear_pytorch_half_pixel/", - "test_resize_downsample_sizes_nearest/", - "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", "test_resize_tf_crop_and_resize/", - "test_resize_upsample_scales_cubic/", - "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", - "test_resize_upsample_scales_cubic_align_corners/", - "test_resize_upsample_scales_cubic_asymmetric/", - "test_resize_upsample_scales_linear/", - "test_resize_upsample_sizes_cubic/", + ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", + # ---- "test_reversesequence_batch/", "test_reversesequence_time/", "test_rnn_seq_length/", diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 466b1b19a582d..aa92e10c4d064 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -69,12 +69,8 @@ def verify_resize(dshape, scale, method, layout, coord_trans): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric", "half_pixel", "align_corners"]: + for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NHWC", "NCHW"]: - # TODO: Topi test does not have a function to produce numpy output for resize with - # nearest_neighbors and align_corners. Enable when topi test has this option - if coord_trans == "align_corners" and method == "nearest_neighbor": - continue verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index b766e599c6794..7ca46a3759065 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -92,12 +92,8 @@ def test_resize(): # Scale NHWC + Align Corners verify_resize(6, 32, 64, 64, 20, 20, "NHWC") for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric", "half_pixel", "align_corners"]: + for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NCHW", "NHWC"]: - # TODO: When topi test has an option for align corners and nearest neighbor that - # produces correct results, re-enable it. - if coord_trans == "align_corners" and method == "nearest_neighbor": - continue verify_resize(4, 16, 32, 32, 50, 50, layout, coord_trans, method=method)