From 1f54fd260524ee7b534c6dc239f640ac1e6931f9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 15 Apr 2021 16:54:44 -0600 Subject: [PATCH 01/11] adds rounding mode for nearest neighbor, passing onnx unit tests for nearest neighbor --- include/tvm/relay/attrs/image.h | 6 +++ python/tvm/relay/frontend/onnx.py | 32 +++++++-------- python/tvm/relay/op/dyn/image/_image.py | 10 ++++- python/tvm/relay/op/image/_image.py | 5 ++- python/tvm/relay/op/image/image.py | 11 +++++- python/tvm/topi/image/resize.py | 46 +++++++++++++++++----- src/relay/op/dyn/image/resize.cc | 3 +- src/relay/op/image/resize.cc | 3 +- src/relay/op/make_op.h | 2 +- src/relay/transforms/dynamic_to_static.cc | 3 +- tests/python/frontend/onnx/test_forward.py | 10 +++-- 11 files changed, 92 insertions(+), 39 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index cf5a6eff74bc..f4c09fe3e04d 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -38,6 +38,7 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string layout; std::string method; std::string coordinate_transformation_mode; + std::string rounding_method; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -61,6 +62,11 @@ struct ResizeAttrs : public tvm::AttrsNode { "to the coordinate in the original tensor." "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ffeb0dd73171..26ece366513b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2045,11 +2045,13 @@ class Resize(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - mode = attr.get("mode") - if mode == b"nearest": + mode = attr.get("mode").decode("ascii") + if mode == "nearest": method = "nearest_neighbor" - elif mode == b"linear": + elif mode == "linear": method = "bilinear" + elif mode == "cubic": + method = "bicubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2063,11 +2065,13 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - mode = attr.get("mode") - if mode == b"nearest": + mode = attr.get("mode").decode("ascii") + if mode == "nearest": method = "nearest_neighbor" - elif mode == b"linear": + elif mode == "linear": method = "bilinear" + elif mode == "cubic": + method = "bicubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2084,20 +2088,12 @@ def _impl_v11(cls, inputs, attr, params): assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - coord_trans = attr.get("coordinate_transformation_mode") - if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]: - coord_trans = "half_pixel" - elif coord_trans == b"align_corners": - coord_trans = "align_corners" - elif coord_trans == b"asymmetric" or method == "nearest_neighbor": - coord_trans = "asymmetric" - else: - raise tvm.error.OpAttributeInvalid( - "Unsupported coordinate_transformation_mode: {}".format(coord_trans) - ) + coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") + nearest_mode = attr.get("nearest_mode", "round_prefer_floor") + layout = "NCHW" # ONNX assumes NCHW layout out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, coord_trans) + return _op.image.resize(inputs[0], out_size, layout, method, coord_trans, nearest_mode) class NonZero(OnnxOpConverter): diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index e3415795712e..208d80b48b50 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -31,10 +31,18 @@ def compute_resize(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method out_dtype = attrs.out_dtype return [ tvm.topi.image.resize( - inputs[0], inputs[1], layout, method, coord_trans, out_dtype, out_type.shape + inputs[0], + inputs[1], + layout, + method, + coord_trans, + rounding_method, + out_dtype, + out_type.shape, ) ] diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index ee8a5b3883b1..7e04529bcb34 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -35,8 +35,11 @@ def compute_resize(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method out_dtype = attrs.out_dtype - return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] + return [ + topi.image.resize(inputs[0], size, layout, method, coord_trans, rounding_method, out_dtype) + ] reg.register_injective_schedule("image.resize") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 153439b1e20c..b551bdef255e 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -26,6 +26,7 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", + rounding_method="round", out_dtype=None, ): """Image resize operator. @@ -58,6 +59,10 @@ def resize( Refer to the ONNX Resize operator specification for details. [half_pixel, align_corners, asymmetric] + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -70,9 +75,11 @@ def resize( size = list(size.data.asnumpy().astype("int32")) if isinstance(size, Expr): return _dyn_make.resize( - data, size, layout, method, coordinate_transformation_mode, out_dtype + data, size, layout, method, coordinate_transformation_mode, rounding_method, out_dtype ) - return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) + return _make.resize( + data, size, layout, method, coordinate_transformation_mode, rounding_method, out_dtype + ) def resize3d( diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 433a92008b6e..e2e6400f6fbe 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -71,6 +71,7 @@ def resize_nearest_neighbor( extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", + rounding_method="round", out_dtype=None, ): @@ -120,6 +121,10 @@ def resize_nearest_neighbor( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + out_dtype: string, optional Type to return. If left None will be same as input type. @@ -150,29 +155,48 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width + scale_y = te.div(image_height.astype("float"), target_height.astype("float")) + scale_x = te.div(image_width.astype("float"), target_width.astype("float")) + if coordinate_transformation_mode == "half_pixel": + in_y = (y + 0.5) * scale_y - 0.5 + in_x = (x + 0.5) * scale_x - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_y = y * (image_height - 1).astype("float") / (target_height - 1) + in_x = x * (image_width - 1).astype("float") / (target_width - 1) + elif coordinate_transformation_mode == "asymmetric": + in_y = y * scale_y + in_x = x * scale_x + elif coordinate_transformation_mode in ["pytorch_half_pixel", "tf_half_pixel_for_nn"]: + in_y = (y + 0.5) * scale_y + in_x = (x + 0.5) * scale_x else: raise ValueError( "Unsupported coordinate_transformation_mode: {}".format( coordinate_transformation_mode ) ) - in_y = h_scale * y - in_x = w_scale * x - if coordinate_transformation_mode == "align_corners" or boxes is not None: + if rounding_method == "round" or boxes is not None: closest_x_index = te.round(in_x).astype("int32") closest_y_index = te.round(in_y).astype("int32") - else: + elif rounding_method == "round_prefer_floor": + closest_x_index = te.ceil(in_x - 0.5).astype("int32") + closest_y_index = te.ceil(in_y - 0.5).astype("int32") + elif rounding_method == "round_prefer_ceil": + closest_x_index = te.floor(in_x + 0.5).astype("int32") + closest_y_index = te.floor(in_y + 0.5).astype("int32") + elif rounding_method == "floor": # Add epsilon to floor to prevent gpu rounding errors. epsilon = 1e-5 closest_y_index = te.floor(in_y + epsilon).astype("int32") closest_x_index = te.floor(in_x + epsilon).astype("int32") + elif rounding_method == "ceil": + # Subract epsilon from ceil to prevent gpu rounding errors. + epsilon = 1e-5 + closest_y_index = te.ceil(in_y - epsilon).astype("int32") + closest_x_index = te.ceil(in_x - epsilon).astype("int32") + else: + raise ValueError("Uknown rounding method: {}".format(rounding_method)) value = get_2d_pixel( data, @@ -611,6 +635,7 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", + rounding_method="round", out_dtype=None, output_shape=None, ): @@ -683,6 +708,7 @@ def _nearest_neighbor(*indices): size[1], layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, out_dtype=out_dtype, ) diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 6581250db0cd..54e26bfe0a6c 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -67,11 +67,12 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, DataType out_dtype) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("dyn.image.resize"); return Call(op, {data, size}, Attrs(attrs), {}); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b8875e48ed0f..b73bf17d9284 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -66,12 +66,13 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return Call(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index e5a20abd7624..c81d75cc8694 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -98,7 +98,7 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype); + String coordinate_transformation_mode, String rounding_method, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 0590b41550ce..734f3cf946a6 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -118,7 +118,8 @@ class DynamicToStaticMutator : public MixedModeMutator { size_prim.push_back(size_int[i]); } return MakeResize(call_node->args[0], size_prim, param->layout, param->method, - param->coordinate_transformation_mode, param->out_dtype); + param->coordinate_transformation_mode, param->rounding_method, + param->out_dtype); } return Expr(nullptr); }}, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6d22b5afd0df..f49685afc584 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4186,11 +4186,11 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", "test_resize_downsample_scales_cubic_align_corners/", "test_resize_downsample_scales_linear/", - "test_resize_downsample_scales_nearest/", + # "test_resize_downsample_scales_nearest/", "test_resize_downsample_sizes_cubic/", "test_resize_downsample_sizes_linear_pytorch_half_pixel/", - "test_resize_downsample_sizes_nearest/", - "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", + # "test_resize_downsample_sizes_nearest/", + # "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", "test_resize_tf_crop_and_resize/", "test_resize_upsample_scales_cubic/", "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", @@ -4198,9 +4198,11 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_resize_upsample_scales_cubic_asymmetric/", "test_resize_upsample_scales_linear/", "test_resize_upsample_sizes_cubic/", + ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/", + # ---- "test_reversesequence_batch/", "test_reversesequence_time/", "test_rnn_seq_length/", @@ -4261,6 +4263,8 @@ def test_onnx_nodes(test): outputs.append(numpy_helper.to_array(new_tensor)) else: raise ImportError(str(tensor) + " not labeled as an import or an output") + ort_val = get_onnxruntime_output(onnx_model, inputs) + tvm.testing.assert_allclose(outputs[0], ort_val, rtol=1e-5, atol=1e-5) tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) if len(outputs) == 1: tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) From 38d90ff5ae2fd7fc4ee70961bc4f3ecb1807f25b Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 16 Apr 2021 10:38:02 -0600 Subject: [PATCH 02/11] passing all linear test. passing all nearest tests except crop and resize, which needs a dynamic implementation of crop and resize --- python/tvm/relay/frontend/onnx.py | 10 ++- python/tvm/relay/op/image/image.py | 2 +- python/tvm/topi/image/resize.py | 86 +++++++++++----------- tests/python/frontend/onnx/test_forward.py | 8 -- 4 files changed, 54 insertions(+), 52 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 26ece366513b..f9fdf581b82b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2066,6 +2066,7 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): mode = attr.get("mode").decode("ascii") + layout = "NCHW" # ONNX assumes NCHW layout if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": @@ -2089,9 +2090,14 @@ def _impl_v11(cls, inputs, attr, params): size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") - nearest_mode = attr.get("nearest_mode", "round_prefer_floor") + ## TODO(mbrookhart): Need Dynamic Crop and Resize :( + # if coord_trans == "tf_crop_and_resize": + # extrapolation_value = attr.get("extrapolation_value", 0.0) + # boxes = _op.reshape(inputs[1], [-1, 4]) + # box_indices = fold_constant(_op.take(shape_of(boxes), _op.const(0, "int64"), axis=0)) + # return _op.image.crop_and_resize(inputs[1], boxes, box_indices, size, layout, method, extrapolation_value) - layout = "NCHW" # ONNX assumes NCHW layout + nearest_mode = attr.get("nearest_mode", "round_prefer_floor") out_size = fold_constant(_op.strided_slice(size, [2], [4])) return _op.image.resize(inputs[0], out_size, layout, method, coord_trans, nearest_mode) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index b551bdef255e..914a9aa24671 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -158,7 +158,7 @@ def crop_and_resize( A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that the i-th box refers to. - crop_size : Tuple of Expr + crop_size : Tuple of PrimExpr The target size to which each box will be resized. layout : str, optional diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index e2e6400f6fbe..0f72b62a9be5 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -59,6 +59,33 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") +def get_iny_inx( + y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode +): + scale_y = te.div(image_height.astype("float"), target_height.astype("float")) + scale_x = te.div(image_width.astype("float"), target_width.astype("float")) + if coordinate_transformation_mode == "half_pixel": + in_y = (y + 0.5) * scale_y - 0.5 + in_x = (x + 0.5) * scale_x - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_y = y * (image_height - 1).astype("float") / (target_height - 1) + in_x = x * (image_width - 1).astype("float") / (target_width - 1) + elif coordinate_transformation_mode == "asymmetric": + in_y = y * scale_y + in_x = x * scale_x + elif coordinate_transformation_mode == "pytorch_half_pixel": + in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) + in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) + elif coordinate_transformation_mode == "tf_half_pixel_for_nn": + in_y = (y + 0.5) * scale_y + in_x = (x + 0.5) * scale_x + else: + raise ValueError( + "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) + ) + return in_y, in_x + + def resize_nearest_neighbor( indices, data, @@ -155,26 +182,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - scale_y = te.div(image_height.astype("float"), target_height.astype("float")) - scale_x = te.div(image_width.astype("float"), target_width.astype("float")) - if coordinate_transformation_mode == "half_pixel": - in_y = (y + 0.5) * scale_y - 0.5 - in_x = (x + 0.5) * scale_x - 0.5 - elif coordinate_transformation_mode == "align_corners": - in_y = y * (image_height - 1).astype("float") / (target_height - 1) - in_x = x * (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode == "asymmetric": - in_y = y * scale_y - in_x = x * scale_x - elif coordinate_transformation_mode in ["pytorch_half_pixel", "tf_half_pixel_for_nn"]: - in_y = (y + 0.5) * scale_y - in_x = (x + 0.5) * scale_x - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) if rounding_method == "round" or boxes is not None: closest_x_index = te.round(in_x).astype("int32") @@ -323,25 +339,15 @@ def _lerp(A, B, t): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - if coordinate_transformation_mode == "half_pixel": - in_y = h_scale * (y + 0.5) - 0.5 - in_x = w_scale * (x + 0.5) - 0.5 - else: - in_y = h_scale * y - in_x = w_scale * x + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) top_y_index = te.floor(in_y).astype("int32") bottom_y_index = te.ceil(in_y).astype("int32") @@ -678,7 +684,6 @@ def resize( or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] """ method = method.lower() - if layout == "NHWC": in_n, in_h, in_w, in_c = data.shape if output_shape is None: @@ -802,7 +807,6 @@ def crop_and_resize( method = method.lower() target_h = crop_size[0] target_w = crop_size[1] - if layout == "NHWC": output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]] image_h = data.shape[1].astype("int32") diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f49685afc584..277b31eea13b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4185,18 +4185,12 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_resize_downsample_scales_cubic/", "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", "test_resize_downsample_scales_cubic_align_corners/", - "test_resize_downsample_scales_linear/", - # "test_resize_downsample_scales_nearest/", "test_resize_downsample_sizes_cubic/", - "test_resize_downsample_sizes_linear_pytorch_half_pixel/", - # "test_resize_downsample_sizes_nearest/", - # "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/", "test_resize_tf_crop_and_resize/", "test_resize_upsample_scales_cubic/", "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", "test_resize_upsample_scales_cubic_align_corners/", "test_resize_upsample_scales_cubic_asymmetric/", - "test_resize_upsample_scales_linear/", "test_resize_upsample_sizes_cubic/", ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/", @@ -4263,8 +4257,6 @@ def test_onnx_nodes(test): outputs.append(numpy_helper.to_array(new_tensor)) else: raise ImportError(str(tensor) + " not labeled as an import or an output") - ort_val = get_onnxruntime_output(onnx_model, inputs) - tvm.testing.assert_allclose(outputs[0], ort_val, rtol=1e-5, atol=1e-5) tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0)) if len(outputs) == 1: tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5) From c5b30701d5b3458785e321b42328855779195aad Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 16 Apr 2021 13:30:22 -0600 Subject: [PATCH 03/11] most of the bicubic tests are working --- include/tvm/relay/attrs/image.h | 8 ++ python/tvm/relay/frontend/onnx.py | 19 +-- python/tvm/relay/op/dyn/image/_image.py | 4 + python/tvm/relay/op/image/_image.py | 14 ++- python/tvm/relay/op/image/image.py | 28 ++++- python/tvm/topi/image/resize.py | 130 ++++++++------------- src/relay/op/dyn/image/resize.cc | 5 +- src/relay/op/image/resize.cc | 5 +- src/relay/op/make_op.h | 3 +- src/relay/transforms/dynamic_to_static.cc | 2 +- tests/python/frontend/onnx/test_forward.py | 7 -- 11 files changed, 121 insertions(+), 104 deletions(-) diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index f4c09fe3e04d..baceb04958f0 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -39,6 +39,8 @@ struct ResizeAttrs : public tvm::AttrsNode { std::string method; std::string coordinate_transformation_mode; std::string rounding_method; + double bicubic_alpha; + int bicubic_exclude; DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { @@ -67,6 +69,12 @@ struct ResizeAttrs : public tvm::AttrsNode { .describe( "indicates how to find the \"nearest\" pixel in nearest_neighbor method" "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(bicubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(bicubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f9fdf581b82b..7340d61c6fea 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2065,8 +2065,9 @@ def _impl_v10(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): - mode = attr.get("mode").decode("ascii") layout = "NCHW" # ONNX assumes NCHW layout + + mode = attr.get("mode").decode("ascii") if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": @@ -2078,6 +2079,11 @@ def _impl_v11(cls, inputs, attr, params): 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) ) + coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") + nearest_mode = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii") + alpha = attr.get("cubic_coeff_a", -0.75) + exclude = attr.get("exclude_outside", 0) + scale = inputs[2] scale_shape = infer_shape(scale) if len(inputs) == 4: @@ -2088,18 +2094,17 @@ def _impl_v11(cls, inputs, attr, params): else: assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale + out_size = fold_constant(_op.strided_slice(size, [2], [4])) - coord_trans = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") ## TODO(mbrookhart): Need Dynamic Crop and Resize :( # if coord_trans == "tf_crop_and_resize": # extrapolation_value = attr.get("extrapolation_value", 0.0) # boxes = _op.reshape(inputs[1], [-1, 4]) # box_indices = fold_constant(_op.take(shape_of(boxes), _op.const(0, "int64"), axis=0)) - # return _op.image.crop_and_resize(inputs[1], boxes, box_indices, size, layout, method, extrapolation_value) - - nearest_mode = attr.get("nearest_mode", "round_prefer_floor") - out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, coord_trans, nearest_mode) + # return _op.image.crop_and_resize(inputs[1], boxes, box_indices, out_size, layout, method, extrapolation_value) + return _op.image.resize( + inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude + ) class NonZero(OnnxOpConverter): diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 208d80b48b50..32bd88456ffc 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -32,6 +32,8 @@ def compute_resize(attrs, inputs, out_type): method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method + bicubic_alpha = attrs.bicubic_alpha + bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype return [ tvm.topi.image.resize( @@ -41,6 +43,8 @@ def compute_resize(attrs, inputs, out_type): method, coord_trans, rounding_method, + bicubic_alpha, + bicubic_exclude, out_dtype, out_type.shape, ) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 7e04529bcb34..b1b364dd8036 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -36,9 +36,21 @@ def compute_resize(attrs, inputs, out_type): method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method + bicubic_alpha = attrs.bicubic_alpha + bicubic_exclude = attrs.bicubic_exclude out_dtype = attrs.out_dtype return [ - topi.image.resize(inputs[0], size, layout, method, coord_trans, rounding_method, out_dtype) + topi.image.resize( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, + ) ] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 914a9aa24671..0276a1974dc4 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -27,6 +27,8 @@ def resize( method="bilinear", coordinate_transformation_mode="half_pixel", rounding_method="round", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, ): """Image resize operator. @@ -63,6 +65,12 @@ def resize( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] + bicubic_alpha: float + Spline Coefficient for Bicubic Interpolation + + bicubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -75,10 +83,26 @@ def resize( size = list(size.data.asnumpy().astype("int32")) if isinstance(size, Expr): return _dyn_make.resize( - data, size, layout, method, coordinate_transformation_mode, rounding_method, out_dtype + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, ) return _make.resize( - data, size, layout, method, coordinate_transformation_mode, rounding_method, out_dtype + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + bicubic_alpha, + bicubic_exclude, + out_dtype, ) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 0f72b62a9be5..2799a4ad8218 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -446,10 +446,13 @@ def resize_bicubic( layout="NCHW", coordinate_transformation_mode="align_corners", out_dtype=None, + alpha=-0.5, + exclude_outside=0, ): """Perform resize operation with bicubic method on the data. More details about Bicubic interpolation please refer to https://en.wikipedia.org/wiki/Bicubic_interpolation. + This algorithm is doing a bicubic spline interpolation Parameters ---------- @@ -459,7 +462,7 @@ def resize_bicubic( data : tvm.te.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] + or [:batch, in_height, in_width, channel] image_height : integer Input image height @@ -473,6 +476,7 @@ def resize_bicubic( target_width : integer The target resized image width + boxes : tvm.te.Tensor, optional A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies the coordinates of a box. @@ -496,6 +500,9 @@ def resize_bicubic( out_dtype: string, optional Type to return. If left None will be same as input type. + alpha: float, optional + Bicubic spline coefficient + Returns ------- output : out_dtype @@ -503,9 +510,9 @@ def resize_bicubic( """ def _cubic_kernel(A, B, C, D, t): - a = -A / 2.0 + (3.0 * B) / 2.0 - (3.0 * C) / 2.0 + D / 2.0 - b = A - (5.0 * B) / 2.0 + 2.0 * C - D / 2.0 - c = -A / 2.0 + C / 2.0 + a = -(alpha * (D - A) + (alpha + 2) * (C - B)) + b = 2 * alpha * (C - A) + 3 * (C - B) + alpha * (D - B) + c = -alpha * (C - A) d = B return a * t * t * t + b * t * t + c * t + d @@ -531,25 +538,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - if coordinate_transformation_mode == "align_corners": - h_scale = (image_height - 1).astype("float") / (target_height - 1) - w_scale = (image_width - 1).astype("float") / (target_width - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - h_scale = image_height.astype("float") / target_height - w_scale = image_width.astype("float") / target_width - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - if coordinate_transformation_mode == "half_pixel": - in_y = h_scale * (y + 0.5) - 0.5 - in_x = w_scale * (x + 0.5) - 0.5 - else: - in_y = h_scale * y - in_x = w_scale * x + in_y, in_x = get_iny_inx( + y, + x, + image_height, + image_width, + target_height, + target_width, + coordinate_transformation_mode, + ) xint = te.floor(in_x).astype("int32") xfract = in_x - te.floor(in_x) @@ -557,67 +554,30 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): yint = te.floor(in_y).astype("int32") yfract = in_y - te.floor(in_y) - # 1st row - p00 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint - 1, cc, inum, ic - ) - p10 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 0, cc, inum, ic - ) - p20 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 1, cc, inum, ic - ) - p30 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint - 1, xint + 2, cc, inum, ic - ) - - # 2nd row - p01 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint - 1, cc, inum, ic - ) - p11 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 0, cc, inum, ic - ) - p21 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 1, cc, inum, ic - ) - p31 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 0, xint + 2, cc, inum, ic - ) - - # 3rd row - p02 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint - 1, cc, inum, ic - ) - p12 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 0, cc, inum, ic - ) - p22 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 1, cc, inum, ic - ) - p32 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 1, xint + 2, cc, inum, ic - ) - - # 4th row - p03 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint - 1, cc, inum, ic - ) - p13 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 0, cc, inum, ic - ) - p23 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 1, cc, inum, ic - ) - p33 = _get_pixel( - data, layout, boxes, image_height, image_width, box_idx, c, yint + 2, xint + 2, cc, inum, ic - ) + # Get the surrounding values + p = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + yint + j - 1, + xint + i - 1, + cc, + inum, + ic, + ) # Interpolate bicubically - col0 = _cubic_kernel(p00, p10, p20, p30, xfract) - col1 = _cubic_kernel(p01, p11, p21, p31, xfract) - col2 = _cubic_kernel(p02, p12, p22, p32, xfract) - col3 = _cubic_kernel(p03, p13, p23, p33, xfract) + col0 = _cubic_kernel(*p[0], xfract) + col1 = _cubic_kernel(*p[1], xfract) + col2 = _cubic_kernel(*p[2], xfract) + col3 = _cubic_kernel(*p[3], xfract) value = _cubic_kernel(col0, col1, col2, col3, yfract) # use extrapolation_value if in_y/in_x is out of boundary @@ -642,6 +602,8 @@ def resize( method="bilinear", coordinate_transformation_mode="half_pixel", rounding_method="round", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, output_shape=None, ): @@ -738,9 +700,11 @@ def _bicubic(*indices): in_w, size[0], size[1], - layout, + layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, out_dtype=out_dtype, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, ) # Determine which interpolation method to use then run it. diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 54e26bfe0a6c..87cf89a223ec 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -67,12 +67,15 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + double bicubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; + attrs->bicubic_alpha = bicubic_alpha; + attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("dyn.image.resize"); return Call(op, {data, size}, Attrs(attrs), {}); diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b73bf17d9284..9c3d60198add 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -66,13 +66,16 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, DataType out_dtype) { + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + int bicubic_exclude, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; + attrs->bicubic_alpha = bicubic_alpha; + attrs->bicubic_exclude = bicubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize"); return Call(op, {data}, Attrs(attrs), {}); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index c81d75cc8694..2e59e418f810 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -98,7 +98,8 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, DataType out_dtype); + String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, + int bicubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 734f3cf946a6..ae3676fd58b9 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -119,7 +119,7 @@ class DynamicToStaticMutator : public MixedModeMutator { } return MakeResize(call_node->args[0], size_prim, param->layout, param->method, param->coordinate_transformation_mode, param->rounding_method, - param->out_dtype); + param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); } return Expr(nullptr); }}, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 277b31eea13b..cf069de04a65 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4182,16 +4182,9 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_qlinearmatmul_3D/", "test_range_float_type_positive_delta_expanded/", "test_range_int32_type_negative_delta_expanded/", - "test_resize_downsample_scales_cubic/", "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", - "test_resize_downsample_scales_cubic_align_corners/", - "test_resize_downsample_sizes_cubic/", "test_resize_tf_crop_and_resize/", - "test_resize_upsample_scales_cubic/", "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", - "test_resize_upsample_scales_cubic_align_corners/", - "test_resize_upsample_scales_cubic_asymmetric/", - "test_resize_upsample_sizes_cubic/", ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", From 6e8a2d863bea32d011207134be6e361bd8c71854 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Apr 2021 09:50:36 -0600 Subject: [PATCH 04/11] fix exclude outside --- python/tvm/topi/image/resize.py | 39 +++++++++++++++------- tests/python/frontend/onnx/test_forward.py | 2 -- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 2799a4ad8218..8bf5be008d74 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -509,13 +509,6 @@ def resize_bicubic( The computed result with type out_dtype """ - def _cubic_kernel(A, B, C, D, t): - a = -(alpha * (D - A) + (alpha + 2) * (C - B)) - b = 2 * alpha * (C - A) + 3 * (C - B) + alpha * (D - B) - c = -alpha * (C - A) - d = B - return a * t * t * t + b * t * t + c * t + d - def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: dtype = out_dtype @@ -574,11 +567,33 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): ) # Interpolate bicubically - col0 = _cubic_kernel(*p[0], xfract) - col1 = _cubic_kernel(*p[1], xfract) - col2 = _cubic_kernel(*p[2], xfract) - col3 = _cubic_kernel(*p[3], xfract) - value = _cubic_kernel(col0, col1, col2, col3, yfract) + def _cubic_spline_weights(t): + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + def _cubic_kernel(inputs, w): + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + wx = _cubic_spline_weights(xfract) + wy = _cubic_spline_weights(yfract) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else(te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i]) + wy[i] = te.if_then_else(te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i]) + sum_wx = sum(wx) + sum_wy = sum(wy) + wx = [w / sum_wx for w in wx] + wy = [w / sum_wy for w in wy] + col0 = _cubic_kernel(p[0], wx) + col1 = _cubic_kernel(p[1], wx) + col2 = _cubic_kernel(p[2], wx) + col3 = _cubic_kernel(p[3], wx) + value = _cubic_kernel([col0, col1, col2, col3], wy) # use extrapolation_value if in_y/in_x is out of boundary if extrapolation_value is not None: diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index cf069de04a65..7e170f4145c9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4182,9 +4182,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_qlinearmatmul_3D/", "test_range_float_type_positive_delta_expanded/", "test_range_int32_type_negative_delta_expanded/", - "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/", "test_resize_tf_crop_and_resize/", - "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/", ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/", "test_resize_upsample_sizes_nearest_floor_align_corners/", From 4b06e6e0b3bd6ad6b19b1d4d2353897dd09b4c39 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Apr 2021 09:53:12 -0600 Subject: [PATCH 05/11] remove dead code --- python/tvm/relay/frontend/onnx.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7340d61c6fea..f8ab253a847f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2096,12 +2096,6 @@ def _impl_v11(cls, inputs, attr, params): size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) - ## TODO(mbrookhart): Need Dynamic Crop and Resize :( - # if coord_trans == "tf_crop_and_resize": - # extrapolation_value = attr.get("extrapolation_value", 0.0) - # boxes = _op.reshape(inputs[1], [-1, 4]) - # box_indices = fold_constant(_op.take(shape_of(boxes), _op.const(0, "int64"), axis=0)) - # return _op.image.crop_and_resize(inputs[1], boxes, box_indices, out_size, layout, method, extrapolation_value) return _op.image.resize( inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude ) From ee227814e4e7316cc1ede093576504a7a1a877a1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Apr 2021 10:10:24 -0600 Subject: [PATCH 06/11] fix lint --- python/tvm/relay/op/image/_image.py | 1 + python/tvm/topi/image/resize.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index b1b364dd8036..e23d2026661d 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -31,6 +31,7 @@ # resize @reg.register_compute("image.resize") def compute_resize(attrs, inputs, out_type): + """ compute definition for resize op """ size = attrs.size layout = attrs.layout method = attrs.method diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 8bf5be008d74..289bf674a9ff 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -62,6 +62,7 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, def get_iny_inx( y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode ): + """ Infer input x,y from output x,y with various coordinate transformation methods """ scale_y = te.div(image_height.astype("float"), target_height.astype("float")) scale_x = te.div(image_width.astype("float"), target_width.astype("float")) if coordinate_transformation_mode == "half_pixel": @@ -476,7 +477,6 @@ def resize_bicubic( target_width : integer The target resized image width - boxes : tvm.te.Tensor, optional A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies the coordinates of a box. From 091a30ddd662a6375671df60bbcd0b28d0c3bb30 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Apr 2021 13:21:59 -0600 Subject: [PATCH 07/11] fix defaults to match old implementation --- python/tvm/relay/op/image/image.py | 2 +- python/tvm/topi/image/resize.py | 9 +++++++-- tests/python/relay/test_op_level2.py | 2 +- tests/python/relay/test_op_level5.py | 6 +----- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 0276a1974dc4..dbfe20aee72a 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -26,7 +26,7 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", - rounding_method="round", + rounding_method="", bicubic_alpha=-0.5, bicubic_exclude=0, out_dtype=None, diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 289bf674a9ff..c4dd8d14c741 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -99,7 +99,7 @@ def resize_nearest_neighbor( extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", - rounding_method="round", + rounding_method="", out_dtype=None, ): @@ -161,6 +161,11 @@ def resize_nearest_neighbor( output : out_dtype The computed result with type out_dtype """ + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: @@ -616,7 +621,7 @@ def resize( layout="NCHW", method="bilinear", coordinate_transformation_mode="half_pixel", - rounding_method="round", + rounding_method="", bicubic_alpha=-0.5, bicubic_exclude=0, out_dtype=None, diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index c5843758c3d2..5c594dc6521d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1362,7 +1362,7 @@ def get_shape(): for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) - tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=2e-5) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 466b1b19a582..e9f22ac1dae2 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -69,12 +69,8 @@ def verify_resize(dshape, scale, method, layout, coord_trans): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric", "half_pixel", "align_corners"]: + for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NHWC", "NCHW"]: - # TODO: Topi test does not have a function to produce numpy output for resize with - # nearest_neighbors and align_corners. Enable when topi test has this option - if coord_trans == "align_corners" and method == "nearest_neighbor": - continue verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) From ab7910b7491488a9f7894a2849e1357b3e03c03a Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Apr 2021 14:24:35 -0600 Subject: [PATCH 08/11] fix lint --- tests/python/relay/test_op_level5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e9f22ac1dae2..aa92e10c4d06 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -69,7 +69,7 @@ def verify_resize(dshape, scale, method, layout, coord_trans): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric + for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) From 0addc2bf800ad57c2e886f33b0acc1d352f304e8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 20 Apr 2021 13:14:33 -0600 Subject: [PATCH 09/11] fix gpu tests --- python/tvm/topi/image/resize.py | 7 +++++++ tests/python/frontend/onnx/test_forward.py | 4 +++- tests/python/topi/python/test_topi_image.py | 7 +------ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index c4dd8d14c741..45e3f1fe01de 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -685,6 +685,13 @@ def resize( else: raise ValueError("%s layout is not supported." % layout) + if isinstance(size, tuple): + size = list(size) + + for i in range(2): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) + def _nearest_neighbor(*indices): return resize_nearest_neighbor( indices, diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7e170f4145c9..0d10b364e6e3 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3386,7 +3386,9 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) + verify_with_ort( + model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True, atol=2e-5 + ) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index b766e599c679..3be87bfb31c3 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -92,12 +92,8 @@ def test_resize(): # Scale NHWC + Align Corners verify_resize(6, 32, 64, 64, 20, 20, "NHWC") for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric", "half_pixel", "align_corners"]: + for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric for layout in ["NCHW", "NHWC"]: - # TODO: When topi test has an option for align corners and nearest neighbor that - # produces correct results, re-enable it. - if coord_trans == "align_corners" and method == "nearest_neighbor": - continue verify_resize(4, 16, 32, 32, 50, 50, layout, coord_trans, method=method) @@ -167,7 +163,6 @@ def check_target(target, dev): for target, dev in tvm.testing.enabled_targets(): check_target(target, dev) - @tvm.testing.uses_gpu def test_resize3d(): # Trilinear From 2dacfd5e3d2f22c00420ca46d6395da36edc66ee Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 20 Apr 2021 14:03:53 -0600 Subject: [PATCH 10/11] fix lint again --- tests/python/topi/python/test_topi_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 3be87bfb31c3..7ca46a375906 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -163,6 +163,7 @@ def check_target(target, dev): for target, dev in tvm.testing.enabled_targets(): check_target(target, dev) + @tvm.testing.uses_gpu def test_resize3d(): # Trilinear From d917223da0fbe995de3440b7cd25ebd0707a21b1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 21 Apr 2021 09:04:57 -0600 Subject: [PATCH 11/11] change order of operations to prevent GPU rounding errors --- python/tvm/topi/image/resize.py | 8 ++++---- tests/python/frontend/onnx/test_forward.py | 4 +--- tests/python/relay/test_op_level2.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 45e3f1fe01de..f0d564581d95 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -69,11 +69,11 @@ def get_iny_inx( in_y = (y + 0.5) * scale_y - 0.5 in_x = (x + 0.5) * scale_x - 0.5 elif coordinate_transformation_mode == "align_corners": - in_y = y * (image_height - 1).astype("float") / (target_height - 1) - in_x = x * (image_width - 1).astype("float") / (target_width - 1) + in_y = (image_height - 1).astype("float") / (target_height - 1) * y + in_x = (image_width - 1).astype("float") / (target_width - 1) * x elif coordinate_transformation_mode == "asymmetric": - in_y = y * scale_y - in_x = x * scale_x + in_y = scale_y * y + in_x = scale_x * x elif coordinate_transformation_mode == "pytorch_half_pixel": in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0d10b364e6e3..7e170f4145c9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3386,9 +3386,7 @@ def verify(ishape, oshape, scales, mode, coord_trans): model = helper.make_model(graph, producer_name="resize_test") - verify_with_ort( - model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True, atol=2e-5 - ) + verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) # upsampling verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 5c594dc6521d..c5843758c3d2 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1362,7 +1362,7 @@ def get_shape(): for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) - tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=2e-5) + tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) @tvm.testing.uses_gpu