Skip to content

Commit

Permalink
[Relay][Topi][TensorFlow][ONNX][Lang] Add support for Any op (#4205)
Browse files Browse the repository at this point in the history
* Add support for Any op

* Support ONNX frontend

* Add doc

* Add to relay docs

* Dummy change to retrigger CI
  • Loading branch information
soiferj authored and jroesch committed Oct 30, 2019
1 parent 156aa59 commit b07b195
Show file tree
Hide file tree
Showing 17 changed files with 256 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ List of operators
topi.greater_equal
topi.less_equal
topi.all
topi.any
topi.logical_and
topi.logical_or
topi.logical_not
Expand Down Expand Up @@ -151,6 +152,7 @@ topi
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.all
.. autofunction:: topi.any
.. autofunction:: topi.max
.. autofunction:: topi.sum
.. autofunction:: topi.min
Expand Down
1 change: 1 addition & 0 deletions docs/frontend/tensorflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Supported Ops
- Abs
- Add
- All
- Any
- ArgMax
- ArgMin
- AvgPool
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ This level enables additional math and transform operators.
tvm.relay.less
tvm.relay.less_equal
tvm.relay.all
tvm.relay.any
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
Expand Down Expand Up @@ -300,6 +301,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.any
.. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,13 @@ TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr all(Expr source, Array<IterVar> axis);

/*!
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
TVM_DLL Expr any(Expr source, Array<IterVar> axis);

/*!
* \brief max of of source expression over axis
* \param source The source expression.
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,12 @@ class Where(OnnxOpConverter):
def _impl_v9(cls, inputs, attr, params):
return _op.where(inputs[0], inputs[1], inputs[2])

class Or(Elemwise):
""" Operator converter for Or.
"""
@classmethod
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])

# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -1111,7 +1117,8 @@ def _get_convert_map(opset):
'And': And.get_converter(opset),
'Tile': Tile.get_converter(opset),
'Erf': Erf.get_converter(opset),
'Where': Where.get_converter(opset)
'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset)
}


Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,7 @@ def _impl(inputs, attr, params):
'Abs' : AttrCvt('abs'),
'Add' : _elemwise('add'),
'All' : _reduce('all'),
'Any' : _reduce('any'),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _schedule_reduce(_, outs, target):
_reg.register_schedule("argmin", _schedule_reduce)
_reg.register_schedule("sum", _schedule_reduce)
_reg.register_schedule("all", _schedule_reduce)
_reg.register_schedule("any", _schedule_reduce)
_reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
Expand Down
52 changes: 52 additions & 0 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,58 @@ def all(data, axis=None, keepdims=False, exclude=False):
return _make.all(data, axis, keepdims, exclude)


def any(data, axis=None, keepdims=False, exclude=False):
"""Computes the logical OR of boolean array elements over given axes.
Parameters
----------
data : relay.Expr
The input boolean tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed. The default, axis=None,
will sum all of the elements of the input array. If axis is
negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = relay.Constant(tvm.nd.array([[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]))
relay.any(data, axis=1)
# [[True, True, True],
# [True, True, True]]
relay.any(data, axis=0)
# [[ True, True, True],
# [ True, True, True],
# [False, True, True]]
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.any(data, axis, keepdims, exclude)


def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
Expand Down
10 changes: 10 additions & 0 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,16 @@ Expr all(Expr source, Array<IterVar> rdom) {
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr any(Expr source, Array<IterVar> rdom) {
CHECK(source.type().is_bool());
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Or::make(x, y);
Expr identity_element = make_const(source.type(), false);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
}

Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Max::make(x, y);
Expand Down
37 changes: 37 additions & 0 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,43 @@ Example::
.set_attr<TOpPattern>("TOpPattern", kCommReduce);


Array<Tensor> AnyCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return ReduceCompute(attrs, inputs, out_type, target, topi::any);
}


RELAY_REGISTER_REDUCE_OP("any")
.describe(R"code(Computes the logical OR of boolean array elements over given axes.
Example::
data = [[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]
any(data, axis=1)
[[True, True, True],
[True, True, True]]
any(data, axis=0)
[[ True, True, True],
[ True, True, True],
[False, True, True]]
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);


Array<Tensor> MaxCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
Expand Down
48 changes: 48 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,53 @@ def test_where():
verify_where(condition, x, y, TensorProto.FLOAT, outdata)


def verify_or(indata, dtype):
x = indata[0].astype(dtype)
y = indata[1].astype(dtype)
outdata = np.logical_or(x, y)

node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], )

graph = helper.make_graph([node],
'or_test',
inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)),
helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))])

model = helper.make_model(graph, producer_name='or_test')

for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
tvm.testing.assert_allclose(outdata, tvm_out)


def test_or():
# 2d
x = (np.random.randn(3, 4) > 0)
y = (np.random.randn(3, 4) > 0)
verify_or(indata=[x, y], dtype=bool)

# 3d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(3, 4, 5) > 0)
verify_or(indata=[x, y], dtype=bool)

# 4d
x = (np.random.randn(3, 4, 5, 6) > 0)
y = (np.random.randn(3, 4, 5, 6) > 0)
verify_or(indata=[x, y], dtype=bool)

# 3d vs 1d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(5) > 0)
verify_or(indata=[x, y], dtype=bool)

# 3d vs 2d
x = (np.random.randn(3, 4, 5) > 0)
y = (np.random.randn(4, 5) > 0)
verify_or(indata=[x, y], dtype=bool)


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -1651,3 +1698,4 @@ def test_where():
test_tile()
test_erf()
test_where()
test_or()
12 changes: 10 additions & 2 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,7 @@ def check_size(ishape):
check_size((10,))

#######################################################################
# All, Max, Min
# All, Any, Max, Min
# -------------
def test_forward_reduce_all():
"""Test the All operator."""
Expand All @@ -2208,6 +2208,14 @@ def test_forward_reduce_all():
tf.reduce_all(in_data, name="all")
compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0')

def test_forward_reduce_any():
"""Test the Any operator."""
np_data = np.random.choice([True, False], size=(5, 7, 11))
tf.reset_default_graph()
in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data")
tf.reduce_any(in_data, name="any")
compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0')

def test_forward_reduce_max():
def check_max(ishape, axis, keepdims, dtype):
tf.reset_default_graph()
Expand Down Expand Up @@ -2432,7 +2440,7 @@ def test_forward_one_hot():
test_forward_mean()
test_forward_reduce_prod()
test_forward_reduce_all()
test_forward_reduce_max()
test_forward_reduce_any()
test_forward_reduce_min()

# General
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_where():
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
dtype = "bool" if ref_func in [np.all] else dtype
dtype = "bool" if ref_func in [np.all, np.any] else dtype

x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
Expand Down Expand Up @@ -207,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.std, np.std],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.any, np.any],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
Expand Down
21 changes: 21 additions & 0 deletions topi/include/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,27 @@ inline Tensor all(const Tensor& data,
return CommReduce(data, axis, tvm::all, keepdims, atleast1d);
}

/*!
* \brief Creates an operation that computes the logical OR of elements
* over a given axis
*
* \param data The input boolean tensor
* \param axis The axes to reduce. If axis is empty, the operation will
* perform logical OR over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \return A Tensor whose op member is the all operation
*/
inline Tensor any(const Tensor& data,
const Array<Integer>& axis,
bool keepdims = false,
bool atleast1d = false) {
return CommReduce(data, axis, tvm::any, keepdims, atleast1d);
}

/*!
* \brief Creates an operation that finds the minimum of elements over
* a given axis.
Expand Down
25 changes: 25 additions & 0 deletions topi/python/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ def all(data, axis=None, keepdims=False):
return cpp.all(data, axis, keepdims)


def any(data, axis=None, keepdims=False):
"""Logical OR of array elements over a given axis or a list of axes
Parameters
----------
data : tvm.Tensor
The input tvm boolean tensor
axis : None or int or tuple of int
Axis or axes along which a logical OR is performed.
The default, axis=None, will perform logical OR over all elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
return cpp.any(data, axis, keepdims)


def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ TVM_REGISTER_GLOBAL("topi.all")
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});

TVM_REGISTER_GLOBAL("topi.any")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});

/* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Expand Down
Loading

0 comments on commit b07b195

Please sign in to comment.