From 5723ebb2fd578e3ee95b2ca79ba678a7481d844b Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 14:57:05 -0500 Subject: [PATCH] [Unity] Relax op: image (#13994) This PR is about the high-level tensor computation operators in Relax. This PR includes the image operators. --- include/tvm/relax/attrs/image.h | 81 ++++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/image/__init__.py | 19 ++ python/tvm/relax/op/image/_ffi_api.py | 20 ++ python/tvm/relax/op/image/image.py | 128 +++++++++ python/tvm/relax/op/op_attrs.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/image/resize.cc | 113 ++++++++ src/relax/op/image/resize.h | 43 +++ tests/python/relax/test_op_image.py | 245 ++++++++++++++++++ .../relax/test_tvmscript_parser_op_image.py | 54 ++++ 11 files changed, 711 insertions(+) create mode 100644 include/tvm/relax/attrs/image.h create mode 100644 python/tvm/relax/op/image/__init__.py create mode 100644 python/tvm/relax/op/image/_ffi_api.py create mode 100644 python/tvm/relax/op/image/image.py create mode 100644 src/relax/op/image/resize.cc create mode 100644 src/relax/op/image/resize.h create mode 100644 tests/python/relax/test_op_image.py create mode 100644 tests/python/relax/test_tvmscript_parser_op_image.py diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h new file mode 100644 index 000000000000..13463aaa4849 --- /dev/null +++ b/include/tvm/relax/attrs/image.h @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/image.h + * \brief Attributes for image operators. + */ +#ifndef TVM_RELAX_ATTRS_IMAGE_H_ +#define TVM_RELAX_ATTRS_IMAGE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array roi; + String layout; + String method; + String coordinate_transformation_mode; + String rounding_method; + double cubic_alpha; + int cubic_exclude; + double extrapolation_value; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") { + TVM_ATTR_FIELD(roi).describe( + "Region of Interest for coordinate transformation mode 'tf_crop_and_resize'"); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .describe("Value to return when roi is outside of the image"); + TVM_ATTR_FIELD(out_dtype).describe( + "The dtype of the output tensor. It it is not specified, the output will have the same " + "dtype as input if not specified."); + } +}; // struct Resize2dAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_IMAGE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index da29c3715dec..38573512691c 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -26,4 +26,5 @@ from .op_attrs import * from .set import * from . import builtin +from . import image from . import memory diff --git a/python/tvm/relax/op/image/__init__.py b/python/tvm/relax/op/image/__init__.py new file mode 100644 index 000000000000..f2552ad6ac51 --- /dev/null +++ b/python/tvm/relax/op/image/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Image operators.""" +from .image import * diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py new file mode 100644 index 000000000000..e666203ae7ff --- /dev/null +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py new file mode 100644 index 000000000000..562de5021d53 --- /dev/null +++ b/python/tvm/relax/op/image/image.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Image operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ...expr import Expr, ShapeExpr + + +PrimExprLike = Union[int, PrimExpr] + + +def resize2d( + data: Expr, + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]], + roi: Optional[Union[float, Tuple[float]]] = None, + layout: str = "NCHW", + method: str = "linear", + coordinate_transformation_mode: str = "half_pixel", + rounding_method: str = "round", + cubic_alpha: float = -0.5, + cubic_exclude: int = 0, + extrapolation_value: float = 0.0, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Image resize2d operator. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, size[0], size[1]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]] + The out size to which the image will be resized. + If specified as a list, it is required to have length either 1 or 2. + If specified as an Expr, it is required to have ndim 2. + + roi: Optional[Union[float, Tuple[float]]] + The region of interest for cropping the input image. Expected to be of + size 4, and format [start_h, start_w, end_h, end_w]. + Only used if coordinate_transformation_mode is tf_crop_and_resize. + + layout : str + Layout of the input. + + method : str + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : str + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. Definitions can be found + in topi/image/resize.py. + [half_pixel, align_corners, asymmetric, pytorch_half_pixel, + tf_half_pixel_for_nn, and tf_crop_and_resize]. + + rounding_method: str + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for bicubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + + extrapolation_value: float + Fill value to use when roi is outside of the image + + out_dtype : Optional[Union[str, DataType]] + The dtype of the output tensor. + It it is not specified, the output will have the same dtype as input if not specified. + + Returns + ------- + result: relax.Expr + The resized result. + """ + if roi is None: + roi = (0.0, 0.0, 0.0, 0.0) # type: ignore + elif isinstance(roi, float): + roi = (roi, roi, roi, roi) # type: ignore + + if isinstance(size, (int, PrimExpr)): + size = (size, size) + if isinstance(size, tuple): + if len(size) == 1: + size = ShapeExpr([size[0], size[0]]) + else: + size = ShapeExpr(size) + + return _ffi_api.resize2d( # type: ignore + data, + size, + roi, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + extrapolation_value, + out_dtype, + ) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 47c3b2879878..fb64443b7e09 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes used in image resize2d operator""" + + @tvm._ffi.register_object("relax.attrs.UniqueAttrs") class UniqueAttrs(Attrs): """Attributes used for the unique operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 537adec6154c..22b85f6f402f 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -35,6 +35,7 @@ builtin, call_builtin_with_ctx, call_tir, + image, invoke_closure, make_closure, memory, @@ -420,6 +421,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "func_ret_struct_info", "func_ret_value", "function", + "image", "invoke_closure", "make_closure", "memory", diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc new file mode 100644 index 000000000000..2711b3cc45f5 --- /dev/null +++ b/src/relax/op/image/resize.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.cc + * \brief Image resize operators. + */ + +#include "resize.h" + +#include + +namespace tvm { +namespace relax { + +/* relax.resize2d */ +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype) { + ObjectPtr attrs = make_object(); + attrs->roi = std::move(roi); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode); + attrs->rounding_method = std::move(rounding_method); + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->extrapolation_value = extrapolation_value; + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.image.resize2d"); + return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); + +StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1 && call->args.size() != 2) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects either one or two arguments, while the given number of arguments is " + << call->args.size()); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* size_sinfo = GetStructInfoAs(call->args[1]); + const auto* size_value = call->args[1].as(); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Resize2D expects the input data to be a Tensor, while the given data is " + << call->args[0]->GetTypeKey()); + } + if (size_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Resize2D expects the given output image size to be a Shape, while the given one is " + << call->args[1]->GetTypeKey()); + } + if (size_sinfo->ndim != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given output image size to " + "be a 2-dim shape, while the given one has ndim " + << size_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // + /*tgt_layout=*/"NCHW", // + /*tensor_name=*/"data"); + + DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; + + Optional data_shape = + CheckNdimPerLayoutAndGetShape(call, ctx, GetRef(data_sinfo), data_layout); + if (!data_shape.defined() || size_value == nullptr) { + return TensorStructInfo(out_dtype, data_layout.ndim()); + } + + Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); + Array out_NCHW_shape(data_NCHW_shape); + out_NCHW_shape.Set(2, size_value->values[0]); + out_NCHW_shape.Set(3, size_value->values[1]); + + Array out_shape = data2NCHW.BackwardShape(out_NCHW_shape); + return TensorStructInfo(ShapeExpr(out_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.image.resize2d") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("size", "Shape", "The output image shape.") + .set_attr("FInferStructInfo", InferStructInfoResize2D); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h new file mode 100644 index 000000000000..085a1cbc5d5f --- /dev/null +++ b/src/relax/op/image/resize.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.h + * \brief The functions to make Relax image resize operator calls. + */ + +#ifndef TVM_RELAX_OP_IMAGE_RESIZE_H_ +#define TVM_RELAX_OP_IMAGE_RESIZE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief Image resize2d operator. */ +Expr resize2d(Expr data, Expr size, Array roi, String layout, String method, + String coordinate_transformation_mode, String rounding_method, double cubic_alpha, + int cubic_exclude, double extrapolation_value, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_IMAGE_RESIZE_H_ diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py new file mode 100644 index 000000000000..b06b51a2a198 --- /dev/null +++ b/tests/python/relax/test_op_image.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + assert relax.op.image.resize2d(x, (28, 28)).op == Op.get("relax.image.resize2d") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_resize2d_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=4)) + x4 = relax.Var("x", R.Tensor("float32", ndim=5)) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor(ndim=4)) + x7 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 3, 28, 28), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28), + relax.TensorStructInfo((2, 3, 28, 28), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=28, layout="NHWC"), + relax.TensorStructInfo((2, 28, 28, 3), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=28, out_dtype="float16"), + relax.TensorStructInfo((2, 3, 28, 28), "float16"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=28, layout="NCHW16c"), + relax.TensorStructInfo((2, 4, 28, 28, 16), "float32"), + ) + _check_inference( + bb, relax.op.image.resize2d(x3, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x4, size=28, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, relax.op.image.resize2d(x5, size=28), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.image.resize2d(x6, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x6, size=28, out_dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x7, size=28), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_resize2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + oh = tir.Var("oh", "int64") + ow = tir.Var("ow", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=oh), relax.TensorStructInfo((n, c, oh, oh), "float32") + ) + _check_inference( + bb, + relax.op.image.resize2d(x0, size=(oh, ow)), + relax.TensorStructInfo((n, c, oh, ow), "float32"), + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=(oh, ow), layout="NCHW16c"), + relax.TensorStructInfo((n, c, oh, ow, 16), "float32"), + ) + + +def test_resize2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.image.resize2d(x0, size=32), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.image.resize2d(x1, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.image.resize2d(x2, size=32, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_resize2d_infer_struct_info_pool_size_var(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((30, 30))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + + _check_inference( + bb, + relax.op.image.resize2d(x0, s0), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, relax.op.image.resize2d(x0, s1), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + + +def test_resize2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.image.resize2d(x0, size=28), relax.TensorStructInfo((2, 3, 28, 28), "float16") + ) + _check_inference( + bb, relax.op.image.resize2d(x1, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int8") + ) + _check_inference( + bb, relax.op.image.resize2d(x2, size=28), relax.TensorStructInfo((2, 3, 28, 28), "int64") + ) + + +def test_resize2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x, size=28, layout="OIHW")) + + +def test_resize2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=28, layout="NCHW16c")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=28, layout="NCHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, size=28)) + + +def test_resize2d_wrong_pool_size_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + s0 = relax.ShapeExpr((3,)) + s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s5 = relax.Var("s", relax.ShapeStructInfo()) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, (3, 3, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, s5)) + + +def test_resize2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + s0 = relax.Var("s", R.Tensor((3, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x0, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x1, size=32)) + with pytest.raises(TVMError): + bb.normalize(relax.op.image.resize2d(x2, s0)) + with pytest.raises(TVMError): + relax.op.image.resize2d(x2, [30, 30]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py b/tests/python/relax/test_tvmscript_parser_op_image.py new file mode 100644 index 000000000000..a90da37812ef --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_image.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_resize2d(): + @R.function + def foo(x: R.Tensor((2, 14, 14, 3), "float32")) -> R.Tensor((2, 28, 28, 3), "float32"): + gv: R.Tensor((2, 28, 28, 3), "float32") = R.image.resize2d(x, size=(28, 28), layout="NHWC") + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 14, 14, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.image.resize2d(x, (28, 28), layout="NHWC")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()