-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR is about the high-level tensor computation operators in Relax. This PR includes the image operators.
- Loading branch information
1 parent
1653505
commit 2d676d4
Showing
11 changed files
with
711 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/relax/attrs/image.h | ||
* \brief Attributes for image operators. | ||
*/ | ||
#ifndef TVM_RELAX_ATTRS_IMAGE_H_ | ||
#define TVM_RELAX_ATTRS_IMAGE_H_ | ||
|
||
#include <tvm/relax/expr.h> | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
/*! \brief Attributes used in image resize2d operator */ | ||
struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> { | ||
Array<FloatImm> roi; | ||
String layout; | ||
String method; | ||
String coordinate_transformation_mode; | ||
String rounding_method; | ||
double cubic_alpha; | ||
int cubic_exclude; | ||
double extrapolation_value; | ||
DataType out_dtype; | ||
|
||
TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") { | ||
TVM_ATTR_FIELD(roi).describe( | ||
"Region of Interest for coordinate transformation mode 'tf_crop_and_resize'"); | ||
TVM_ATTR_FIELD(layout).describe( | ||
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." | ||
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" | ||
"dimensions respectively. Resize is applied on the 'H' and" | ||
"'W' dimensions."); | ||
TVM_ATTR_FIELD(method).describe( | ||
"Specify the mode to use for scaling." | ||
"nearest_neighbor - Nearest Neighbor" | ||
"linear - Bilinear Interpolation" | ||
"cubic - Bicubic Interpolation"); | ||
TVM_ATTR_FIELD(coordinate_transformation_mode) | ||
.describe( | ||
"Describes how to transform the coordinate in the resized tensor" | ||
"to the coordinate in the original tensor." | ||
"Refer to the ONNX Resize operator specification for details" | ||
"Available options are half_pixel, align_corners and asymmetric"); | ||
TVM_ATTR_FIELD(rounding_method) | ||
.describe( | ||
"indicates how to find the \"nearest\" pixel in nearest_neighbor method" | ||
"Available options are round, floor, and ceil."); | ||
TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic Interpolation"); | ||
TVM_ATTR_FIELD(cubic_exclude) | ||
.describe("Flag to exclude exterior of the image during bicubic interpolation"); | ||
TVM_ATTR_FIELD(extrapolation_value) | ||
.describe("Value to return when roi is outside of the image"); | ||
TVM_ATTR_FIELD(out_dtype).describe( | ||
"The dtype of the output tensor. It it is not specified, the output will have the same " | ||
"dtype as input if not specified."); | ||
} | ||
}; // struct Resize2dAttrs | ||
|
||
} // namespace relax | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAX_ATTRS_IMAGE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,4 +26,5 @@ | |
from .op_attrs import * | ||
from .set import * | ||
from . import builtin | ||
from . import image | ||
from . import memory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=wildcard-import | ||
"""Image operators.""" | ||
from .image import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Constructor APIs""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("relax.op.image", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Image operators.""" | ||
from typing import Optional, Tuple, Union | ||
|
||
from tvm import DataType | ||
from tvm.ir.expr import PrimExpr | ||
|
||
from . import _ffi_api | ||
from ...expr import Expr, ShapeExpr | ||
|
||
|
||
PrimExprLike = Union[int, PrimExpr] | ||
|
||
|
||
def resize2d( | ||
data: Expr, | ||
size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]], | ||
roi: Optional[Union[float, Tuple[float]]] = None, | ||
layout: str = "NCHW", | ||
method: str = "linear", | ||
coordinate_transformation_mode: str = "half_pixel", | ||
rounding_method: str = "round", | ||
cubic_alpha: float = -0.5, | ||
cubic_exclude: int = 0, | ||
extrapolation_value: float = 0.0, | ||
out_dtype: Optional[Union[str, DataType]] = None, | ||
) -> Expr: | ||
"""Image resize2d operator. | ||
This operator takes data as input and does 2D scaling to the given scale factor. | ||
In the default case, where the data_layout is `NCHW` | ||
with data of shape (n, c, h, w) | ||
out will have a shape (n, c, size[0], size[1]) | ||
method indicates the algorithm to be used while calculating the out value | ||
and method can be one of ("linear", "nearest_neighbor", "cubic") | ||
Parameters | ||
---------- | ||
data : relax.Expr | ||
The input data to the operator. | ||
size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]] | ||
The out size to which the image will be resized. | ||
If specified as a list, it is required to have length either 1 or 2. | ||
If specified as an Expr, it is required to have ndim 2. | ||
roi: Optional[Union[float, Tuple[float]]] | ||
The region of interest for cropping the input image. Expected to be of | ||
size 4, and format [start_h, start_w, end_h, end_w]. | ||
Only used if coordinate_transformation_mode is tf_crop_and_resize. | ||
layout : str | ||
Layout of the input. | ||
method : str | ||
Scale method to used [nearest_neighbor, linear, cubic]. | ||
coordinate_transformation_mode : str | ||
Describes how to transform the coordinate in the resized tensor | ||
to the coordinate in the original tensor. Definitions can be found | ||
in topi/image/resize.py. | ||
[half_pixel, align_corners, asymmetric, pytorch_half_pixel, | ||
tf_half_pixel_for_nn, and tf_crop_and_resize]. | ||
rounding_method: str | ||
indicates how to find the "nearest" pixel in nearest_neighbor method | ||
[round, floor, ceil] | ||
cubic_alpha: float | ||
Spline Coefficient for bicubic interpolation | ||
cubic_exclude: int | ||
Flag to exclude exterior of the image during bicubic interpolation | ||
extrapolation_value: float | ||
Fill value to use when roi is outside of the image | ||
out_dtype : Optional[Union[str, DataType]] | ||
The dtype of the output tensor. | ||
It it is not specified, the output will have the same dtype as input if not specified. | ||
Returns | ||
------- | ||
result: relax.Expr | ||
The resized result. | ||
""" | ||
if roi is None: | ||
roi = (0.0, 0.0, 0.0, 0.0) # type: ignore | ||
elif isinstance(roi, float): | ||
roi = (roi, roi, roi, roi) # type: ignore | ||
|
||
if isinstance(size, (int, PrimExpr)): | ||
size = (size, size) | ||
if isinstance(size, tuple): | ||
if len(size) == 1: | ||
size = ShapeExpr([size[0], size[0]]) | ||
else: | ||
size = ShapeExpr(size) | ||
|
||
return _ffi_api.resize2d( # type: ignore | ||
data, | ||
size, | ||
roi, | ||
layout, | ||
method, | ||
coordinate_transformation_mode, | ||
rounding_method, | ||
cubic_alpha, | ||
cubic_exclude, | ||
extrapolation_value, | ||
out_dtype, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file resize.cc | ||
* \brief Image resize operators. | ||
*/ | ||
|
||
#include "resize.h" | ||
|
||
#include <utility> | ||
|
||
namespace tvm { | ||
namespace relax { | ||
|
||
/* relax.resize2d */ | ||
TVM_REGISTER_NODE_TYPE(Resize2DAttrs); | ||
|
||
Expr resize2d(Expr data, Expr size, Array<FloatImm> roi, String layout, String method, | ||
String coordinate_transformation_mode, String rounding_method, double cubic_alpha, | ||
int cubic_exclude, double extrapolation_value, DataType out_dtype) { | ||
ObjectPtr<Resize2DAttrs> attrs = make_object<Resize2DAttrs>(); | ||
attrs->roi = std::move(roi); | ||
attrs->layout = std::move(layout); | ||
attrs->method = std::move(method); | ||
attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode); | ||
attrs->rounding_method = std::move(rounding_method); | ||
attrs->cubic_alpha = cubic_alpha; | ||
attrs->cubic_exclude = cubic_exclude; | ||
attrs->extrapolation_value = extrapolation_value; | ||
attrs->out_dtype = out_dtype; | ||
|
||
static const Op& op = Op::Get("relax.image.resize2d"); | ||
return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); | ||
|
||
StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { | ||
if (call->args.size() != 1 && call->args.size() != 2) { | ||
ctx->ReportFatal( | ||
Diagnostic::Error(call) | ||
<< "Resize2D expects either one or two arguments, while the given number of arguments is " | ||
<< call->args.size()); | ||
} | ||
|
||
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]); | ||
const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]); | ||
const auto* size_value = call->args[1].as<ShapeExprNode>(); | ||
if (data_sinfo == nullptr) { | ||
ctx->ReportFatal(Diagnostic::Error(call) | ||
<< "Resize2D expects the input data to be a Tensor, while the given data is " | ||
<< call->args[0]->GetTypeKey()); | ||
} | ||
if (size_sinfo == nullptr) { | ||
ctx->ReportFatal( | ||
Diagnostic::Error(call) | ||
<< "Resize2D expects the given output image size to be a Shape, while the given one is " | ||
<< call->args[1]->GetTypeKey()); | ||
} | ||
if (size_sinfo->ndim != 2) { | ||
ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given output image size to " | ||
"be a 2-dim shape, while the given one has ndim " | ||
<< size_sinfo->ndim); | ||
} | ||
|
||
const auto* attrs = call->attrs.as<Resize2DAttrs>(); | ||
auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout, // | ||
/*tgt_layout=*/"NCHW", // | ||
/*tensor_name=*/"data"); | ||
|
||
DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : attrs->out_dtype; | ||
|
||
Optional<ShapeExpr> data_shape = | ||
CheckNdimPerLayoutAndGetShape(call, ctx, GetRef<TensorStructInfo>(data_sinfo), data_layout); | ||
if (!data_shape.defined() || size_value == nullptr) { | ||
return TensorStructInfo(out_dtype, data_layout.ndim()); | ||
} | ||
|
||
Array<PrimExpr> data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); | ||
Array<PrimExpr> out_NCHW_shape(data_NCHW_shape); | ||
out_NCHW_shape.Set(2, size_value->values[0]); | ||
out_NCHW_shape.Set(3, size_value->values[1]); | ||
|
||
Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape); | ||
return TensorStructInfo(ShapeExpr(out_shape), out_dtype); | ||
} | ||
|
||
TVM_REGISTER_OP("relax.image.resize2d") | ||
.set_attrs_type<Resize2DAttrs>() | ||
.set_num_inputs(2) | ||
.add_argument("data", "Tensor", "The input tensor.") | ||
.add_argument("size", "Shape", "The output image shape.") | ||
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D); | ||
|
||
} // namespace relax | ||
} // namespace tvm |
Oops, something went wrong.