Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY]Reduce ops sum/max/min/mean/prod #1927

Merged
merged 4 commits into from
Oct 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ This level enables additional math and transform operators.
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
tvm.relay.sum
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.prod


**Level 5: Vision/Image Operators**
Expand Down Expand Up @@ -187,6 +192,11 @@ Level 4 Definitions
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
.. autofunction:: tvm.relay.sum
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod


Level 5 Definitions
Expand Down
152 changes: 150 additions & 2 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""

return _make.argmax(data, axis, keepdims, exclude)

def argmin(data, axis=None, keepdims=False, exclude=False):
Expand Down Expand Up @@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""

return _make.argmin(data, axis, keepdims, exclude)


def sum(data, axis=None, keepdims=False, exclude=False):
"""Computes the sum of array elements over given axes.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sum(data, axis, keepdims, exclude)


def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.max(data, axis, keepdims, exclude)


def min(data, axis=None, keepdims=False, exclude=False):
"""Computes the min of array elements over given axes.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.min(data, axis, keepdims, exclude)


def mean(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean of array elements over given axes.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mean(data, axis, keepdims, exclude)


def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.

exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.prod(data, axis, keepdims, exclude)
116 changes: 111 additions & 5 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/relay/op.h>
#include <numeric>
#include <limits>
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
Expand All @@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
bool exclude;

TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(Array<IndexExpr>({}))
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<IndexExpr>>())
.describe(R"code(The axis or axes along which to perform the reduction.

The default, `axis=()`, will compute over all elements into a
Expand Down Expand Up @@ -158,10 +159,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape;
for (auto i : data->shape) {
in_shape.push_back(i);
}
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
Expand All @@ -172,6 +170,31 @@ bool ArgReduceRel(const Array<Type>& types,
return true;
}

/*!
* \brief ReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);

// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}

#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
Expand Down Expand Up @@ -213,5 +236,88 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);


RELAY_REGISTER_REDUCE_OP("sum")
.describe(R"code(Computes the sum of array elements over given axes.

Example::

data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]

sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]

sum(data, axis=[1,2])
[ 12. 19. 27.]

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("min")
.describe(R"code(Computes the min of array elements over given axes.

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.

Example::

data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]

mean(data)
[3.22]

mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.

Example::

data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]

mean(data, axis=1)
[35562240]

mean(data, axis=[1,2])
[ 36 480 2058]

)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);

} // namespace relay
} // namespace tvm
Loading