Skip to content

Commit

Permalink
[Unity] Relax op: image (#13994)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the image operators.
  • Loading branch information
MasterJH5574 authored and tqchen committed Mar 13, 2023
1 parent 1653505 commit 2d676d4
Show file tree
Hide file tree
Showing 11 changed files with 711 additions and 0 deletions.
81 changes: 81 additions & 0 deletions include/tvm/relax/attrs/image.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/image.h
* \brief Attributes for image operators.
*/
#ifndef TVM_RELAX_ATTRS_IMAGE_H_
#define TVM_RELAX_ATTRS_IMAGE_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in image resize2d operator */
struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
Array<FloatImm> roi;
String layout;
String method;
String coordinate_transformation_mode;
String rounding_method;
double cubic_alpha;
int cubic_exclude;
double extrapolation_value;
DataType out_dtype;

TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") {
TVM_ATTR_FIELD(roi).describe(
"Region of Interest for coordinate transformation mode 'tf_crop_and_resize'");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).describe(
"Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"linear - Bilinear Interpolation"
"cubic - Bicubic Interpolation");
TVM_ATTR_FIELD(coordinate_transformation_mode)
.describe(
"Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
TVM_ATTR_FIELD(rounding_method)
.describe(
"indicates how to find the \"nearest\" pixel in nearest_neighbor method"
"Available options are round, floor, and ceil.");
TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic Interpolation");
TVM_ATTR_FIELD(cubic_exclude)
.describe("Flag to exclude exterior of the image during bicubic interpolation");
TVM_ATTR_FIELD(extrapolation_value)
.describe("Value to return when roi is outside of the image");
TVM_ATTR_FIELD(out_dtype).describe(
"The dtype of the output tensor. It it is not specified, the output will have the same "
"dtype as input if not specified.");
}
}; // struct Resize2dAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_IMAGE_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from .op_attrs import *
from .set import *
from . import builtin
from . import image
from . import memory
19 changes: 19 additions & 0 deletions python/tvm/relax/op/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Image operators."""
from .image import *
20 changes: 20 additions & 0 deletions python/tvm/relax/op/image/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relax.op.image", __name__)
128 changes: 128 additions & 0 deletions python/tvm/relax/op/image/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Image operators."""
from typing import Optional, Tuple, Union

from tvm import DataType
from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ...expr import Expr, ShapeExpr


PrimExprLike = Union[int, PrimExpr]


def resize2d(
data: Expr,
size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]],
roi: Optional[Union[float, Tuple[float]]] = None,
layout: str = "NCHW",
method: str = "linear",
coordinate_transformation_mode: str = "half_pixel",
rounding_method: str = "round",
cubic_alpha: float = -0.5,
cubic_exclude: int = 0,
extrapolation_value: float = 0.0,
out_dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
"""Image resize2d operator.
This operator takes data as input and does 2D scaling to the given scale factor.
In the default case, where the data_layout is `NCHW`
with data of shape (n, c, h, w)
out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating the out value
and method can be one of ("linear", "nearest_neighbor", "cubic")
Parameters
----------
data : relax.Expr
The input data to the operator.
size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]]
The out size to which the image will be resized.
If specified as a list, it is required to have length either 1 or 2.
If specified as an Expr, it is required to have ndim 2.
roi: Optional[Union[float, Tuple[float]]]
The region of interest for cropping the input image. Expected to be of
size 4, and format [start_h, start_w, end_h, end_w].
Only used if coordinate_transformation_mode is tf_crop_and_resize.
layout : str
Layout of the input.
method : str
Scale method to used [nearest_neighbor, linear, cubic].
coordinate_transformation_mode : str
Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor. Definitions can be found
in topi/image/resize.py.
[half_pixel, align_corners, asymmetric, pytorch_half_pixel,
tf_half_pixel_for_nn, and tf_crop_and_resize].
rounding_method: str
indicates how to find the "nearest" pixel in nearest_neighbor method
[round, floor, ceil]
cubic_alpha: float
Spline Coefficient for bicubic interpolation
cubic_exclude: int
Flag to exclude exterior of the image during bicubic interpolation
extrapolation_value: float
Fill value to use when roi is outside of the image
out_dtype : Optional[Union[str, DataType]]
The dtype of the output tensor.
It it is not specified, the output will have the same dtype as input if not specified.
Returns
-------
result: relax.Expr
The resized result.
"""
if roi is None:
roi = (0.0, 0.0, 0.0, 0.0) # type: ignore
elif isinstance(roi, float):
roi = (roi, roi, roi, roi) # type: ignore

if isinstance(size, (int, PrimExpr)):
size = (size, size)
if isinstance(size, tuple):
if len(size) == 1:
size = ShapeExpr([size[0], size[0]])
else:
size = ShapeExpr(size)

return _ffi_api.resize2d( # type: ignore
data,
size,
roi,
layout,
method,
coordinate_transformation_mode,
rounding_method,
cubic_alpha,
cubic_exclude,
extrapolation_value,
out_dtype,
)
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs):
"""Attributes used in strided_slice operator"""


@tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
class Resize2DAttrs(Attrs):
"""Attributes used in image resize2d operator"""


@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
class UniqueAttrs(Attrs):
"""Attributes used for the unique operator"""
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
builtin,
call_builtin_with_ctx,
call_tir,
image,
invoke_closure,
make_closure,
memory,
Expand Down Expand Up @@ -420,6 +421,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"func_ret_struct_info",
"func_ret_value",
"function",
"image",
"invoke_closure",
"make_closure",
"memory",
Expand Down
113 changes: 113 additions & 0 deletions src/relax/op/image/resize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file resize.cc
* \brief Image resize operators.
*/

#include "resize.h"

#include <utility>

namespace tvm {
namespace relax {

/* relax.resize2d */
TVM_REGISTER_NODE_TYPE(Resize2DAttrs);

Expr resize2d(Expr data, Expr size, Array<FloatImm> roi, String layout, String method,
String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
int cubic_exclude, double extrapolation_value, DataType out_dtype) {
ObjectPtr<Resize2DAttrs> attrs = make_object<Resize2DAttrs>();
attrs->roi = std::move(roi);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode);
attrs->rounding_method = std::move(rounding_method);
attrs->cubic_alpha = cubic_alpha;
attrs->cubic_exclude = cubic_exclude;
attrs->extrapolation_value = extrapolation_value;
attrs->out_dtype = out_dtype;

static const Op& op = Op::Get("relax.image.resize2d");
return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d);

StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1 && call->args.size() != 2) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "Resize2D expects either one or two arguments, while the given number of arguments is "
<< call->args.size());
}

const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
const auto* size_value = call->args[1].as<ShapeExprNode>();
if (data_sinfo == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Resize2D expects the input data to be a Tensor, while the given data is "
<< call->args[0]->GetTypeKey());
}
if (size_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "Resize2D expects the given output image size to be a Shape, while the given one is "
<< call->args[1]->GetTypeKey());
}
if (size_sinfo->ndim != 2) {
ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given output image size to "
"be a 2-dim shape, while the given one has ndim "
<< size_sinfo->ndim);
}

const auto* attrs = call->attrs.as<Resize2DAttrs>();
auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, //
/*tgt_layout=*/"NCHW", //
/*tensor_name=*/"data");

DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype;

Optional<ShapeExpr> data_shape =
CheckNdimPerLayoutAndGetShape(call, ctx, GetRef<TensorStructInfo>(data_sinfo), data_layout);
if (!data_shape.defined() || size_value == nullptr) {
return TensorStructInfo(out_dtype, data_layout.ndim());
}

Array<PrimExpr> data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values);
Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
out_NCHW_shape.Set(2, size_value->values[0]);
out_NCHW_shape.Set(3, size_value->values[1]);

Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}

TVM_REGISTER_OP("relax.image.resize2d")
.set_attrs_type<Resize2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("size", "Shape", "The output image shape.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D);

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 2d676d4

Please sign in to comment.