Skip to content

Commit

Permalink
[Op] Add relax.clip Op (apache#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored and junrushao committed Feb 7, 2023
1 parent e4ae04e commit e810127
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
BaseFunc,
Binding,
)
from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo
from .struct_info import ShapeStructInfo, StructInfo, TensorStructInfo, PrimStructInfo
from .op.base import call_tir
from . import _ffi_api

Expand Down Expand Up @@ -256,6 +256,8 @@ def _convert_te_arg_helper(arg):
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
elif isinstance(arg.struct_info, PrimStructInfo):
return arg.value
elif isinstance(arg, (list, tvm.ir.Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relax/op/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Relax unary arithmetic operators."""
from . import _ffi_api
from ..expr import Expr
from ..utils import args_converter

###################### Arithmetic operators ######################

Expand Down Expand Up @@ -454,6 +455,29 @@ def tanh(x: Expr) -> Expr:
return _ffi_api.tanh(x) # type: ignore


@args_converter.auto
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
"""Clips tensor values to a specified min and max.
Parameters
----------
x : relax.Expr
The input data
min : relax.Expr
The minimum value
max : relax.Expr
The maximum value
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.clip(x, min, max) # type: ignore


###################### Check operators ######################


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
call_builtin_with_ctx,
call_tir,
ceil,
clip,
concat,
cos,
cosh,
Expand Down Expand Up @@ -465,6 +466,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"call_tir",
"call_builtin_with_ctx",
"ceil",
"clip",
"cos",
"cosh",
"concat",
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
# pylint: disable=redefined-builtin,unused-argument
import tvm
from tvm import te
from . import tag
from . import cpp
from tvm.tir import PrimExpr

from . import cpp, tag
from .utils import get_const_tuple


Expand Down Expand Up @@ -633,8 +634,16 @@ def clip(x, a_min, a_max):

def _compute(*indices):
value = x(*indices)
const_min = tvm.tir.const(a_min, value.dtype)
const_max = tvm.tir.const(a_max, value.dtype)
const_min = (
tvm.tir.Cast(value.dtype, a_min)
if isinstance(a_min, PrimExpr)
else tvm.tir.const(a_min, value.dtype)
)
const_max = (
tvm.tir.Cast(value.dtype, a_max)
if isinstance(a_max, PrimExpr)
else tvm.tir.const(a_max, value.dtype)
)
return tvm.te.max(tvm.te.min(value, const_max), const_min)

return te.compute(x.shape, _compute)
Expand Down
25 changes: 24 additions & 1 deletion src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,30 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx
}

/*!
* \brief Infer the struct info for unary arithmetic elementwise ops. It's also
* \brief Infer the struct info by returning the struct info of the input argument.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
* \tparam arg_index The index of the argument to infer the output dtype from.
* \return The inferred struct info.
*/
template <int arg_index>
StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
int n_input = op->arguments.size();
if (static_cast<int>(call->args.size()) != n_input) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " op should have " << n_input << " arguments");
}
if (arg_index >= n_input) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " op has only " << n_input
<< "arguments, but try to get the arg with index " << arg_index);
}
return GetStructInfo(call->args[arg_index]);
}

/*!
* \brief Infer the struct info for unary arithmetic elementwise ops. It's also
* used in some NN operators.
* \param call The context Call to the operator.
* \param ctx The error reporting context.
Expand Down
23 changes: 23 additions & 0 deletions src/relax/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "unary.h"

#include <utility>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -58,6 +60,27 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true);

// relax.clip
TVM_REGISTER_OP("relax.clip")
.set_num_inputs(3)
.add_argument("x", "Tensor", "The input tensor.")
.add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to")
.add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnStructInfoFromArg<0>);

Expr clip(Expr x, Expr min, Expr max) {
CHECK(min->IsInstance<PrimValueNode>())
<< "The argument `min` of relax.clip is expected to be a PrimValue, but got"
<< min->GetTypeKey();
CHECK(max->IsInstance<PrimValueNode>())
<< "The argument `max` of relax.clip is expected to be a PrimValue, but got"
<< max->GetTypeKey();
static const Op& op = Op::Get("relax.clip");
return Call(op, {std::move(x), std::move(min), std::move(max)});
}

TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip);

/***************** Check operators *****************/

RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite);
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ Expr tan(Expr x);
/*! \brief Compute element-wise tanh of data. */
Expr tanh(Expr x);

/*! \brief Clips tensor values to a specified min and max. */
Expr clip(Expr x, Expr min, Expr max);

/***************** Check operators *****************/

/*! \brief Check if input value is finite. */
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,29 @@ def test_emit_te_extern():
assert call_node.sinfo_args[0].shape[1] == n


def test_emit_te_prime_value():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
x = rx.Var("x", R.Tensor([n, m], "float32"))
a_min = rx.PrimValue(0)
a_max = rx.PrimValue(6)

with bb.function("rx_clip", [x]):
out = bb.emit_te(topi.clip, x, a_min, a_max)
bb.emit_func_output(out)

rx_func = bb.get()["rx_clip"]

# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert len(rx_func.body.blocks) == 1
call_node = rx_func.body.blocks[0].bindings[0].value
assert isinstance(call_node, rx.Call)
assert call_node.op == relay.op.get("relax.call_tir")
assert len(call_node.args) == 2
assert call_node.args[1][0] == x


def test_emit_tuple_get_item():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_op_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_op_correctness():
assert relax.op.sqrt(x).op == Op.get("relax.sqrt")
assert relax.op.tan(x).op == Op.get("relax.tan")
assert relax.op.tanh(x).op == Op.get("relax.tanh")
assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip")


def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
Expand Down Expand Up @@ -174,5 +175,29 @@ def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable
bb.normalize(unary_arith_op(x1))


def test_clip_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=3))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3)))
x4 = relax.Var("x", R.Tensor())

_check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3))
_check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32"))
_check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype=""))
_check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype=""))

# Symbolic
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
x5 = relax.Var("x", R.Tensor((m, n), "float32"))
x6 = relax.Var("x", R.Tensor((4, n), "float32"))

_check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((m, n), "float32"))
_check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorStructInfo((4, n), "float32"))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit e810127

Please sign in to comment.