Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
SiriusNEO committed Jan 17, 2023
1 parent 99d5afc commit 4216bb1
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 80 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relax/op/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def cos(x: Expr) -> Expr:
return _ffi_api.cos(x) # type: ignore


def exp(x: Expr) -> Expr:
"""Compute element-wise exp of data.
Parameters
----------
x : relax.Expr
The input data
Returns
-------
result : relax.Expr
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return _ffi_api.exp(x) # type: ignore


def log(x: Expr) -> Expr:
"""Compute element-wise natural logarithm of the input data.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
divide,
equal,
ewise_fma,
exp,
expand_dims,
flatten,
floor_divide,
Expand Down Expand Up @@ -413,6 +414,7 @@ def tuple(*fields: List[Expr]) -> Expr:
"emit_match_cast",
"equal",
"ewise_fma",
"exp",
"expand_dims",
"flatten",
"floor_divide",
Expand Down
1 change: 1 addition & 0 deletions src/relax/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace tvm {
namespace relax {

RELAX_REGISTER_UNARY_OP_INTERFACE(cos, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_OP_INTERFACE(exp, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_OP_INTERFACE(log, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_OP_INTERFACE(negative, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_OP_INTERFACE(sigmoid, /*require_float_dtype=*/true);
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ namespace relax {
*/
Expr cos(Expr x);

/*! \brief Compute element-wise exp of data. */
Expr exp(Expr x);

/*! \brief Compute element-wise natural logarithm of data. */
Expr log(Expr x);

Expand Down
3 changes: 2 additions & 1 deletion tests/python/relax/test_op_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_op_correctness():
assert relax.op.tanh(x).op == Op.get("relax.tanh")
assert relax.op.sqrt(x).op == Op.get("relax.sqrt")
assert relax.op.log(x).op == Op.get("relax.log")
assert relax.op.exp(x).op == Op.get("relax.exp")
assert relax.op.sigmoid(x).op == Op.get("relax.sigmoid")
assert relax.op.unique(x).op == Op.get("relax.unique")

Expand Down Expand Up @@ -74,7 +75,7 @@ def test_unary_arith_infer_struct_info_shape_var():
x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))

_check_inference(bb, relax.op.log(x0), relax.TensorStructInfo(s0, "float32"))
_check_inference(bb, relax.op.tanh(x1), relax.TensorStructInfo(s1, "float32"))
_check_inference(bb, relax.op.exp(x1), relax.TensorStructInfo(s1, "float32"))


def test_unary_arith_infer_struct_info_more_input_dtype():
Expand Down
117 changes: 38 additions & 79 deletions tests/python/relax/test_tvmscript_parser_op_arith_cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Union
from typing import Optional, Union, Callable

import tvm
import tvm.testing
Expand All @@ -34,124 +34,83 @@ def _check(
tvm.ir.assert_structural_equal(parsed, expect)


def test_relax_add():
def _test_unary(op_func: Callable):
@R.function
def foo(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")
) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.add(x, y)
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = op_func(x)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 1), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, y]):
gv = bb.emit(relax.op.add(x, y))
with bb.function("foo", [x]):
gv = bb.emit(op_func(x))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_subtract():
def _test_binary_arith(op_func: Callable):
@R.function
def foo(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")
) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.subtract(x, y)
gv: R.Tensor((2, 3), "float32") = op_func(x, y)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 1), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, y]):
gv = bb.emit(relax.op.subtract(x, y))
gv = bb.emit(op_func(x, y))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_floor_divide():
@R.function
def foo(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")
) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.floor_divide(x, y)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 1), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, y]):
gv = bb.emit(relax.op.floor_divide(x, y))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_sin():
@R.function
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.sin(x)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x]):
gv = bb.emit(relax.op.sin(x))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_sigmoid():
@R.function
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.sigmoid(x)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x]):
gv = bb.emit(relax.op.sigmoid(x))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_equal():
def _test_binary_cmp(op_func: Callable):
@R.function
def foo(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")
) -> R.Tensor((2, 3), "bool"):
gv: R.Tensor((2, 3), "bool") = R.equal(x, y)
gv: R.Tensor((2, 3), "bool") = op_func(x, y)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 1), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, y]):
gv = bb.emit(relax.op.equal(x, y))
gv = bb.emit(op_func(x, y))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


def test_relax_less():
@R.function
def foo(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32")
) -> R.Tensor((2, 3), "bool"):
gv: R.Tensor((2, 3), "bool") = R.less(x, y)
return gv

x = relax.Var("x", R.Tensor((2, 3), "float32"))
y = relax.Var("y", R.Tensor((2, 1), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, y]):
gv = bb.emit(relax.op.less(x, y))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])
def test_unary():
_test_unary(relax.op.cos)
_test_unary(relax.op.exp)
_test_unary(relax.op.log)
_test_unary(relax.op.negative)
_test_unary(relax.op.sigmoid)
_test_unary(relax.op.sin)
_test_unary(relax.op.sqrt)
_test_unary(relax.op.tanh)


def test_binary_arith():
_test_binary_arith(relax.op.add)
_test_binary_arith(relax.op.divide)
_test_binary_arith(relax.op.floor_divide)
_test_binary_arith(relax.op.multiply)
_test_binary_arith(relax.op.subtract)


def test_binary_cmp():
_test_binary_cmp(relax.op.equal)
_test_binary_cmp(relax.op.greater)
_test_binary_cmp(relax.op.greater_equal)
_test_binary_cmp(relax.op.less)
_test_binary_cmp(relax.op.less_equal)
_test_binary_cmp(relax.op.not_equal)


def test_relax_ewise_fma():
Expand Down

0 comments on commit 4216bb1

Please sign in to comment.