Skip to content

Commit

Permalink
[Unity] Relax op: arithmetic, comparison
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the unary, binary and ternary arithmetic and
comparison operators.

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Chaofan Lin <1713833595@qq.com>
  • Loading branch information
3 people committed Feb 14, 2023
1 parent 4ee1bde commit d57d0f8
Show file tree
Hide file tree
Showing 19 changed files with 2,027 additions and 26 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,18 @@
from .base import *
from .binary import *
from .manipulate import *
from .ternary import *
from .unary import *
from . import builtin
from . import memory


def _register_op_make():
# pylint: disable=import-outside-toplevel
from . import _ffi_api
from .. import expr

expr._op_ffi_api = _ffi_api # type: ignore


_register_op_make()
165 changes: 165 additions & 0 deletions python/tvm/relax/op/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,42 @@ def add(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.add(x1, x2) # type: ignore


def divide(x1: Expr, x2: Expr) -> Expr:
"""Division with numpy-style broadcasting.
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.divide(x1, x2) # type: ignore


def floor_divide(x1: Expr, x2: Expr) -> Expr:
"""Floor division with numpy-style broadcasting.
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.floor_divide(x1, x2) # type: ignore


def multiply(x1: Expr, x2: Expr) -> Expr:
"""Multiplication with numpy-style broadcasting.
Expand All @@ -65,3 +101,132 @@ def multiply(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.multiply(x1, x2) # type: ignore


def subtract(x1: Expr, x2: Expr) -> Expr:
"""Subtraction with numpy-style broadcasting.
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.subtract(x1, x2) # type: ignore


###################### Comparison operators ######################


def equal(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs == rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.equal(x1, x2) # type: ignore


def greater(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs > rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.greater(x1, x2) # type: ignore


def greater_equal(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs >= rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.greater_equal(x1, x2) # type: ignore


def less(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs < rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.less(x1, x2) # type: ignore


def less_equal(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs <= rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.less_equal(x1, x2) # type: ignore


def not_equal(x1: Expr, x2: Expr) -> Expr:
"""Broadcasted element-wise test for (lhs != rhs).
Parameters
----------
x1 : relax.Expr
The first input tensor.
x2 : relax.Expr
The second input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.not_equal(x1, x2) # type: ignore
43 changes: 43 additions & 0 deletions python/tvm/relax/op/ternary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Relax ternary arithmetic operators."""
from . import _ffi_api
from ..expr import Expr


def ewise_fma(x1: Expr, x2: Expr, x3: Expr) -> Expr:
"""Elementwise fused multiply-add operator
Returns elementwise result of :math:`x1 * x2 + x3`
Parameters
----------
x1 : relax.Expr
The left hand operand of the multiplication
x2 : relax.Expr
The right hand operand of the multiplication
x3 : relax.Expr
The operand of the addition
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.ewise_fma(x1, x2, x3) # type: ignore
Loading

0 comments on commit d57d0f8

Please sign in to comment.