From d3fc562a6f3b8cd4d0a5f86e1e3ebc503ebeba2b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 5 Jul 2021 04:09:26 -0600 Subject: [PATCH] [Relay][TOPI] Resize 1D (#8346) * rename resize to resize2d * refactor resize_2d * Add resize1d op, normalize attribute names across ops * normalize resize3d to match the API of 1D and 2D * fix lint * fix relay tests from API change * refactor topi tests, docs * fix method naming in framework frontends fix more frontend issues * refactor resize tests to reuse components, add more coordinate tranform modes to tests * add cubic resize reference kernel and tests, add relay tests for resize1d * fix pylint * fix test typo --- docs/langref/relay_op.rst | 4 +- include/tvm/relay/attrs/image.h | 107 +- python/tvm/relay/frontend/keras.py | 1 + python/tvm/relay/frontend/mxnet.py | 2 +- python/tvm/relay/frontend/onnx.py | 59 +- python/tvm/relay/frontend/pytorch.py | 16 +- python/tvm/relay/frontend/tensorflow_ops.py | 6 +- python/tvm/relay/frontend/tflite.py | 6 +- python/tvm/relay/op/dyn/image/_image.py | 26 +- python/tvm/relay/op/image/_image.py | 149 ++- python/tvm/relay/op/image/image.py | 148 +- python/tvm/relay/op/op_attrs.py | 21 +- python/tvm/topi/image/resize.py | 1191 ++++++++++------- python/tvm/topi/nn/upsampling.py | 6 +- python/tvm/topi/testing/__init__.py | 4 +- .../topi/testing/bilinear_resize_python.py | 105 -- python/tvm/topi/testing/resize_python.py | 294 ++++ .../topi/testing/trilinear_resize3d_python.py | 111 -- python/tvm/topi/testing/upsampling_python.py | 136 -- python/tvm/topi/utils.py | 10 + src/relay/op/dyn/image/resize.cc | 30 +- src/relay/op/image/resize.cc | 129 +- src/relay/op/make_op.h | 6 +- src/relay/transforms/dynamic_to_static.cc | 10 +- tests/python/frontend/coreml/test_forward.py | 11 +- tests/python/frontend/onnx/test_forward.py | 97 +- .../relay/dyn/test_dynamic_op_level2.py | 33 +- .../relay/dyn/test_dynamic_op_level5.py | 25 +- tests/python/relay/test_any.py | 8 +- tests/python/relay/test_op_level2.py | 34 +- tests/python/relay/test_op_level5.py | 89 +- .../relay/test_pass_convert_op_layout.py | 20 +- .../relay/test_pass_dynamic_to_static.py | 29 +- tests/python/topi/python/test_topi_image.py | 74 +- .../topi/python/test_topi_upsampling.py | 32 +- 35 files changed, 1813 insertions(+), 1216 deletions(-) delete mode 100644 python/tvm/topi/testing/bilinear_resize_python.py create mode 100644 python/tvm/topi/testing/resize_python.py delete mode 100644 python/tvm/topi/testing/trilinear_resize3d_python.py delete mode 100644 python/tvm/topi/testing/upsampling_python.py diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index febe542b83b1..3e797fc93b31 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -181,7 +181,9 @@ This level enables additional math and transform operators. .. autosummary:: :nosignatures: - tvm.relay.image.resize + tvm.relay.image.resize1d + tvm.relay.image.resize2d + tvm.relay.image.resize3d tvm.relay.image.crop_and_resize tvm.relay.image.dilation2d tvm.relay.vision.multibox_prior diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index baceb04958f0..b851add61e4a 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -32,31 +32,74 @@ namespace tvm { namespace relay { -/*! \brief Attributes used in image resize operator */ -struct ResizeAttrs : public tvm::AttrsNode { +/*! \brief Attributes used in image resize1d operator */ +struct Resize1DAttrs : public tvm::AttrsNode { Array size; std::string layout; std::string method; std::string coordinate_transformation_mode; std::string rounding_method; - double bicubic_alpha; - int bicubic_exclude; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; - TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { + TVM_DECLARE_ATTRS(Resize1DAttrs, "relay.attrs.Resize1DAttrs") { + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel and width" + "dimensions respectively. Resize is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Linear Interpolation" + "cubic - Cubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for cubic interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during cubic interpolation"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + } +}; + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array size; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relay.attrs.Resize2DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCHW").describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Resize is applied on the 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("bilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( @@ -69,10 +112,10 @@ struct ResizeAttrs : public tvm::AttrsNode { .describe( "indicates how to find the \"nearest\" pixel in nearest_neighbor method" "Available options are round, floor, and ceil."); - TVM_ATTR_FIELD(bicubic_alpha) + TVM_ATTR_FIELD(cubic_alpha) .set_default(-0.5) .describe("Spline Coefficient for Bicubic Interpolation"); - TVM_ATTR_FIELD(bicubic_exclude) + TVM_ATTR_FIELD(cubic_exclude) .set_default(0) .describe("Flag to exclude exterior of the image during bicubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); @@ -80,32 +123,46 @@ struct ResizeAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in image resize3d operator */ -struct Resize3dAttrs : public tvm::AttrsNode { +struct Resize3DAttrs : public tvm::AttrsNode { Array size; - String layout; - String method; - String coordinate_transformation_mode; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + std::string rounding_method; + double cubic_alpha; + int cubic_exclude; DataType out_dtype; - TVM_DECLARE_ATTRS(Resize3dAttrs, "relay.attrs.Resize3dAttrs") { + TVM_DECLARE_ATTRS(Resize3DAttrs, "relay.attrs.Resize3DAttrs") { TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Resize3d is applied on the 'D', 'H' and" "'W' dimensions."); - TVM_ATTR_FIELD(method) - .set_default("trilinear") - .describe( - "Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Trilinear Interpolation" + "cubic - Tricubic Interpolation"); TVM_ATTR_FIELD(coordinate_transformation_mode) .set_default("half_pixel") .describe( "Describes how to transform the coordinate in the resized tensor" "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Tricubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during tricubic interpolation"); TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 63521a67b065..aa185923d02e 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -725,6 +725,7 @@ def _convert_upsample3d(inexpr, keras_layer, etab): params["scale_h"] = h params["scale_w"] = w params["layout"] = etab.data_layout + params["coordinate_transformation_mode"] = "asymmetric" out = _op.nn.upsampling3d(inexpr, **params) return out diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3b940bd15f5b..59b4e99de999 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -963,7 +963,7 @@ def _mx_resize(inputs, attrs): if scale_width is not None: width = (scale_width * shape[3]).astype("int32") size = (height, width) - return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") + return _op.image.resize2d(inputs[0], size, coordinate_transformation_mode="align_corners") def _mx_amp_multicast(inputs, attrs): diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7135fccdf43b..c3108ff890b1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -457,6 +457,7 @@ def _impl_v1(cls, inputs, attr, params): kernel_type = infer_type(inputs[1]) kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)] + if "kernel_shape" not in attr: attr["kernel_shape"] = kernel_shapes[0][2:] @@ -1199,7 +1200,13 @@ def _impl_v9(cls, inputs, attr, params): layout = "NCDHW" out = _op.nn.upsampling3d( - inputs[0], scale_d, scale_h, scale_w, layout=layout, method=method + inputs[0], + scale_d, + scale_h, + scale_w, + layout=layout, + method=method, + coordinate_transformation_mode="asymmetric", ) # in 2d case, use dynamic op else: @@ -2388,9 +2395,9 @@ def _impl_v10(cls, inputs, attr, params): if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2398,21 +2405,31 @@ def _impl_v10(cls, inputs, attr, params): scale = inputs[1] size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale - layout = "NCHW" # ONNX assumes NCHW layout - out_size = fold_constant(_op.strided_slice(size, [2], [4])) - return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric") + ndims = len(infer_shape(inputs[0])) + out = None + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + out = _op.image.resize1d(inputs[0], out_size, "NCW", method, "asymmetric") + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = _op.image.resize2d(inputs[0], out_size, "NCHW", method, "asymmetric") + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + out = _op.image.resize3d(inputs[0], out_size, "NCDHW", method, "asymmetric") + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") + return out @classmethod def _impl_v11(cls, inputs, attr, params): - layout = "NCHW" # ONNX assumes NCHW layout - + ndims = len(infer_shape(inputs[0])) mode = attr.get("mode").decode("ascii") if mode == "nearest": method = "nearest_neighbor" elif mode == "linear": - method = "bilinear" + method = "linear" elif mode == "cubic": - method = "bicubic" + method = "cubic" else: raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "mode" of operator Resize is not valid.'.format(mode) @@ -2434,10 +2451,26 @@ def _impl_v11(cls, inputs, attr, params): assert len(scale_shape) != 0, "One of scale or size should be passed." size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = None + if ndims == 3: + out_size = fold_constant(_op.strided_slice(size, [2], [3])) + out = _op.image.resize1d( + inputs[0], out_size, "NCW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 4: + out_size = fold_constant(_op.strided_slice(size, [2], [4])) + out = _op.image.resize2d( + inputs[0], out_size, "NCHW", method, coord_trans, nearest_mode, alpha, exclude + ) + elif ndims == 5: + out_size = fold_constant(_op.strided_slice(size, [2], [5])) + out = _op.image.resize3d( + inputs[0], out_size, "NCDHW", method, coord_trans, nearest_mode, alpha, exclude + ) + else: + raise NotImplementedError("Resize only supports 3, 4, or 5 dims") - return _op.image.resize( - inputs[0], out_size, layout, method, coord_trans, nearest_mode, alpha, exclude - ) + return out class NonZero(OnnxOpConverter): diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 909b8049d46a..6bcedbfebab3 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1798,7 +1798,7 @@ def get_upsample_out_size(self, inputs, method): else: out_size.append(size) else: - scale_index = 3 if method in ["bilinear", "trilinear"] else 2 + scale_index = 3 if method == "linear" else 2 scales = inputs[scale_index] assert scales is not None, "neither out size nor scale provided" assert isinstance(scales, list) @@ -1813,7 +1813,7 @@ def upsample(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "bilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -1826,7 +1826,7 @@ def upsample(inputs, input_types): coord_trans = "half_pixel" def func(x): - return _op.image.resize(x, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(x, out_size, "NCHW", method, coord_trans) if self.is_quantized_tensor(data): # input qparams are manually appended by us @@ -1845,7 +1845,7 @@ def upsample3d(inputs, input_types): data = inputs[0] out_size = self.get_upsample_out_size(inputs, method) - if len(inputs) > 2 and method == "trilinear": + if len(inputs) > 2 and method == "linear": align_corners = inputs[2] else: align_corners = False @@ -2192,6 +2192,8 @@ def interpolate(self, inputs, input_types): method = inputs[3] if method.startswith("nearest"): method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] if method == "nearest_neighbor": coord_trans = "asymmetric" @@ -2200,7 +2202,7 @@ def interpolate(self, inputs, input_types): else: coord_trans = "half_pixel" - return _op.image.resize(data, out_size, "NCHW", method, coord_trans) + return _op.image.resize2d(data, out_size, "NCHW", method, coord_trans) def numel(self, inputs, input_types): return _op.ndarray_size(inputs[0]) @@ -2475,9 +2477,9 @@ def create_convert_map(self): "aten::clamp": self.clamp, "aten::clamp_": self.clamp, "aten::detach": self.identity, - "aten::upsample_bilinear2d": self.make_upsample("bilinear"), + "aten::upsample_bilinear2d": self.make_upsample("linear"), "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"), - "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"), + "aten::upsample_trilinear3d": self.make_upsample3d("linear"), "aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"), "aten::expand_as": self.expand_as, "aten::lt": self.make_elemwise("less"), diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 004174f076fd..797ff51ace7a 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1075,7 +1075,7 @@ def _impl(inputs, attr, params, mod): # Ignore the new attributes from TF2.0, for now. return AttrCvt( - op_name="resize", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} + op_name="resize2d", ignores=["Tdim", "half_pixel_centers"], extras={"method": method} )(inputs, attr) return _impl @@ -2943,8 +2943,8 @@ def _impl(inputs, attr, params, mod): "Relu": AttrCvt("relu"), "Relu6": _relu6(), "Reshape": _reshape(), - "ResizeBicubic": _resize("bilinear"), - "ResizeBilinear": _resize("bilinear"), + "ResizeBicubic": _resize("cubic"), + "ResizeBilinear": _resize("linear"), "ResizeNearestNeighbor": _resize("nearest_neighbor"), "ReverseV2": _reverse_v2(), "RightShift": AttrCvt("right_shift"), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a47fdf0141b5..42096ad9af2f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -630,7 +630,7 @@ def _convert_resize(self, method, op): # Options - align_corners (bool) resize_options = None align_corners = False - bilinear_method = method == "bilinear" + bilinear_method = method == "linear" if bilinear_method: assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions resize_options = ResizeBilinearOptions() @@ -647,7 +647,7 @@ def _convert_resize(self, method, op): coord_trans = "align_corners" if align_corners else "asymmetric" if bilinear_method and input_tensor.qnn_params: in_expr = self.dequantize(in_expr, input_tensor) - out = _op.image.resize( + out = _op.image.resize2d( in_expr, target_size, "NHWC", method, coordinate_transformation_mode=coord_trans ) if bilinear_method and output_tensor.qnn_params: @@ -656,7 +656,7 @@ def _convert_resize(self, method, op): def convert_resize_bilinear(self, op): """Convert TFLite RESIZE_BILINEAR""" - return self._convert_resize("bilinear", op) + return self._convert_resize("linear", op) def convert_resize_nearest_neighbor(self, op): """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index 32bd88456ffc..5e97d2461100 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -26,36 +26,36 @@ # resize -@reg.register_compute("dyn.image.resize") -def compute_resize(attrs, inputs, out_type): +@reg.register_compute("dyn.image.resize2d") +def compute_resize2d(attrs, inputs, out_type): layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ - tvm.topi.image.resize( + tvm.topi.image.resize2d( inputs[0], inputs[1], layout, method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, out_type.shape, ) ] -reg.register_injective_schedule("dyn.image.resize") +reg.register_injective_schedule("dyn.image.resize2d") @script -def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): +def _resize2d_shape_func(dshape, size, ndim, height_axis, width_axis): out = output_tensor((ndim,), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) @@ -64,15 +64,15 @@ def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): return out -@reg.register_shape_func("dyn.image.resize", True) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("dyn.image.resize2d", True) +def resize2d_shape_func(attrs, inputs, _): """ Shape function for dyn.image.resize op. """ layout = attrs.layout if nchw_pack_layout(layout) or nchw_xc_layout(layout): out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), convert(2), convert(3) ) ] @@ -84,7 +84,7 @@ def resize_shape_func(attrs, inputs, _): if letter == "W": width_axis = i out = [ - _resize_shape_func( + _resize2d_shape_func( inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 2071a43f828b..ec24ff76b90e 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,42 +26,42 @@ from .. import op as reg from .. import strategy from ..op import OpPattern -from .image import resize +from .image import resize2d # resize -@reg.register_compute("image.resize") -def compute_resize(attrs, inputs, out_type): - """compute definition for resize op""" +@reg.register_compute("image.resize1d") +def compute_resize1d(attrs, inputs, out_type): + """compute definition for resize1d op""" size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode rounding_method = attrs.rounding_method - bicubic_alpha = attrs.bicubic_alpha - bicubic_exclude = attrs.bicubic_exclude + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype return [ - topi.image.resize( + topi.image.resize1d( inputs[0], size, layout, method, coord_trans, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) ] -reg.register_injective_schedule("image.resize") +reg.register_injective_schedule("image.resize1d") -@reg.register_convert_op_layout("image.resize") -def convert_image_resize(attrs, inputs, tinfos, desired_layouts): - """Convert Layout pass registration for image resize op. +@reg.register_convert_op_layout("image.resize1d") +def convert_image_resize1d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize1d op. Parameters ---------- @@ -86,11 +86,104 @@ def convert_image_resize(attrs, inputs, tinfos, desired_layouts): desired_layout = str(desired_layouts[0]) assert desired_layout != "default", "Layout cannot be default" new_attrs["layout"] = desired_layout - return resize(*inputs, **new_attrs) + return resize1d(*inputs, **new_attrs) @script -def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): +def _resize1d_shape_func(image_shape, size, batch_axis, width_axis, channel_axis): + out = output_tensor((3,), "int64") + out[batch_axis] = int64(image_shape[0]) + out[width_axis] = int64(size[1]) + out[channel_axis] = image_shape[channel_axis] + return out + + +@reg.register_shape_func("image.resize1d", False) +def resize1d_shape_func(attrs, inputs, _): + """ + Shape function for resize2d op. + """ + layout = attrs.layout + width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "N": + batch_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + size = get_const_tuple(attrs.size) + return [ + _resize1d_shape_func( + inputs[0], + convert(size), + convert(batch_axis), + convert(width_axis), + convert(channel_axis), + ) + ] + + +@reg.register_compute("image.resize2d") +def compute_resize2d(attrs, inputs, out_type): + """compute definition for resize2d op""" + size = attrs.size + layout = attrs.layout + method = attrs.method + coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude + out_dtype = attrs.out_dtype + return [ + topi.image.resize2d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + ] + + +reg.register_injective_schedule("image.resize2d") + + +@reg.register_convert_op_layout("image.resize2d") +def convert_image_resize2d(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize2d op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize2d(*inputs, **new_attrs) + + +@script +def _resize2d_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") out[batch_axis] = int64(image_shape[0]) out[height_axis] = int64(size[0]) @@ -99,10 +192,10 @@ def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, c return out -@reg.register_shape_func("image.resize", False) -def resize_shape_func(attrs, inputs, _): +@reg.register_shape_func("image.resize2d", False) +def resize2d_shape_func(attrs, inputs, _): """ - Shape function for resize op. + Shape function for resize2d op. """ layout = attrs.layout height_axis = width_axis = channel_axis = 1 @@ -117,7 +210,7 @@ def resize_shape_func(attrs, inputs, _): channel_axis = i size = get_const_tuple(attrs.size) return [ - _resize_shape_func( + _resize2d_shape_func( inputs[0], convert(size), convert(batch_axis), @@ -130,12 +223,28 @@ def resize_shape_func(attrs, inputs, _): @reg.register_compute("image.resize3d") def compute_resize3d(attrs, inputs, out_type): + """compute definition for resize3d op""" size = attrs.size layout = attrs.layout method = attrs.method coord_trans = attrs.coordinate_transformation_mode + rounding_method = attrs.rounding_method + cubic_alpha = attrs.cubic_alpha + cubic_exclude = attrs.cubic_exclude out_dtype = attrs.out_dtype - return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)] + return [ + topi.image.resize3d( + inputs[0], + size, + layout, + method, + coord_trans, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + ] reg.register_injective_schedule("image.resize3d") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 6d7d79264844..7f5bd80159f9 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -20,18 +20,94 @@ from ...expr import Expr, Constant -def resize( +def resize1d( + data, + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.5, + cubic_exclude=0, + out_dtype=None, +): + """Image resize1d operator. + + This operator takes data as input and does 1D scaling to the given scale factor. + In the default case, where the data_layout is `NCW` + with data of shape (n, c, w) + out will have a shape (n, c, size[0]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Int or Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + [half_pixel, align_corners, asymmetric] + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for cubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The resized result. + """ + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize1d is not yet implemented, got size", size) + return _make.resize1d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) + + +def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, - bicubic_exclude=0, + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): - """Image resize operator. + """Image resize2d operator. This operator takes data as input and does 2D scaling to the given scale factor. In the default case, where the data_layout is `NCHW` @@ -39,7 +115,7 @@ def resize( out will have a shape (n, c, size[0], size[1]) method indicates the algorithm to be used while calculating the out value - and method can be one of ("bilinear", "nearest_neighbor", "bicubic") + and method can be one of ("linear", "nearest_neighbor", "cubic") Parameters ---------- @@ -53,7 +129,7 @@ def resize( Layout of the input. method : str, optional - Scale method to used [nearest_neighbor, bilinear, bicubic]. + Scale method to used [nearest_neighbor, linear, cubic]. coordinate_transformation_mode : string, optional Describes how to transform the coordinate in the resized tensor @@ -65,10 +141,10 @@ def resize( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] - bicubic_alpha: float - Spline Coefficient for Bicubic Interpolation + cubic_alpha: float + Spline Coefficient for bicubic interpolation - bicubic_exclude: int + cubic_exclude: int Flag to exclude exterior of the image during bicubic interpolation out_dtype : str, optional @@ -82,26 +158,26 @@ def resize( if isinstance(size, Constant): size = list(size.data.numpy().astype("int32")) if isinstance(size, Expr): - return _dyn_make.resize( + return _dyn_make.resize2d( data, size, layout, method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) - return _make.resize( + return _make.resize2d( data, size, layout, method, coordinate_transformation_mode, rounding_method, - bicubic_alpha, - bicubic_exclude, + cubic_alpha, + cubic_exclude, out_dtype, ) @@ -110,11 +186,14 @@ def resize3d( data, size, layout="NCDHW", - method="trilinear", + method="linear", coordinate_transformation_mode="half_pixel", + rounding_method="", + cubic_alpha=-0.5, + cubic_exclude=0, out_dtype=None, ): - """Image resize 3D operator. + """Image resize3d operator. This operator takes data as input and does 3D scaling to the given scale factor. In the default case, where the data_layout is `NCDHW` @@ -122,27 +201,38 @@ def resize3d( out will have a shape `(n, c, size[0], size[1], size[2])` method indicates the algorithm to be used while calculating the out value - and method can be one of ("trilinear", "nearest_neighbor") + and method can be one of ("linear", "nearest_neighbor", "cubic") Parameters ---------- data : relay.Expr The input data to the operator. - size: Tuple of Expr + size: Tuple of Int or Expr The out size to which the image will be resized. layout : str, optional Layout of the input. method : str, optional - Scale method to used [nearest_neighbor, trilinear]. + Scale method to used [nearest_neighbor, linear, cubic]. coordinate_transformation_mode : string, optional Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. [half_pixel, align_corners, asymmetric] + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for cubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during cubic interpolation + out_dtype : str, optional Type to return. If left None returns the same type as input. @@ -151,7 +241,21 @@ def resize3d( result: relay.Expr The resized result. """ - return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype) + if isinstance(size, Constant): + size = list(size.data.numpy().astype("int32")) + if isinstance(size, Expr): + raise NotImplementedError("dyn.resize3d is not yet implemented, got size", size) + return _make.resize3d( + data, + size, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + out_dtype, + ) def crop_and_resize( diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 780badc89fc4..2d185bcee798 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -139,9 +139,19 @@ class DeformableConv2DAttrs(Attrs): """Attributes for nn.deformable_conv2d""" -@tvm._ffi.register_object("relay.attrs.ResizeAttrs") -class ResizeAttrs(Attrs): - """Attributes for image.resize""" +@tvm._ffi.register_object("relay.attrs.Resize1DAttrs") +class Resize1DAttrs(Attrs): + """Attributes for image.resize1d""" + + +@tvm._ffi.register_object("relay.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes for image.resize2d""" + + +@tvm._ffi.register_object("relay.attrs.Resize3DAttrs") +class Resize3DAttrs(Attrs): + """Attributes used in resize3d operators""" @tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs") @@ -499,11 +509,6 @@ class RequantizeAttrs(Attrs): """Attributes used in requantize operators""" -@tvm._ffi.register_object("relay.attrs.Resize3dAttrs") -class Resize3dAttrs(Attrs): - """Attributes used in resize3d operators""" - - @tvm._ffi.register_object("relay.attrs.ScatterAttrs") class ScatterAttrs(Attrs): """Attributes used in scatter operators""" diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 42d0455665a1..5d9d96036282 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -23,6 +23,25 @@ from .. import tag +def get_1d_indices(indices, layout="NCW"): + """Get 1d indices""" + (cc, inum, ic) = (0, 0, 0) + if layout == "NWC": + n, x, c = indices + cc = None + elif layout == "NCW": + n, c, x = indices + cc = None + elif ncw_pack_layout(layout): + n, c, x, inum, ic = indices + else: + # else must be NCHWxc + assert ncw_xc_layout(layout) + n, c, x, cc = indices + + return n, c, x, cc, inum, ic + + def get_2d_indices(indices, layout="NCHW"): """Get 2d indices""" (cc, inum, ic) = (0, 0, 0) @@ -42,6 +61,36 @@ def get_2d_indices(indices, layout="NCHW"): return n, c, y, x, cc, inum, ic +def get_3d_indices(indices, layout="NCDHW"): + """Get 3d indices""" + if layout == "NDHWC": + n, z, y, x, c = indices + cc = None + elif layout == "NCDHW": + n, c, z, y, x = indices + cc = None + else: + n, c, z, y, x, cc = indices + + return n, c, z, y, x, cc + + +def get_1d_pixel(data, layout, boxes, image_width, n, c, x, cc, ib, ic): + """Get 1d pixel""" + if boxes is None: + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NWC": + return data(n, x, c).astype("float") + if layout == "NCW": + return data(n, c, x).astype("float") + if ncw_pack_layout(layout): + return data(n, c, x, ib, ic).astype("float") + + # else must be NCHWxc + assert ncw_xc_layout(layout) + return data(n, c, x, cc).astype("float") + + def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic): """Get 2d pixel""" if boxes is None: @@ -59,53 +108,99 @@ def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, return data(n, c, y, x, cc).astype("float") -def get_iny_inx( - y, x, image_height, image_width, target_height, target_width, coordinate_transformation_mode -): - """Infer input x,y from output x,y with various coordinate transformation methods""" - scale_y = te.div(image_height.astype("float"), target_height.astype("float")) +def get_3d_pixel(data, layout, image_depth, image_height, image_width, n, c, z, y, x, cc): + """Get 3d pixel""" + z = tvm.te.max(tvm.te.min(z, image_depth - 1), 0) + y = tvm.te.max(tvm.te.min(y, image_height - 1), 0) + x = tvm.te.max(tvm.te.min(x, image_width - 1), 0) + if layout == "NDHWC": + return data(n, z, y, x, c).astype("float") + if layout == "NCDHW": + return data(n, c, z, y, x).astype("float") + # else must be NCDHWxc + return data(n, c, z, y, x, cc).astype("float") + + +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" scale_x = te.div(image_width.astype("float"), target_width.astype("float")) if coordinate_transformation_mode == "half_pixel": - in_y = (y + 0.5) * scale_y - 0.5 in_x = (x + 0.5) * scale_x - 0.5 elif coordinate_transformation_mode == "align_corners": - in_y = (image_height - 1).astype("float") / (target_height - 1) * y in_x = (image_width - 1).astype("float") / (target_width - 1) * x elif coordinate_transformation_mode == "asymmetric": - in_y = scale_y * y in_x = scale_x * x elif coordinate_transformation_mode == "pytorch_half_pixel": - in_y = te.if_then_else(target_height > 1, (y + 0.5) * scale_y - 0.5, 0.0) in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) elif coordinate_transformation_mode == "tf_half_pixel_for_nn": - in_y = (y + 0.5) * scale_y in_x = (x + 0.5) * scale_x else: raise ValueError( "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) ) - return in_y, in_x + return in_x -def resize_nearest_neighbor( +def get_closest_index(in_x, rounding_method, boxes): + """get the closest index to a value based on a certain rounding method""" + if rounding_method == "round" or boxes is not None: + closest_x_index = te.round(in_x).astype("int32") + elif rounding_method == "round_prefer_floor": + closest_x_index = te.ceil(in_x - 0.5).astype("int32") + elif rounding_method == "round_prefer_ceil": + closest_x_index = te.floor(in_x + 0.5).astype("int32") + elif rounding_method == "floor": + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + closest_x_index = te.floor(in_x + epsilon).astype("int32") + elif rounding_method == "ceil": + # Subract epsilon from ceil to prevent gpu rounding errors. + epsilon = 1e-5 + closest_x_index = te.ceil(in_x - epsilon).astype("int32") + else: + raise ValueError("Uknown rounding method: {}".format(rounding_method)) + return closest_x_index + + +def _lerp(A, B, t): + """Perform Linear interpolation in 1D""" + return A * (1.0 - t) + B * t + + +def _cubic_spline_weights(t, alpha): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + +def _cubic_kernel(inputs, w): + """perform cubic interpolation in 1D""" + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + +def _resize_1d( indices, data, - image_height, image_width, - target_height, target_width, boxes=None, box_indices=None, + method=None, extrapolation_value=None, - layout="NCHW", + layout="NCW", coordinate_transformation_mode="align_corners", rounding_method="", + alpha=-0.5, + exclude_outside=0, out_dtype=None, ): - """Perform resize operation with nearest neighbor method on the data. - For details about Nearest-neighbor interpolation please refer to - https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. + """Perform resize operation on the data with selected method and options. Parameters ---------- @@ -113,19 +208,13 @@ def resize_nearest_neighbor( The indices of input data data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] - - image_height : integer - Input image height + inputs is a 3-D tensor with shape + [batch, channel, in_width] + or [batch, in_width, channel] image_width : integer Input image width - target_height : integer - The target resized image height - target_width : integer The target resized image width @@ -141,7 +230,7 @@ def resize_nearest_neighbor( Value used for extrapolation, when applicable. layout: string, optional - "NCHW", "NHWC", or "NCHWc". + "NCW", "NWC", or "NCWc". coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor @@ -153,6 +242,12 @@ def resize_nearest_neighbor( indicates how to find the "nearest" pixel in nearest_neighbor method [round, floor, ceil] + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + out_dtype: string, optional Type to return. If left None will be same as input type. @@ -161,11 +256,6 @@ def resize_nearest_neighbor( output : out_dtype The computed result with type out_dtype """ - if rounding_method == "": - if coordinate_transformation_mode == "align_corners": - rounding_method = "round" - else: - rounding_method = "floor" def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: @@ -174,136 +264,130 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): dtype = data_dtype return value.astype(dtype) - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) + n, c, x, cc, inum, ic = get_1d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + in_x = get_inx( + x, + image_width, + target_width, + coordinate_transformation_mode, + ) - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x - else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_1d_pixel( + data, + layout, + boxes, image_width, - target_height, - target_width, - coordinate_transformation_mode, + box_idx, + c, + closest_x_index, + cc, + inum, + ic, ) + elif method == "linear": + x_int = te.floor(in_x).astype("int32") - if rounding_method == "round" or boxes is not None: - closest_x_index = te.round(in_x).astype("int32") - closest_y_index = te.round(in_y).astype("int32") - elif rounding_method == "round_prefer_floor": - closest_x_index = te.ceil(in_x - 0.5).astype("int32") - closest_y_index = te.ceil(in_y - 0.5).astype("int32") - elif rounding_method == "round_prefer_ceil": - closest_x_index = te.floor(in_x + 0.5).astype("int32") - closest_y_index = te.floor(in_y + 0.5).astype("int32") - elif rounding_method == "floor": - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - closest_y_index = te.floor(in_y + epsilon).astype("int32") - closest_x_index = te.floor(in_x + epsilon).astype("int32") - elif rounding_method == "ceil": - # Subract epsilon from ceil to prevent gpu rounding errors. - epsilon = 1e-5 - closest_y_index = te.ceil(in_y - epsilon).astype("int32") - closest_x_index = te.ceil(in_x - epsilon).astype("int32") - else: - raise ValueError("Uknown rounding method: {}".format(rounding_method)) + x_lerp = in_x - x_int - value = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - closest_y_index, - closest_x_index, - cc, - inum, - ic, - ) + p = [0 for i in range(2)] + for i in range(2): + p[i] = get_1d_pixel( + data, + layout, + boxes, + image_width, + box_idx, + c, + x_int + i, + cc, + inum, + ic, + ) + + value = _lerp(*p, x_lerp) + + elif method == "cubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [0 for i in range(4)] + for i in range(4): + p[i] = get_1d_pixel( + data, + layout, + boxes, + image_width, + box_idx, + c, + xint + i - 1, + cc, + inum, + ic, + ) + + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wx = sum(wx) + wx = [w / sum_wx for w in wx] + value = _cubic_kernel(p, wx) + + else: + raise ValueError("Unknown resize method:", method) if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) # use extrapolation_value if in_x is out of boundary value = tvm.tir.if_then_else( in_x < 0, extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, value), ) return _cast_output(value, data.dtype, out_dtype=out_dtype) -def resize_bilinear( - indices, +def resize1d( data, - image_height, - image_width, - target_height, - target_width, - boxes=None, - box_indices=None, - extrapolation_value=None, - layout="NCHW", - coordinate_transformation_mode="align_corners", + size, + layout="NCW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, + output_shape=None, ): - - """Perform resize operation with bilinear method on the data. - For details about Bilinear interpolation please refer to - https://en.wikipedia.org/wiki/Bilinear_interpolation. + """Perform resize operation on the data. Parameters ---------- - indices : tuple - The indices of input data - data : tvm.te.Tensor - inputs is a 4-D tensor with shape - [batch, channel, in_height, in_width] - or [batch, in_height, in_width, channel] - - image_height : integer - Input image height - - image_width : integer - Input image width - - target_height : integer - The target resized image height - - target_width : integer - The target resized image width - - boxes : tvm.te.Tensor, optional - A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies - the coordinates of a box. + inputs is a 3-D tensor with shape + [batch, channel in_width] + or [batch in_width, channel] - box_indices : tvm.te.Tensor, optional - A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that - the i-th box refers to. - - extrapolation_value: float, optional - Value used for extrapolation, when applicable. + size: Tuple + Output resolution scale to layout: string, optional - "NCHW", "NHWC", or "NCHWc". + "NCW", "NWC", or "NCWc". coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor @@ -311,135 +395,69 @@ def resize_bilinear( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". + method: {"linear", "nearest_neighbor", "cubic"} + Method to be used for resizing. + out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- - output : out_dtype - The computed result with type out_dtype + output : tvm.te.Tensor + 4-D with shape [batch, chananel, in_width*scale] + or [batch, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_width*scale, channel-minor] """ - - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout=layout) - box_idx = box_indices(n) if box_indices is not None else n - - if boxes is not None: - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) - - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) - - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x + method = method.lower() + if layout == "NWC": + in_n, in_w, in_c = data.shape + if output_shape is None: + output_shape = [in_n, size[0], in_c] + elif layout == "NCW": + in_n, in_c, in_w = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0]] + elif ncw_pack_layout(layout): # for NCWinic + in_n, in_c, in_w, in_inum, in_ic = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_inum, in_ic] + elif ncw_xc_layout(layout): # for NCWxc + in_n, in_c, in_w, in_cc = data.shape + if output_shape is None: + output_shape = [in_n, in_c, size[0], in_cc] else: - in_y, in_x = get_iny_inx( - y, - x, - image_height, - image_width, - target_height, - target_width, - coordinate_transformation_mode, - ) - - top_y_index = te.floor(in_y).astype("int32") - bottom_y_index = te.ceil(in_y).astype("int32") - y_lerp = in_y - top_y_index - - left_x_index = te.floor(in_x).astype("int32") - right_x_index = te.ceil(in_x).astype("int32") - x_lerp = in_x - left_x_index + raise ValueError("%s layout is not supported." % layout) - top_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - left_x_index, - cc, - inum, - ic, - ) - top_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - top_y_index, - right_x_index, - cc, - inum, - ic, - ) - bottom_left = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - left_x_index, - cc, - inum, - ic, - ) - bottom_right = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - bottom_y_index, - right_x_index, - cc, - inum, - ic, - ) + if isinstance(size, tuple): + size = list(size) - top = _lerp(top_left, top_right, x_lerp) - bottom = _lerp(bottom_left, bottom_right, x_lerp) - value = _lerp(top, bottom, y_lerp) + for i in range(1): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) - # use extrapolation_value if in_y/in_x is out of boundary - if extrapolation_value is not None: - out = tvm.tir.if_then_else( - in_y < 0, - extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), - ) - value = tvm.tir.if_then_else( - in_x < 0, - extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + def compute_func(*indices): + return _resize_1d( + indices, + data, + in_w, + size[0], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) + + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) -def resize_bicubic( +def _resize_2d( indices, data, image_height, @@ -448,17 +466,17 @@ def resize_bicubic( target_width, boxes=None, box_indices=None, + method=None, extrapolation_value=None, layout="NCHW", coordinate_transformation_mode="align_corners", - out_dtype=None, + rounding_method="", alpha=-0.5, exclude_outside=0, + out_dtype=None, ): - """Perform resize operation with bicubic method on the data. - More details about Bicubic interpolation please refer to - https://en.wikipedia.org/wiki/Bicubic_interpolation. - This algorithm is doing a bicubic spline interpolation + + """Perform resize operation on the data with selected method and options. Parameters ---------- @@ -468,7 +486,7 @@ def resize_bicubic( data : tvm.te.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] - or [:batch, in_height, in_width, channel] + or [batch, in_height, in_width, channel] image_height : integer Input image height @@ -502,12 +520,19 @@ def resize_bicubic( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - out_dtype: string, optional - Type to return. If left None will be same as input type. + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] alpha: float, optional Bicubic spline coefficient + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + Returns ------- output : out_dtype @@ -523,7 +548,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n - if boxes is not None: y1, x1 = boxes(n, 0), boxes(n, 1) y2, x2 = boxes(n, 2), boxes(n, 3) @@ -536,77 +560,118 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): in_y = y1 * (image_height - 1) + h_scale * y in_x = x1 * (image_width - 1) + w_scale * x else: - in_y, in_x = get_iny_inx( - y, - x, + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) + + value = get_2d_pixel( + data, + layout, + boxes, image_height, image_width, - target_height, - target_width, - coordinate_transformation_mode, + box_idx, + c, + closest_y_index, + closest_x_index, + cc, + inum, + ic, ) + elif method == "linear": + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + y_int + j, + x_int + i, + cc, + inum, + ic, + ) - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + top = _lerp(*p[0], x_lerp) + bottom = _lerp(*p[1], x_lerp) + value = _lerp(top, bottom, y_lerp) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + elif method == "cubic": + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) - # Get the surrounding values - p = [[0 for i in range(4)] for j in range(4)] - for j in range(4): - for i in range(4): - p[j][i] = get_2d_pixel( - data, - layout, - boxes, - image_height, - image_width, - box_idx, - c, - yint + j - 1, - xint + i - 1, - cc, - inum, - ic, - ) + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) + + # Get the surrounding values + p = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + p[j][i] = get_2d_pixel( + data, + layout, + boxes, + image_height, + image_width, + box_idx, + c, + yint + j - 1, + xint + i - 1, + cc, + inum, + ic, + ) + + wx = _cubic_spline_weights(xfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + if exclude_outside: + for i in range(4): + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + sum_wx = sum(wx) + sum_wy = sum(wy) + wx = [w / sum_wx for w in wx] + wy = [w / sum_wy for w in wy] + col0 = _cubic_kernel(p[0], wx) + col1 = _cubic_kernel(p[1], wx) + col2 = _cubic_kernel(p[2], wx) + col3 = _cubic_kernel(p[3], wx) + value = _cubic_kernel([col0, col1, col2, col3], wy) + + else: + raise ValueError("Unknown resize method:", method) - # Interpolate bicubically - def _cubic_spline_weights(t): - t2 = t * t - t3 = t * t * t - w1 = alpha * (t3 - 2 * t2 + t) - w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 - w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t - w4 = -alpha * t3 + alpha * t2 - return [w1, w2, w3, w4] - - def _cubic_kernel(inputs, w): - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) - - wx = _cubic_spline_weights(xfract) - wy = _cubic_spline_weights(yfract) - if exclude_outside: - for i in range(4): - wx[i] = te.if_then_else(te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i]) - wy[i] = te.if_then_else(te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i]) - sum_wx = sum(wx) - sum_wy = sum(wy) - wx = [w / sum_wx for w in wx] - wy = [w / sum_wy for w in wy] - col0 = _cubic_kernel(p[0], wx) - col1 = _cubic_kernel(p[1], wx) - col2 = _cubic_kernel(p[2], wx) - col3 = _cubic_kernel(p[3], wx) - value = _cubic_kernel([col0, col1, col2, col3], wy) - - # use extrapolation_value if in_y/in_x is out of boundary if extrapolation_value is not None: out = tvm.tir.if_then_else( in_y < 0, extrapolation_value, tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), ) + # use extrapolation_value if in_x is out of boundary value = tvm.tir.if_then_else( in_x < 0, extrapolation_value, @@ -615,11 +680,11 @@ def _cubic_kernel(inputs, w): return _cast_output(value, data.dtype, out_dtype=out_dtype) -def resize( +def resize2d( data, size, layout="NCHW", - method="bilinear", + method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", bicubic_alpha=-0.5, @@ -648,7 +713,7 @@ def resize( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"bilinear", "nearest_neighbor", "bicubic"} + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional @@ -692,58 +757,23 @@ def resize( if isinstance(size[i], int): size[i] = tvm.tir.IntImm("int32", size[i]) - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( + def compute_func(*indices): + return _resize_2d( indices, data, in_h, in_w, size[0], size[1], + method=method, layout=layout, coordinate_transformation_mode=coordinate_transformation_mode, rounding_method=rounding_method, - out_dtype=out_dtype, - ) - - def _bilinear(*indices): - return resize_bilinear( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, - ) - - def _bicubic(*indices): - return resize_bicubic( - indices, - data, - in_h, - in_w, - size[0], - size[1], - layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, - out_dtype=out_dtype, alpha=bicubic_alpha, exclude_outside=bicubic_exclude, + out_dtype=out_dtype, ) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear - elif method == "bicubic": - compute_func = _bicubic - else: - raise ValueError("%s method is not supported." % method) - return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) @@ -818,9 +848,11 @@ def crop_and_resize( image_w = data.shape[3].astype("int32") else: raise ValueError("%s layout is not supported." % layout) + if method == "bilinear": + method = "linear" - def _bilinear(*indices): - return resize_bilinear( + def compute_func(*indices): + return _resize_2d( indices, data, image_h, @@ -829,50 +861,280 @@ def _bilinear(*indices): target_w, boxes, box_indices, - extrapolation_value, - layout, + method=method, + extrapolation_value=extrapolation_value, + layout=layout, out_dtype=out_dtype, ) - def _nearest_neighbor(*indices): - return resize_nearest_neighbor( - indices, + return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) + + +def _resize_3d( + indices, + data, + image_depth, + image_height, + image_width, + target_depth, + target_height, + target_width, + boxes=None, + box_indices=None, + method=None, + extrapolation_value=None, + layout="NCHW", + coordinate_transformation_mode="align_corners", + rounding_method="", + alpha=-0.5, + exclude_outside=0, + out_dtype=None, +): + + """Perform resize operation on the data with selected method and options. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.te.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_depth : integer + Input image depth + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_depth : integer + The target resized image depth + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.te.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.te.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + rounding_method: string, optional + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + alpha: float, optional + Bicubic spline coefficient + + exclude_oiutside: bool, optional: + Exclude values outside the image fdor bicubic interpolation + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + n, c, z, y, x, cc = get_3d_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + # TODO(mbrookhart): Find an example of this + raise NotImplementedError("resize1d with image boxes not yet implemented") + in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode) + in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode) + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + + if method == "nearest_neighbor": + if rounding_method == "": + if coordinate_transformation_mode == "align_corners": + rounding_method = "round" + else: + rounding_method = "floor" + + closest_z_index = get_closest_index(in_z, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + + value = get_3d_pixel( data, - image_h, - image_w, - target_h, - target_w, - boxes, - box_indices, - extrapolation_value, layout, - out_dtype=out_dtype, + image_depth, + image_height, + image_width, + box_idx, + c, + closest_z_index, + closest_y_index, + closest_x_index, + cc, ) + elif method == "linear": + z_int = te.floor(in_z).astype("int32") + y_int = te.floor(in_y).astype("int32") + x_int = te.floor(in_x).astype("int32") + + z_lerp = in_z - z_int + y_lerp = in_y - y_int + x_lerp = in_x - x_int + + p = [[[0 for i in range(2)] for j in range(2)] for k in range(2)] + for k in range(2): + for j in range(2): + for i in range(2): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + z_int + k, + y_int + j, + x_int + i, + cc, + ) + l = [[0 for i in range(2)] for j in range(2)] + for j in range(2): + for i in range(2): + l[j][i] = _lerp(*p[j][i], x_lerp) + + top = _lerp(*l[0], y_lerp) + bottom = _lerp(*l[1], y_lerp) + value = _lerp(top, bottom, z_lerp) + + elif method == "cubic": + zint = te.floor(in_z).astype("int32") + zfract = in_z - te.floor(in_z) + + yint = te.floor(in_y).astype("int32") + yfract = in_y - te.floor(in_y) + + xint = te.floor(in_x).astype("int32") + xfract = in_x - te.floor(in_x) + + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for k in range(4): + for j in range(4): + for i in range(4): + p[k][j][i] = get_3d_pixel( + data, + layout, + image_depth, + image_height, + image_width, + box_idx, + c, + zint + k - 1, + yint + j - 1, + xint + i - 1, + cc, + ) + + wz = _cubic_spline_weights(zfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + wx = _cubic_spline_weights(xfract, alpha) + if exclude_outside: + for i in range(4): + wz[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wz = sum(wz) + sum_wy = sum(wy) + sum_wx = sum(wx) + wz = [w / sum_wz for w in wz] + wy = [w / sum_wy for w in wy] + wx = [w / sum_wx for w in wx] + + l = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + l[j][i] = _cubic_kernel(p[j][i], wx) + col0 = _cubic_kernel(l[0], wy) + col1 = _cubic_kernel(l[1], wy) + col2 = _cubic_kernel(l[2], wy) + col3 = _cubic_kernel(l[3], wy) + value = _cubic_kernel([col0, col1, col2, col3], wz) - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "bilinear": - compute_func = _bilinear else: - raise ValueError("%s method is not supported." % method) + raise ValueError("Unknown resize method:", method) - return te.compute(output_shape, compute_func, name="crop_and_resize", tag=tag.INJECTIVE) + if extrapolation_value is not None: + out = tvm.tir.if_then_else( + in_z < 0, + extrapolation_value, + tvm.tir.if_then_else(in_z > image_depth - 1, extrapolation_value, value), + ) + out = tvm.tir.if_then_else( + in_y < 0, + extrapolation_value, + tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), + ) + # use extrapolation_value if in_x is out of boundary + value = tvm.tir.if_then_else( + in_x < 0, + extrapolation_value, + tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + ) + return _cast_output(value, data.dtype, out_dtype=out_dtype) def resize3d( data, size, layout="NCDHW", - method="nearest_neighbor", - coordinate_transformation_mode="align_corners", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="", + bicubic_alpha=-0.5, + bicubic_exclude=0, out_dtype=None, + output_shape=None, ): """Perform resize operation on the data. Parameters ---------- - inputs: tvm.te.Tensor + data : tvm.te.Tensor inputs is a 5-D tensor with shape [batch, channel, in_depth, in_height, in_width] or [batch, in_depth, in_height, in_width, channel] @@ -887,24 +1149,28 @@ def resize3d( Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. Refer to the ONNX Resize operator specification for details. - Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"trilinear", "nearest_neighbor"} + + method: {"linear", "nearest_neighbor", "cubic"} Method to be used for resizing. out_dtype: string, optional Type to return. If left None will be same as input type. + output_shape: tvm.tir.container.Array, optional + Shape to return. If left None will be inferred + (If shape is determined dynamically, pass out_dtype.shape as output_shape) + Returns ------- output : tvm.te.Tensor - 5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] + 4-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, - channel-minor] + or 5-D with shape + [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, channel-minor] """ - method = method.lower() + method = method.lower() if layout == "NDHWC": in_n, in_d, in_h, in_w, in_c = data.shape output_shape = [in_n, size[0], size[1], size[2], in_c] @@ -916,125 +1182,30 @@ def resize3d( in_n, in_c, in_d, in_h, in_w, in_cc = data.shape output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc] - if coordinate_transformation_mode == "align_corners": - z_ratio = (in_d - 1).astype("float") / (size[0] - 1) - y_ratio = (in_h - 1).astype("float") / (size[1] - 1) - x_ratio = (in_w - 1).astype("float") / (size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - z_ratio = (in_d).astype("float") / (size[0]) - y_ratio = (in_h).astype("float") / (size[1]) - x_ratio = (in_w).astype("float") / (size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _get_pixel(n, c, z, y, x, cc): - z = tvm.te.max(tvm.te.min(z, in_d - 1), 0) - y = tvm.te.max(tvm.te.min(y, in_h - 1), 0) - x = tvm.te.max(tvm.te.min(x, in_w - 1), 0) - if layout == "NDHWC": - return data(n, z, y, x, c).astype("float") - if layout == "NCDHW": - return data(n, c, z, y, x).astype("float") - # else must be NCDHWxc - return data(n, c, z, y, x, cc).astype("float") - - def _get_indices(*indices): - if layout == "NDHWC": - n, z, y, x, c = indices - cc = None - elif layout == "NCDHW": - n, c, z, y, x = indices - cc = None - else: - n, c, z, y, x, cc = indices - - return n, c, z, y, x, cc - - def _cast_output(value): - if out_dtype: - dtype = out_dtype - else: - dtype = data.dtype - return value.astype(dtype) - - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - if coordinate_transformation_mode == "align_corners": - zint = te.round(in_z).astype("int32") - yint = te.round(in_y).astype("int32") - xint = te.round(in_x).astype("int32") - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - zint = te.floor(in_z + epsilon).astype("int32") - yint = te.floor(in_y + epsilon).astype("int32") - xint = te.floor(in_x + epsilon).astype("int32") - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode - ) - ) - - return _cast_output(_get_pixel(n, c, zint, yint, xint, cc)) - - # Trilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _trilinear(*indices): - n, c, z, y, x, cc = _get_indices(*indices) - - if coordinate_transformation_mode == "half_pixel": - in_z = z_ratio * (z + 0.5) - 0.5 - in_y = y_ratio * (y + 0.5) - 0.5 - in_x = x_ratio * (x + 0.5) - 0.5 - else: - in_z = z_ratio * z - in_y = y_ratio * y - in_x = x_ratio * x - - zint = te.floor(in_z).astype("int32") - zfract = in_z - te.floor(in_z) - - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + if isinstance(size, tuple): + size = list(size) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + for i in range(3): + if isinstance(size[i], int): + size[i] = tvm.tir.IntImm("int32", size[i]) - p000 = _get_pixel(n, c, zint, yint, xint, cc) - p001 = _get_pixel(n, c, zint, yint, xint + 1, cc) - p010 = _get_pixel(n, c, zint, yint + 1, xint, cc) - p011 = _get_pixel(n, c, zint, yint + 1, xint + 1, cc) - p100 = _get_pixel(n, c, zint + 1, yint, xint, cc) - p101 = _get_pixel(n, c, zint + 1, yint, xint + 1, cc) - p110 = _get_pixel(n, c, zint + 1, yint + 1, xint, cc) - p111 = _get_pixel(n, c, zint + 1, yint + 1, xint + 1, cc) - - dep00 = _lerp(p000, p100, zfract) - dep01 = _lerp(p001, p101, zfract) - dep10 = _lerp(p010, p110, zfract) - dep11 = _lerp(p011, p111, zfract) - col0 = _lerp(dep00, dep01, xfract) - col1 = _lerp(dep10, dep11, xfract) - value = _lerp(col0, col1, yfract) - return _cast_output(value) - - # Determine which interpolation method to use then run it. - if method == "nearest_neighbor": - compute_func = _nearest_neighbor - elif method == "trilinear": - compute_func = _trilinear - else: - raise ValueError("%s method is not supported." % method) + def compute_func(*indices): + return _resize_3d( + indices, + data, + in_d, + in_h, + in_w, + size[0], + size[1], + size[2], + method=method, + layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + alpha=bicubic_alpha, + exclude_outside=bicubic_exclude, + out_dtype=out_dtype, + ) - return te.compute(output_shape, compute_func, name="resize3d", tag=tag.INJECTIVE) + return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE) diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index b95835f6e103..36b9349a139d 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -92,7 +92,9 @@ def upsampling( else: raise ValueError("not support this layout {} yet".format(layout)) coord_trans = "align_corners" if align_corners else "asymmetric" - return topi.image.resize( + if method[0:2] == "bi": + method = method[2:] + return topi.image.resize2d( data, reshape_size, layout=layout, @@ -188,6 +190,8 @@ def upsampling3d( ) else: raise ValueError("not support this layout {} yet".format(layout)) + if method[0:3] == "tri": + method = method[3:] return topi.image.resize3d( data, resize_shape, diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index afb251417315..e23ecfa8fc69 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -35,9 +35,7 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python -from .upsampling_python import upsampling_python, upsampling3d_python -from .bilinear_resize_python import bilinear_resize_python -from .trilinear_resize3d_python import trilinear_resize3d_python +from .resize_python import resize1d_python, resize2d_python, resize3d_python from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python from .roi_pool_python import roi_pool_nchw_python diff --git a/python/tvm/topi/testing/bilinear_resize_python.py b/python/tvm/topi/testing/bilinear_resize_python.py deleted file mode 100644 index b1fb8b0b4845..000000000000 --- a/python/tvm/topi/testing/bilinear_resize_python.py +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Bilinear Scale in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): - """Bilinear scaling using python""" - (new_h, new_w) = out_size - (ib, ic) = (1, 1) - - if layout == "NHWC": - (batch, h, w, channel) = image.shape - scaled_image = np.ones((batch, new_h, new_w, channel)) - # NCHWinic - elif nchw_pack_layout(layout): - (batch, channel, h, w, ib, ic) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w, ib, ic)) - else: - (batch, channel, h, w) = image.shape - scaled_image = np.ones((batch, channel, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - height_scale = np.float32(h - 1) / np.float32(out_size[0] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[1] - 1) - else: - height_scale = np.float32(h) / np.float32(out_size[0]) - width_scale = np.float32(w) / np.float32(out_size[1]) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _img_scale(b, m, i, n): - for j in range(new_h): - for k in range(new_w): - if coordinate_transformation_mode == "half_pixel": - in_y = (j + 0.5) * height_scale - 0.5 - else: - in_y = j * height_scale - y0 = int(math.floor(in_y)) - y1 = max(min(y0 + 1, h - 1), 0) - y0 = max(y0, 0) - y_lerp = in_y - math.floor(in_y) - - if coordinate_transformation_mode == "half_pixel": - in_x = (k + 0.5) * width_scale - 0.5 - else: - in_x = k * width_scale - x0 = int(math.floor(in_x)) - x1 = max(min(x0 + 1, w - 1), 0) - x0 = max(x0, 0) - x_lerp = in_x - math.floor(in_x) - - if layout == "NHWC": - A = image[b][y0][x0][i] - B = image[b][y0][x1][i] - C = image[b][y1][x0][i] - D = image[b][y1][x1][i] - elif nchw_pack_layout(layout): - A = image[b][i][y0][x0][m][n] - B = image[b][i][y0][x1][m][n] - C = image[b][i][y1][x0][m][n] - D = image[b][i][y1][x1][m][n] - else: - A = image[b][i][y0][x0] - B = image[b][i][y0][x1] - C = image[b][i][y1][x0] - D = image[b][i][y1][x1] - - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NHWC": - scaled_image[b][j][k][i] = pixel - elif nchw_pack_layout(layout): - scaled_image[b][i][j][k][m][n] = pixel - else: - scaled_image[b][i][j][k] = pixel - - for b in range(batch): - for m in range(ib): - for i in range(channel): - for n in range(ic): - _img_scale(b, m, i, n) - - return scaled_image diff --git a/python/tvm/topi/testing/resize_python.py b/python/tvm/topi/testing/resize_python.py new file mode 100644 index 000000000000..e8d5c0599887 --- /dev/null +++ b/python/tvm/topi/testing/resize_python.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Upsampling in python""" +import math +import numpy as np +from tvm.topi.utils import nchw_pack_layout + + +def get_inx(x, image_width, target_width, coordinate_transformation_mode): + """Infer input x from output x with various coordinate transformation methods""" + scale = image_width / target_width + if coordinate_transformation_mode == "half_pixel": + in_x = (x + 0.5) * scale - 0.5 + elif coordinate_transformation_mode == "align_corners": + in_x = (image_width - 1) / (target_width - 1) * x if target_width > 1 else 0 + elif coordinate_transformation_mode == "asymmetric": + in_x = scale * x + else: + raise ValueError( + "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) + ) + return in_x + + +def get_index(x, image_width, target_width, coordinate_transformation_mode): + """get and round the nearest index for nearest_neighbor""" + in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode) + if coordinate_transformation_mode == "align_corners": + # round prefer ceil + out = int(math.floor(in_x + 0.5)) + else: + out = int(math.floor(in_x)) + out = max(min(out, image_width - 1), 0) + return out + + +def resize3d_nearest(arr, scale, coordinate_transformation_mode): + """Populate the array by scale factor""" + d, h, w = arr.shape + out_d, out_h, out_w = [int(round(i * s)) for i, s in zip(arr.shape, scale)] + out = np.empty((out_d, out_h, out_w)) + for z in range(out_d): + for y in range(out_h): + for x in range(out_w): + in_z = get_index(z, d, out_d, coordinate_transformation_mode) + in_y = get_index(y, h, out_h, coordinate_transformation_mode) + in_x = get_index(x, w, out_w, coordinate_transformation_mode) + out[z, y, x] = arr[in_z, in_y, in_x] + return out + + +def resize3d_linear(data_in, scale, coordinate_transformation_mode): + """Trilinear 3d scaling using python""" + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _in_coord(new_coord, in_shape, out_shape): + in_coord = get_inx(new_coord, in_shape, out_shape, coordinate_transformation_mode) + coord0 = int(math.floor(in_coord)) + coord1 = max(min(coord0 + 1, in_shape - 1), 0) + coord0 = max(coord0, 0) + coord_lerp = in_coord - math.floor(in_coord) + return coord0, coord1, coord_lerp + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + z0, z1, z_lerp = _in_coord(m, d, new_d) + y0, y1, y_lerp = _in_coord(j, h, new_h) + x0, x1, x_lerp = _in_coord(k, w, new_w) + + A0 = data_in[z0][y0][x0] + B0 = data_in[z0][y0][x1] + C0 = data_in[z0][y1][x0] + D0 = data_in[z0][y1][x1] + A1 = data_in[z1][y0][x0] + B1 = data_in[z1][y0][x1] + C1 = data_in[z1][y1][x0] + D1 = data_in[z1][y1][x1] + + A = _lerp(A0, A1, z_lerp) + B = _lerp(B0, B1, z_lerp) + C = _lerp(C0, C1, z_lerp) + D = _lerp(D0, D1, z_lerp) + top = _lerp(A, B, x_lerp) + bottom = _lerp(C, D, x_lerp) + + data_out[m][j][k] = np.float32(_lerp(top, bottom, y_lerp)) + + return data_out + + +def resize3d_cubic(data_in, scale, coordinate_transformation_mode): + """Tricubic 3d scaling using python""" + d, h, w = data_in.shape + new_d, new_h, new_w = [int(round(i * s)) for i, s in zip(data_in.shape, scale)] + data_out = np.ones((new_d, new_h, new_w)) + + def _cubic_spline_weights(t, alpha=-0.5): + """create cubic spline weights in 1D""" + t2 = t * t + t3 = t * t * t + w1 = alpha * (t3 - 2 * t2 + t) + w2 = (alpha + 2) * t3 - (3 + alpha) * t2 + 1 + w3 = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t + w4 = -alpha * t3 + alpha * t2 + return [w1, w2, w3, w4] + + def _cubic_kernel(inputs, w): + """perform cubic interpolation in 1D""" + return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + + def _get_input_value(z, y, x): + z = max(min(z, d - 1), 0) + y = max(min(y, h - 1), 0) + x = max(min(x, w - 1), 0) + return data_in[z][y][x] + + def _get_patch(zint, yint, xint): + # Get the surrounding values + p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] + for kk in range(4): + for jj in range(4): + for ii in range(4): + p[kk][jj][ii] = _get_input_value( + zint + kk - 1, + yint + jj - 1, + xint + ii - 1, + ) + return p + + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + in_z = get_inx(m, d, new_d, coordinate_transformation_mode) + in_y = get_inx(j, h, new_h, coordinate_transformation_mode) + in_x = get_inx(k, w, new_w, coordinate_transformation_mode) + zint = math.floor(in_z) + zfract = in_z - math.floor(in_z) + + yint = math.floor(in_y) + yfract = in_y - math.floor(in_y) + + xint = math.floor(in_x) + xfract = in_x - math.floor(in_x) + + wz = _cubic_spline_weights(zfract) + wy = _cubic_spline_weights(yfract) + wx = _cubic_spline_weights(xfract) + + p = _get_patch(zint, yint, xint) + + l = [[0 for i in range(4)] for j in range(4)] + for jj in range(4): + for ii in range(4): + l[jj][ii] = _cubic_kernel(p[jj][ii], wx) + + col0 = _cubic_kernel(l[0], wy) + col1 = _cubic_kernel(l[1], wy) + col2 = _cubic_kernel(l[2], wy) + col3 = _cubic_kernel(l[3], wy) + data_out[m][j][k] = _cubic_kernel([col0, col1, col2, col3], wz) + + return data_out + + +def resize3d_ncdhw( + data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners" +): + """reference kernel for 3D image resizing""" + ishape = data.shape + + oshape = ( + ishape[0], + ishape[1], + int(round(ishape[2] * scale[0])), + int(round(ishape[3] * scale[1])), + int(round(ishape[4] * scale[2])), + ) + + output_np = np.zeros(oshape, dtype=data.dtype) + + for b in range(oshape[0]): + for c in range(oshape[1]): + if method == "nearest_neighbor": + output_np[b, c, :, :, :] = resize3d_nearest( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + elif method == "linear": + output_np[b, c, :, :, :] = resize3d_linear( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + elif method == "cubic": + output_np[b, c, :, :, :] = resize3d_cubic( + data[b, c, :, :, :], scale, coordinate_transformation_mode + ) + else: + raise ValueError("Unknown resize method", method) + + return output_np + + +def resize1d_python( + data, + scale, + layout="NCW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NWC": + data = data.transpose([0, 2, 1]) + + data = np.expand_dims(data, axis=[2, 3]) + output_np = resize3d_ncdhw(data, (1, 1) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NWC": + output_np = output_np.transpose([0, 2, 1]) + + return output_np + + +def resize2d_python( + data, + scale, + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of scaling using nearest neighbour""" + + if layout == "NHWC": + data = data.transpose([0, 3, 1, 2]) + elif nchw_pack_layout(layout): + ishape = data.shape + transposed = data.transpose([0, 4, 1, 5, 2, 3]) + tshape = transposed.shape + data = transposed.reshape( + tshape[0] * tshape[1], tshape[2] * tshape[3], tshape[4], tshape[5] + ) + + data = np.expand_dims(data, axis=2) + output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode) + output_np = np.squeeze(output_np, axis=2) + + if layout == "NHWC": + output_np = output_np.transpose([0, 2, 3, 1]) + elif nchw_pack_layout(layout): + output_np = output_np.reshape(tshape[0:4] + output_np.shape[2:]) + output_np = output_np.transpose([0, 2, 4, 5, 1, 3]) + + return output_np + + +def resize3d_python( + data, + scale, + layout="NCDHW", + method="nearest_neighbor", + coordinate_transformation_mode="align_corners", +): + """Python version of 3D scaling using nearest neighbour""" + + if layout == "NDHWC": + data = data.transpose([0, 4, 1, 2, 3]) + + output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode) + + if layout == "NDHWC": + output_np = output_np.transpose([0, 2, 3, 4, 1]) + + return output_np diff --git a/python/tvm/topi/testing/trilinear_resize3d_python.py b/python/tvm/topi/testing/trilinear_resize3d_python.py deleted file mode 100644 index d603e987d5ef..000000000000 --- a/python/tvm/topi/testing/trilinear_resize3d_python.py +++ /dev/null @@ -1,111 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks -"""Trilinear 3D resize in python""" -import math -import numpy as np - - -def trilinear_resize3d_python( - data_in, out_size, layout, coordinate_transformation_mode="align_corners" -): - """Trilinear 3d scaling using python""" - (new_d, new_h, new_w) = out_size - - if layout == "NDHWC": - (batch, d, h, w, channel) = data_in.shape - data_out = np.ones((batch, new_d, new_h, new_w, channel)) - else: - (batch, channel, d, h, w) = data_in.shape - data_out = np.ones((batch, channel, new_d, new_h, new_w)) - - if coordinate_transformation_mode == "align_corners": - depth_scale = np.float32(d - 1) / np.float32(out_size[0] - 1) - height_scale = np.float32(h - 1) / np.float32(out_size[1] - 1) - width_scale = np.float32(w - 1) / np.float32(out_size[2] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - depth_scale = np.float32(d) / np.float32(out_size[0]) - height_scale = np.float32(h) / np.float32(out_size[1]) - width_scale = np.float32(w) / np.float32(out_size[2]) - else: - raise ValueError( - "Unsupported coordinate_transformation_mode: {}".format(coordinate_transformation_mode) - ) - - def _lerp(A, B, t): - return A * (1.0 - t) + B * t - - def _in_coord(new_coord, scale, shape, mode): - if mode == "half_pixel": - in_coord = (new_coord + 0.5) * scale - 0.5 - else: - in_coord = new_coord * scale - coord0 = int(math.floor(in_coord)) - coord1 = max(min(coord0 + 1, shape - 1), 0) - coord0 = max(coord0, 0) - coord_lerp = in_coord - math.floor(in_coord) - return coord0, coord1, coord_lerp - - for b in range(batch): - for i in range(channel): - for m in range(new_d): - for j in range(new_h): - for k in range(new_w): - z0, z1, z_lerp = _in_coord( - m, depth_scale, d, coordinate_transformation_mode - ) - y0, y1, y_lerp = _in_coord( - j, height_scale, h, coordinate_transformation_mode - ) - x0, x1, x_lerp = _in_coord( - k, width_scale, w, coordinate_transformation_mode - ) - - if layout == "NDHWC": - A0 = data_in[b][z0][y0][x0][i] - B0 = data_in[b][z0][y0][x1][i] - C0 = data_in[b][z0][y1][x0][i] - D0 = data_in[b][z0][y1][x1][i] - A1 = data_in[b][z1][y0][x0][i] - B1 = data_in[b][z1][y0][x1][i] - C1 = data_in[b][z1][y1][x0][i] - D1 = data_in[b][z1][y1][x1][i] - else: - A0 = data_in[b][i][z0][y0][x0] - B0 = data_in[b][i][z0][y0][x1] - C0 = data_in[b][i][z0][y1][x0] - D0 = data_in[b][i][z0][y1][x1] - A1 = data_in[b][i][z1][y0][x0] - B1 = data_in[b][i][z1][y0][x1] - C1 = data_in[b][i][z1][y1][x0] - D1 = data_in[b][i][z1][y1][x1] - - A = _lerp(A0, A1, z_lerp) - B = _lerp(B0, B1, z_lerp) - C = _lerp(C0, C1, z_lerp) - D = _lerp(D0, D1, z_lerp) - top = _lerp(A, B, x_lerp) - bottom = _lerp(C, D, x_lerp) - - pixel = np.float32(_lerp(top, bottom, y_lerp)) - - if layout == "NDHWC": - data_out[b][m][j][k][i] = pixel - else: - data_out[b][i][m][j][k] = pixel - - return data_out diff --git a/python/tvm/topi/testing/upsampling_python.py b/python/tvm/topi/testing/upsampling_python.py deleted file mode 100644 index dd187c4d8cff..000000000000 --- a/python/tvm/topi/testing/upsampling_python.py +++ /dev/null @@ -1,136 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals -"""Upsampling in python""" -import math -import numpy as np -from tvm.topi.utils import nchw_pack_layout - - -def upsample_nearest(arr, scale): - """Populate the array by scale factor""" - h, w = arr.shape - out_h = int(round(h * scale[0])) - out_w = int(round(w * scale[1])) - out = np.empty((out_h, out_w)) - for y in range(out_h): - for x in range(out_w): - in_y = math.floor(y / scale[0]) - in_x = math.floor(x / scale[1]) - out[y, x] = arr[in_y, in_x] - return out - - -def upsampling_python(data, scale, layout="NCHW"): - """Python version of scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) - return output_np - # NCHWinic - if nchw_pack_layout(layout): - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - ishape[4], - ishape[5], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for ib in range(oshape[4]): - for c in range(oshape[1]): - for ic in range(oshape[5]): - output_np[b, c, :, :, ib, ic] = upsample_nearest( - data[b, c, :, :, ib, ic], scale - ) - return output_np - - if layout == "NHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - ishape[3], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[3]): - output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) - - -def upsample3d_nearest(arr, scale): - """Populate the array by scale factor""" - d, h, w = arr.shape - out_d = int(round(d * scale[0])) - out_h = int(round(h * scale[1])) - out_w = int(round(w * scale[2])) - out = np.empty((out_d, out_h, out_w)) - for z in range(out_d): - for y in range(out_h): - for x in range(out_w): - in_z = math.floor(z / scale[0]) - in_y = math.floor(y / scale[1]) - in_x = math.floor(x / scale[2]) - out[z, y, x] = arr[in_z, in_y, in_x] - return out - - -def upsampling3d_python(data, scale, layout="NCDHW"): - """Python version of 3D scaling using nearest neighbour""" - - ishape = data.shape - if layout == "NCDHW": - oshape = ( - ishape[0], - ishape[1], - int(round(ishape[2] * scale[0])), - int(round(ishape[3] * scale[1])), - int(round(ishape[4] * scale[2])), - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[1]): - output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale) - return output_np - if layout == "NDHWC": - oshape = ( - ishape[0], - int(round(ishape[1] * scale[0])), - int(round(ishape[2] * scale[1])), - int(round(ishape[3] * scale[2])), - ishape[4], - ) - output_np = np.zeros(oshape, dtype=data.dtype) - for b in range(oshape[0]): - for c in range(oshape[4]): - output_np[b, :, :, :, c] = upsample3d_nearest(data[b, :, :, :, c], scale) - return output_np - raise ValueError("not support this layout {} yet".format(layout)) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 3a056cfb4326..be3df2be5f6a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -31,6 +31,16 @@ class InvalidShapeError(ValueError): """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" +def ncw_pack_layout(layout_info): + """Check whether the layout type is NCWinic""" + return layout_info[:3] == "NCW" and "c" in layout_info and "n" in layout_info + + +def ncw_xc_layout(layout_info): + """Check whether the layout type is NCWxc""" + return layout_info[:3] == "NCW" and "c" in layout_info and layout_info[3:-1].isnumeric() + + def nchw_pack_layout(layout_info): """Check whether the layout type is NCHWinic""" return layout_info[:4] == "NCHW" and "c" in layout_info and "n" in layout_info diff --git a/src/relay/op/dyn/image/resize.cc b/src/relay/op/dyn/image/resize.cc index 87cf89a223ec..002105f4d565 100644 --- a/src/relay/op/dyn/image/resize.cc +++ b/src/relay/op/dyn/image/resize.cc @@ -31,10 +31,10 @@ namespace tvm { namespace relay { namespace dyn { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // {data, size, out} ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -42,7 +42,7 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -66,24 +66,24 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Expr size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - double bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Expr size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + double cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("dyn.image.resize"); + static const Op& op = Op::Get("dyn.image.resize2d"); return Call(op, {data, size}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.dyn.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("dyn.image.resize") +RELAY_REGISTER_OP("dyn.image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -100,12 +100,12 @@ RELAY_REGISTER_OP("dyn.image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("size", "Tensor", "The output size tensor.") .set_support_level(5) - .add_type_rel("DynResize", ResizeRel) + .add_type_rel("DynResize2D", Resize2DRel) .set_attr("TOpPattern", kInjective); } // namespace dyn diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b672c7f87c05..ee779841505c 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -31,8 +31,6 @@ namespace tvm { namespace relay { -TVM_REGISTER_NODE_TYPE(ResizeAttrs); - template InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, @@ -58,15 +56,90 @@ InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params)); } -bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +TVM_REGISTER_NODE_TYPE(Resize1DAttrs); + +bool Resize1DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCW("NCW"); + + const Resize1DAttrs* param = attrs.as(); + ICHECK(param != nullptr); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCW); + ICHECK(layout_converter.defined()) + << "Resize only support input layouts that are convertible from NCW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, param->size[0]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // assign output type + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + return true; +} + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize1D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); + attrs->size = std::move(size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("image.resize1d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.resize1d").set_body_typed(MakeResize1D); + +RELAY_REGISTER_OP("image.resize1d") + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 3D array of shape + (batch_size, channels, in_width) for NCW + (batch_size, in_width, channels) for NWC + +- **out**: Output is 3D array of shape + for layout NCW + (batch_size, channels, size[0]) + + for layout NWC + (batch_size, size[0], channels) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize1D", Resize1DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .set_attr("TOpPattern", kInjective); + +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +bool Resize2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); - const ResizeAttrs* param = attrs.as(); + const Resize2DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); @@ -90,25 +163,25 @@ bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; attrs->rounding_method = rounding_method; - attrs->bicubic_alpha = bicubic_alpha; - attrs->bicubic_exclude = bicubic_exclude; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; - static const Op& op = Op::Get("image.resize"); + static const Op& op = Op::Get("image.resize2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize2d").set_body_typed(MakeResize2D); -RELAY_REGISTER_OP("image.resize") +RELAY_REGISTER_OP("image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape @@ -122,17 +195,17 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize", ResizeRel) - .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) + .add_type_rel("Resize2D", Resize2DRel) + .set_attr("FInferCorrectLayout", ResizeInferCorrectLayout) .set_attr("TOpPattern", kInjective); -TVM_REGISTER_NODE_TYPE(Resize3dAttrs); +TVM_REGISTER_NODE_TYPE(Resize3DAttrs); -bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, +bool Resize3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -140,7 +213,7 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, static const Layout kNCDHW("NCDHW"); - const Resize3dAttrs* param = attrs.as(); + const Resize3DAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); @@ -165,19 +238,23 @@ bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize3d(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, DataType out_dtype) { - auto attrs = make_object(); +Expr MakeResize3D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype) { + auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); attrs->method = std::move(method); attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->rounding_method = rounding_method; + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; attrs->out_dtype = out_dtype; static const Op& op = Op::Get("image.resize3d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3d); +TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3D); RELAY_REGISTER_OP("image.resize3d") .describe(R"code( @@ -194,11 +271,11 @@ Perform resize3d to input array with nearest neighbour or bilinear interpolation for layout NDHWC (batch_size, size[0], size[1], size[2], channels) )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) - .add_type_rel("Resize3d", Resize3dRel) + .add_type_rel("Resize3d", Resize3DRel) .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 6f4db5ab268a..1a47193bb91a 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -101,9 +101,9 @@ Expr MakeZeros(Array shape, DataType dtype); Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); -Expr MakeResize(Expr data, Array size, String layout, String method, - String coordinate_transformation_mode, String rounding_method, double bicubic_alpha, - int bicubic_exclude, DataType out_dtype); +Expr MakeResize2D(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, DataType out_dtype); Expr MakeSparseToDense(Expr indices, Array output_shape, Expr values, Expr default_value); diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 7c947ba109bf..318022fb86f5 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -106,20 +106,20 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, - {Op::Get("dyn.image.resize"), + {Op::Get("dyn.image.resize2d"), [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); if (const ConstantNode* size = args[1].as()) { - const ResizeAttrs* param = call_node->attrs.as(); + const Resize2DAttrs* param = call_node->attrs.as(); ICHECK(param); auto size_int = ToVector(size->data); Array size_prim; for (size_t i = 0; i < size_int.size(); ++i) { size_prim.push_back(size_int[i]); } - return MakeResize(call_node->args[0], size_prim, param->layout, param->method, - param->coordinate_transformation_mode, param->rounding_method, - param->bicubic_alpha, param->bicubic_exclude, param->out_dtype); + return MakeResize2D(call_node->args[0], size_prim, param->layout, param->method, + param->coordinate_transformation_mode, param->rounding_method, + param->cubic_alpha, param->cubic_exclude, param->out_dtype); } return Expr(nullptr); }}, diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 72dac9b2501f..ee9159573ea2 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -206,12 +206,15 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): dtype = "float32" a_np = np.full(input_dim, 1, dtype=dtype) + if mode == "NN": - b_np = tvm.topi.testing.upsampling_python(a_np, (scale, scale)) + method = "nearest_neighbor" + coord_trans = "asymmetric" else: - new_h = input_dim[2] * scale - new_w = input_dim[3] * scale - b_np = tvm.topi.testing.bilinear_resize_python(a_np, (new_h, new_w), "NCHW") + method = "linear" + coord_trans = "align_corners" + + b_np = tvm.topi.testing.resize2d_python(a_np, (scale, scale), "NCHW", method, coord_trans) input = [("input", datatypes.Array(*input_dim))] output = [("output", datatypes.Array(*b_np.shape))] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 52c3346e5807..3c1098c2c1cd 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1366,11 +1366,12 @@ def verify_upsample3d_trilinear(): y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear") scales = [1.0, 1.0, 2.0, 2.0, 2.0] in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = tvm.topi.testing.trilinear_resize3d_python( + out_array = tvm.topi.testing.resize3d_python( in_array, - (3 * scale, 3 * scale, 3 * scale), + (scale, scale, scale), "NCDHW", - coordinate_transformation_mode="half_pixel", + "linear", + coordinate_transformation_mode="asymmetric", ) ref_array = np.array(scales) @@ -3548,7 +3549,7 @@ def test_gru(): @tvm.testing.uses_gpu def test_resize(): - def verify(ishape, oshape, scales, mode, coord_trans): + def verify(ishape, oshape, scales, mode, coord_trans="asymmetric", alpha=0.5, exclude=False): nodes = [ make_constant_node("roi", onnx.TensorProto.FLOAT, (0,), []), make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), @@ -3566,6 +3567,8 @@ def verify(ishape, oshape, scales, mode, coord_trans): outputs=["Y"], mode=mode, coordinate_transformation_mode=coord_trans, + cubic_coeff_a=alpha, + exclude_outside=exclude, ) ) @@ -3582,29 +3585,69 @@ def verify(ishape, oshape, scales, mode, coord_trans): verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, freeze_params=True) - # upsampling - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 64, 64], [], "linear", "half_pixel") - - # downsampling - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "asymmetric") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "align_corners") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "nearest", "half_pixel") - verify([1, 16, 32, 32], [1, 16, 16, 16], [], "linear", "half_pixel") - - # scales are specified instead of sizes - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "asymmetric") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "nearest", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 2, 2], "linear", "align_corners") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel") - verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "nearest", "half_pixel") + for ndim in [1, 2, 3]: + method = "nearest" + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method, coord_trans) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method, coord_trans) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method, coord_trans) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method, coord_trans) + + if ndim == 2: + ## TODO(mbrookhart): ONNX Runtime in CI only supports 2D linear resize + ## Remove this condition when updating CI + method = "linear" + # upsampling + verify([1, 16] + [32] * ndim, [1, 16] + [64] * ndim, [], method) + # downsampling + verify([1, 16] + [32] * ndim, [1, 16] + [16] * ndim, [], method) + # scales are specified instead of sizes + verify([1, 16] + [32] * ndim, [], [1, 1] + [0.5] * ndim, method) + verify([1, 16] + [32] * ndim, [], [1, 1] + [2] * ndim, method) + + if ndim == 2: + # ONNX Runtime only supports cubic interpolation for 2D images + method = "cubic" + for alpha in [0.5, 0.75]: + for exclude in [True, False]: + # upsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [64] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # downsampling + verify( + [1, 16] + [32] * ndim, + [1, 16] + [16] * ndim, + [], + method, + alpha=alpha, + exclude=exclude, + ) + # scales are specified instead of sizes + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [0.5] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) + verify( + [1, 16] + [32] * ndim, + [], + [1, 1] + [2] * ndim, + method, + alpha=alpha, + exclude=exclude, + ) def verify_opset_10(ishape, scales, mode): nodes = [ diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index dca5dd6d4384..a6ea609be1e2 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -40,12 +40,13 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa (n, h, w, c) = dshape x_data = np.random.uniform(size=(n, h, w, c)).astype("float32") - if method == "nearest_neighbor": - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout) - else: - ref_res = tvm.topi.testing.bilinear_resize_python( - x_data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref_res = tvm.topi.testing.resize2d_python( + x_data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) scale_w_var = relay.var("scale_h", relay.TensorType((), "float32")) @@ -87,7 +88,7 @@ def test_dyn_upsampling_infer_type_const(): @tvm.testing.uses_gpu def test_dyn_upsampling3d_run(): def verify_upsampling3d( - dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel" + dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="asymmetric" ): if layout == "NCDHW": @@ -98,16 +99,14 @@ def verify_upsampling3d( (n, d, h, w, c) = dshape x_data = np.random.uniform(size=(n, d, h, w, c)).astype("float32") - if method == "nearest_neighbor": - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d, scale_h, scale_w), layout - ) - else: - ref_res = tvm.topi.testing.trilinear_resize3d_python( - x_data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coord_trans, + ) + x = relay.Var("x", relay.TensorType(dshape, "float32")) scale_d_var = relay.var("scale_d", relay.TensorType((), "float32")) scale_h_var = relay.var("scale_h", relay.TensorType((), "float32")) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 78e2c232c08e..d3459afaab06 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -27,39 +27,40 @@ import tvm.testing -def test_resize_infer_type(): +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) size = relay.var("size", relay.TensorType((2,), "int8")) - z = relay.image.resize(x, size) + z = relay.image.resize2d(x, size) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") @tvm.testing.uses_gpu -def test_resize(): - def verify_resize(dshape, scale, method, layout): +def test_resize2d(): + def verify_resize2d(dshape, scale, method, layout): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) else: size = (dshape[2] * scale, dshape[3] * scale) size = np.array(size).astype("int64") x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) size_var = relay.var("size", relay.TensorType((2,), "int64")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) zz = run_infer_type(z) func = relay.Function([x, size_var], z) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) + for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) @@ -67,10 +68,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate()(x_data, size) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: - verify_resize((1, 4, 4, 4), 2, method, layout) - verify_resize((2, 8, 17, 20), 7, method, layout) + verify_resize2d((1, 4, 4, 4), 2, method, layout) + verify_resize2d((2, 8, 17, 20), 7, method, layout) if __name__ == "__main__": diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 13f5525bfee8..e94b5145ccc2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1275,7 +1275,7 @@ def test_any_ndarray_size(): verify_any_ndarray_size((1, 2, 3, 4)) -def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape): +def verify_any_resize2d(data_shape, scale, layout, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) @@ -1283,7 +1283,7 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap size = (data_shape[1] * scale, data_shape[2] * scale) else: size = (data_shape[2] * scale, data_shape[3] * scale) - y = relay.image.resize(data, size, layout) + y = relay.image.resize2d(data, size, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) check_result([data_np], mod, ref_out_shape, assert_shape=True) @@ -1291,14 +1291,14 @@ def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shap @tvm.testing.uses_gpu def test_any_resize(): - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 4, 4, 4), scale=2, layout="NHWC", static_data_shape=(1, 4, 4, 4), ref_out_shape=(1, 8, 8, 4), ) - verify_any_resize( + verify_any_resize2d( data_shape=(relay.Any(), 8, 17, 20), scale=3, layout="NCHW", diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 50fc0622ee6e..f05c5054415d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1448,13 +1448,15 @@ def get_shape(): align_corners=align_corners, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - ref = tvm.topi.testing.upsampling_python(data, (scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.bilinear_resize_python( - data, (int(round(h * scale_h)), int(round(w * scale_w))), layout - ) + ref = tvm.topi.testing.resize2d_python( + data, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "align_corners" if align_corners else "asymmetric", + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) @@ -1518,15 +1520,15 @@ def get_shape(): coordinate_transformation_mode=coordinate_transformation_mode, ) func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) - if method == "nearest_neighbor": - ref = tvm.topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout) - else: - ref = tvm.topi.testing.trilinear_resize3d_python( - data, - (int(round(d * scale_d)), int(round(h * scale_h)), int(round(w * scale_w))), - layout, - ) + ref = tvm.topi.testing.resize3d_python( + data, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + coordinate_transformation_mode, + ) for target, dev in tvm.testing.enabled_targets(): executor = relay.create_executor("graph", device=dev, target=target) out = executor.evaluate(func)(data) @@ -1535,9 +1537,9 @@ def get_shape(): @tvm.testing.uses_gpu def test_upsampling3d(): - _test_upsampling3d("NCDHW", "nearest_neighbor") + _test_upsampling3d("NCDHW", "nearest_neighbor", "asymmetric") _test_upsampling3d("NCDHW", "trilinear", "align_corners") - _test_upsampling3d("NDHWC", "nearest_neighbor") + _test_upsampling3d("NDHWC", "nearest_neighbor", "asymmetric") _test_upsampling3d("NDHWC", "trilinear", "align_corners") diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e27520339f36..d93de5419f56 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -26,23 +26,72 @@ from tvm.relay.testing import run_infer_type -def test_resize_infer_type(): +def test_resize1d_infer_type(): + n, c, w = te.size_var("n"), te.size_var("c"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + tw = te.var("tw") + z = relay.image.resize1d(x, (tw,)) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, tw), "int8") + + x = relay.var("x", relay.TensorType((n, c, w), "int8")) + z = relay.image.resize1d(x, (200,), "NCW", "linear", "align_corners") + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, 200), "int8") + + +@tvm.testing.uses_gpu +def test_resize1d(): + def verify_resize(dshape, scale, method, layout, coord_trans): + if layout == "NWC": + size = (dshape[1] * scale,) + else: + size = (dshape[2] * scale,) + + x_data = np.random.uniform(size=dshape).astype("float32") + + ref_res = tvm.topi.testing.resize1d_python(x_data, (scale,), layout, method, coord_trans) + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.image.resize1d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, device=dev, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) + + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NWC", "NCW"]: + verify_resize((1, 4, 4), 2, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((2, 8, 17), 3, method, layout, coord_trans) + verify_resize((3, 4, 5), 5, method, layout, coord_trans) + + +def test_resize2d_infer_type(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) th, tw = te.var("th"), te.var("tw") - z = relay.image.resize(x, (th, tw)) + z = relay.image.resize2d(x, (th, tw)) zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - z = relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners") + z = relay.image.resize2d(x, (100, 200), "NCHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): def verify_resize(dshape, scale, method, layout, coord_trans): if layout == "NHWC": size = (dshape[1] * scale, dshape[2] * scale) @@ -51,25 +100,25 @@ def verify_resize(dshape, scale, method, layout, coord_trans): x_data = np.random.uniform(size=dshape).astype("float32") - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout, coord_trans) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.image.resize(x, size, layout, method, coordinate_transformation_mode=coord_trans) + z = relay.image.resize2d( + x, size, layout, method, coordinate_transformation_mode=coord_trans + ) assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") func = relay.Function([x], z) - for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) - for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout, coord_trans) verify_resize((2, 8, 17, 20), 3, method, layout, coord_trans) @@ -92,7 +141,7 @@ def test_resize3d_infer_type(): assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8") x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) - z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners") + z = relay.image.resize3d(x, (10, 10, 20), "NCDHW", "linear", "align_corners") assert "size=" in z.astext() zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8") @@ -107,10 +156,9 @@ def verify_resize(dshape, scale, method, layout): size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale) x_data = np.random.uniform(size=dshape).astype("float32") - if method == "trilinear": - ref_res = tvm.topi.testing.trilinear_resize3d_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout) + ref_res = tvm.topi.testing.resize3d_python( + x_data, (scale, scale, scale), layout, method, "align_corners" + ) x = relay.var("x", relay.TensorType(dshape, "float32")) z = relay.image.resize3d(x, size, layout, method, "align_corners") assert "size=" in z.astext() @@ -123,9 +171,10 @@ def verify_resize(dshape, scale, method, layout): op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) - for method in ["trilinear", "nearest_neighbor"]: - for layout in ["NDHWC", "NCDHW"]: - verify_resize((1, 4, 4, 4, 4), 2, method, layout) + for method in ["nearest_neighbor", "linear", "cubic"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NDHWC", "NCDHW"]: + verify_resize((1, 4, 4, 4, 4), 2, method, layout) @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 88590c946e88..fafab3ee3584 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,24 +1797,24 @@ def expected(): _test_conv_reduce_convert_layout2() -def test_image_resize_convert_layout(): +def test_image_resize2d_convert_layout(): def _test_image_resize_convert_layout_nchw_to_nhwc(): def before(): x = relay.var("x", shape=(1, 2, 4, 4)) - y = relay.image.resize(x, (8, 8)) + y = relay.image.resize2d(x, (8, 8)) y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 2, 4, 4)) x = relay.layout_transform(x, "NCHW", "NHWC") - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.layout_transform(y, "NHWC", "NCHW") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NHWC"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1822,20 +1822,20 @@ def expected(): def _test_image_resize_convert_layout_nhwc_to_nchw(): def before(): x = relay.var("x", shape=(1, 4, 4, 2)) - y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.image.resize2d(x, (8, 8), layout="NHWC") y = relay.Function([x], y) return y def expected(): x = relay.var("x", shape=(1, 4, 4, 2)) x = relay.layout_transform(x, "NHWC", "NCHW") - y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.image.resize2d(x, (8, 8), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(relay.analysis.free_vars(y), y) return y a = before() - a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NCHW"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -1844,7 +1844,7 @@ def expected(): _test_image_resize_convert_layout_nhwc_to_nchw() -def test_conv_image_resize_convert_layout(): +def test_conv_image_resize2d_convert_layout(): """Check that layout transforms are propagated through image resize.""" def before(): @@ -1859,7 +1859,7 @@ def before(): data_layout="NHWC", kernel_layout="HWIO", ) - y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.image.resize2d(y, (112, 112), layout="NHWC") y = relay.Function(analysis.free_vars(y), y) return y @@ -1869,7 +1869,7 @@ def expected(): x = relay.layout_transform(x, "NHWC", "NCHW") w = relay.layout_transform(w, "HWIO", "OIHW") y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) - y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.image.resize2d(y, (112, 112), layout="NCHW") y = relay.layout_transform(y, "NCHW", "NHWC") y = relay.Function(analysis.free_vars(y), y) return y diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 9f7f3deebeb8..962b7bebb12b 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -248,7 +248,7 @@ def verify_ones_zeros(shape, dtype): @tvm.testing.uses_gpu -def test_dynamic_to_static_resize(): +def test_dynamic_to_static_resize2d(): def verify_resize(shape, scale, method, layout): if layout == "NHWC": size = (shape[1] * scale, shape[2] * scale) @@ -258,7 +258,7 @@ def verify_resize(shape, scale, method, layout): x = relay.var("x", relay.TensorType(shape, "float32")) size_var = relay.const(np.array(size).astype("float32")) coord_trans = "asymmetric" if method == "nearest_neighbor" else "align_corners" - z = relay.image.resize( + z = relay.image.resize2d( x, size_var, layout, method, coordinate_transformation_mode=coord_trans ) @@ -267,17 +267,14 @@ def verify_resize(shape, scale, method, layout): zz = func2.body assert isinstance(zz, relay.Call) - assert zz.op == relay.op.get("image.resize") + assert zz.op == relay.op.get("image.resize2d") x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale, scale), layout, method, coord_trans + ) - if method == "bilinear": - ref_res = tvm.topi.testing.bilinear_resize_python(x_data, size, layout) - else: - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale, scale), layout) - verify_func(func2, [x_data], ref_res, rtol=1e-4, atol=1e-6) - - for method in ["bilinear", "nearest_neighbor"]: + for method in ["linear", "nearest_neighbor"]: for layout in ["NCHW", "NHWC"]: verify_resize((1, 4, 4, 4), 2, method, layout) @@ -347,7 +344,9 @@ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype): assert zz.op == relay.op.get("nn.upsampling") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW") + ref_res = tvm.topi.testing.resize2d_python( + x_data, (scale_h_val, scale_w_val), "NCHW", "nearest_neighbor", "asymmetric" + ) verify_func(func2, [x_data], ref_res) verify_upsampling((1, 16, 32, 32), 2, 2, "int8") @@ -371,8 +370,12 @@ def verify_upsampling3d(data_shape, scale_d_val, scale_h_val, scale_w_val, dtype assert zz.op == relay.op.get("nn.upsampling3d") x_data = np.random.uniform(size=data_shape).astype(dtype) - ref_res = tvm.topi.testing.upsampling3d_python( - x_data, (scale_d_val, scale_h_val, scale_w_val), "NCDHW" + ref_res = tvm.topi.testing.resize3d_python( + x_data, + (scale_d_val, scale_h_val, scale_w_val), + "NCDHW", + "nearest_neighbor", + "asymmetric", ) verify_func(func2, [x_data], ref_res) diff --git a/tests/python/topi/python/test_topi_image.py b/tests/python/topi/python/test_topi_image.py index 2730783907fd..fe7fba52f1ee 100644 --- a/tests/python/topi/python/test_topi_image.py +++ b/tests/python/topi/python/test_topi_image.py @@ -24,7 +24,7 @@ from tvm.contrib.pickle_memoize import memoize -def verify_resize( +def verify_resize2d( batch, in_channel, in_height, @@ -33,7 +33,7 @@ def verify_resize( out_width, layout="NCHW", coord_trans="align_corners", - method="bilinear", + method="linear", ): if layout == "NCHW": A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="float32") @@ -47,24 +47,16 @@ def verify_resize( a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError("Layout not supported {} ".format(layout)) - B = topi.image.resize( + B = topi.image.resize2d( A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method, ) - if method == "bilinear": - b_np = tvm.topi.testing.bilinear_resize_python( - a_np, (out_height, out_width), layout, coord_trans - ) - else: - # TODO: Nearest neighbor case doesn't do anything with coordinate transform mode, and also - # nearest_neighbors and align_corners combination in topi doesn't match the output of this - # function. - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize2d_python(a_np, (scale_h, scale_w), layout, method, coord_trans) def check_target(target, dev): print("Running on target: %s" % target) @@ -82,19 +74,21 @@ def check_target(target, dev): @tvm.testing.uses_gpu -def test_resize(): +def test_resize2d(): # Scale NCHW - verify_resize(4, 16, 32, 32, 50, 50, "NCHW") + verify_resize2d(4, 16, 32, 32, 50, 50, "NCHW") # Scale NCHW + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NCHW") + verify_resize2d(6, 32, 64, 64, 20, 20, "NCHW") # Scale NHWC - verify_resize(4, 16, 32, 32, 50, 50, "NHWC") + verify_resize2d(4, 16, 32, 32, 50, 50, "NHWC") # Scale NHWC + Align Corners - verify_resize(6, 32, 64, 64, 20, 20, "NHWC") - for method in ["nearest_neighbor", "bilinear"]: - for coord_trans in ["asymmetric"]: # TOPI testing function only support asymmetric - for layout in ["NCHW", "NHWC"]: - verify_resize(4, 16, 32, 32, 50, 50, layout, coord_trans, method=method) + verify_resize2d(6, 32, 64, 64, 20, 20, "NHWC") + for layout in ["NCHW", "NHWC"]: + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "align_corners", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="nearest_neighbor") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="linear") + verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") def verify_resize3d( @@ -107,8 +101,8 @@ def verify_resize3d( out_height, out_width, layout="NCDHW", - coordinate_transformation_mode="half_pixel", - method="trilinear", + coordinate_transformation_mode="asymmetric", + method="linear", ): if layout == "NCDHW": A = te.placeholder( @@ -139,18 +133,14 @@ def verify_resize3d( method=method, ) - if method == "trilinear": - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, (out_depth, out_height, out_width), layout, coordinate_transformation_mode - ) - else: - scale_d = out_depth / in_depth - scale_h = out_height / in_height - scale_w = out_width / in_width - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + scale_d = out_depth / in_depth + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = tvm.topi.testing.resize3d_python( + a_np, (scale_d, scale_h, scale_w), layout, method, coordinate_transformation_mode + ) def check_target(target, dev): - print("Running on target: %s" % target) with tvm.target.Target(target): s = tvm.topi.testing.get_injective_schedule(target)(B) a = tvm.nd.array(a_np, dev) @@ -167,16 +157,10 @@ def check_target(target, dev): @tvm.testing.uses_gpu def test_resize3d(): # Trilinear - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW") - verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "align_corners") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NCDHW", "asymmetric") - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, "NDHWC", "asymmetric") - - # Nearest neighbor - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NCDHW", method="nearest_neighbor") - verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, "NDHWC", method="nearest_neighbor") + for method in ["nearest_neighbor", "linear"]: + for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: + for layout in ["NCDHW", "NDHWC"]: + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, layout, coord_trans, method) @tvm.testing.uses_gpu diff --git a/tests/python/topi/python/test_topi_upsampling.py b/tests/python/topi/python/test_topi_upsampling.py index 0ab0e64af4c7..7793417a9a2b 100644 --- a/tests/python/topi/python/test_topi_upsampling.py +++ b/tests/python/topi/python/test_topi_upsampling.py @@ -78,11 +78,13 @@ def verify_upsampling( B = topi.nn.upsampling(A, scale_h, scale_w, layout=layout, method=method, align_corners=False) - if method == "bilinear": - out_size = (int(round(in_height * scale_h)), int(round(in_width * scale_w))) - b_np = tvm.topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric") - else: - b_np = tvm.topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize2d_python( + a_np, + (scale_h, scale_w), + layout, + method[2:] if method[0:2] == "bi" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target) @@ -213,20 +215,16 @@ def verify_upsampling3d( scale_w, layout=layout, method=method, - coordinate_transformation_mode="half_pixel", + coordinate_transformation_mode="asymmetric", ) - if method == "trilinear": - out_size = ( - int(round(in_depth * scale_d)), - int(round(in_height * scale_h)), - int(round(in_width * scale_w)), - ) - b_np = tvm.topi.testing.trilinear_resize3d_python( - a_np, out_size, layout, coordinate_transformation_mode="half_pixel" - ) - else: - b_np = tvm.topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + b_np = tvm.topi.testing.resize3d_python( + a_np, + (scale_d, scale_h, scale_w), + layout, + method[3:] if method[0:3] == "tri" else method, + "asymmetric", + ) def check_target(target, dev): print("Running on target: %s" % target)