Skip to content

Commit

Permalink
[Unity] Relax op: statistical (#13991)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the statistical operators.
  • Loading branch information
MasterJH5574 authored and tqchen committed Feb 24, 2023
1 parent 27dde56 commit d4a7cfc
Show file tree
Hide file tree
Showing 9 changed files with 856 additions and 0 deletions.
48 changes: 48 additions & 0 deletions include/tvm/relax/attrs/statistical.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/statistical.h
* \brief Attributes for statistical operators.
*/
#ifndef TVM_RELAX_ATTRS_STATISTICAL_H_
#define TVM_RELAX_ATTRS_STATISTICAL_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes for statistical operators */
struct StatisticalAttrs : public tvm::AttrsNode<StatisticalAttrs> {
Optional<Array<Integer>> axis;
bool keepdims;

TVM_DECLARE_ATTRS(StatisticalAttrs, "relax.attrs.StatisticalAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis or axes along which to perform the reduction.");
TVM_ATTR_FIELD(keepdims).describe(
"If this is set to `True`, the reduced axes are left in the result as dimension with size "
"one.");
}
}; // struct StatisticalAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_STATISTICAL_H_
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .index import *
from .manipulate import *
from .op_attrs import *
from .statistical import *
from .set import *
from .ternary import *
from .unary import *
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs):
"""Attributes used in strided_slice operator"""


@tvm._ffi.register_object("relax.attrs.StatisticalAttrs")
class StatisticalAttrs(Attrs):
"""Attributes used in statistical operator"""


@tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
class Resize2DAttrs(Attrs):
"""Attributes used in image resize2d operator"""
Expand Down
218 changes: 218 additions & 0 deletions python/tvm/relax/op/statistical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""Statistical operators."""
from typing import List, Optional, Union

from . import _ffi_api
from ..expr import Expr


def max(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the max of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a max operation is performed.
The default, axis=None, will compute the max of all elements in the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.max(x, axis, keepdims) # type: ignore


def mean(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the mean of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a mean operation is performed.
The default, axis=None, will compute the mean of all elements in the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.mean(x, axis, keepdims) # type: ignore


def min(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the min of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a min operation is performed.
The default, axis=None, will compute the min of all elements in the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.min(x, axis, keepdims) # type: ignore


def prod(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the product of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a product is performed.
The default, axis=None, will compute the product of all elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.prod(x, axis, keepdims) # type: ignore


def std(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the standard deviation of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a standard deviation is performed.
The default, axis=None, will compute the std of all elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.std(x, axis, keepdims) # type: ignore


def sum(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the sum of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.sum(x, axis, keepdims) # type: ignore


def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the variance of tensor elements over given axes.
Parameters
----------
x : relax.Expr
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a variance operation is performed.
The default, axis=None, will compute the variance of all elements in the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input tensor.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.variance(x, axis, keepdims) # type: ignore
18 changes: 18 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,24 @@
less_equal,
log,
make_closure,
max,
mean,
memory,
min,
multiply,
negative,
not_equal,
null_value,
print,
prod,
reshape,
round,
shape_of,
std,
strided_slice,
sum,
take,
variance,
sigmoid,
sign,
sin,
Expand Down Expand Up @@ -486,18 +495,26 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"less_equal",
"log",
"make_closure",
"max",
"mean",
"memory",
"min",
"multiply",
"negative",
"not_equal",
"null_value",
"output",
"prim_value",
"print",
"prod",
"reshape",
"round",
"shape",
"shape_of",
"std",
"str",
"strided_slice",
"sum",
"sigmoid",
"sign",
"sin",
Expand All @@ -511,5 +528,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"tan",
"tanh",
"tuple",
"variance",
"unique",
]
Loading

0 comments on commit d4a7cfc

Please sign in to comment.