diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 56558272f2a3..a36f8e6c71cf 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -108,6 +108,11 @@ This level enables additional math and transform operators. tvm.relay.where tvm.relay.argmax tvm.relay.argmin + tvm.relay.sum + tvm.relay.max + tvm.relay.min + tvm.relay.mean + tvm.relay.prod **Level 5: Vision/Image Operators** @@ -187,6 +192,11 @@ Level 4 Definitions .. autofunction:: tvm.relay.where .. autofunction:: tvm.relay.argmax .. autofunction:: tvm.relay.argmin +.. autofunction:: tvm.relay.sum +.. autofunction:: tvm.relay.max +.. autofunction:: tvm.relay.min +.. autofunction:: tvm.relay.mean +.. autofunction:: tvm.relay.prod Level 5 Definitions diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index a2a4519512ea..73c5f270e8bf 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False): result : relay.Expr The computed result. """ - return _make.argmax(data, axis, keepdims, exclude) def argmin(data, axis=None, keepdims=False, exclude=False): @@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False): result : relay.Expr The computed result. """ - return _make.argmin(data, axis, keepdims, exclude) + + +def sum(data, axis=None, keepdims=False, exclude=False): + """Computes the sum of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sum(data, axis, keepdims, exclude) + + +def max(data, axis=None, keepdims=False, exclude=False): + """ Computes the max of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.max(data, axis, keepdims, exclude) + + +def min(data, axis=None, keepdims=False, exclude=False): + """Computes the min of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.min(data, axis, keepdims, exclude) + + +def mean(data, axis=None, keepdims=False, exclude=False): + """Computes the mean of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.mean(data, axis, keepdims, exclude) + + +def prod(data, axis=None, keepdims=False, exclude=False): + """Computes the products of array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a argmin operation is performed. + The default, axis=None, will find the indices of minimum element all of the elements of + the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.prod(data, axis, keepdims, exclude) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 017ef1e5dfec..0a955fad631b 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -7,6 +7,7 @@ #include #include #include +#include "../op_common.h" #include "../type_relations.h" namespace tvm { @@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode { bool exclude; TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(Array({})) + TVM_ATTR_FIELD(axis).set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -158,10 +159,7 @@ bool ArgReduceRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) return false; CHECK(static_cast(data->shape.size()) != 0); - std::vector in_shape; - for (auto i : data->shape) { - in_shape.push_back(i); - } + std::vector&& in_shape = AsVector(data->shape); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -172,6 +170,31 @@ bool ArgReduceRel(const Array& types, return true; } +/*! +* \brief ReduceRel Output type and shape relation evaluation function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return false if This relation cannot be resolved. true if this relation has been resolved. +*/ +bool ReduceRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + CHECK(static_cast(data->shape.size()) != 0); + std::vector&& in_shape = AsVector(data->shape); + + const ReduceAttrs* param = attrs.as(); + CHECK(param != nullptr); + + // assign output type and shape + auto oshape = ReduceShapeImpl(in_shape, param, reporter); + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \ @@ -213,5 +236,88 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", ArgReduceRel); + +RELAY_REGISTER_REDUCE_OP("sum") +.describe(R"code(Computes the sum of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + sum(data, axis=1) + [[ 4. 8.] + [ 10. 9.] + [ 21. 6.]] + + sum(data, axis=[1,2]) + [ 12. 19. 27.] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("max") +.describe(R"code(Computes the max of array elements over given axes. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("min") +.describe(R"code(Computes the min of array elements over given axes. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("mean") +.describe(R"code(Computes the mean of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data) + [3.22] + + mean(data, axis=[1,2]) + [ 2. 3.16666667 4.5] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + + +RELAY_REGISTER_REDUCE_OP("prod") +.describe(R"code(Computes the products of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data, axis=1) + [35562240] + + mean(data, axis=[1,2]) + [ 36 480 2058] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c2b685affab4..2dc643cfd7e4 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -46,27 +46,6 @@ def test_binary_int_broadcast(): assert zz.checked_type == relay.TensorType((5, 10, 4), "int32") -def test_arg_reduce(): - for op in [relay.argmax, relay.argmin]: - n, c , h, w = 10, 20, 3, 4 - x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) - z = relay.argmax(x, axis=(1,)) - "axis=" in z.astext() - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType((n, h, w), "int32") - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) - z = relay.argmax(x, axis=(2,), keepdims=True) - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType((n, c , 1, w), "int32") - - n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.ty.TensorType((n, c , h, w), "float32")) - z = relay.argmax(x, axis=(2,), keepdims=True, exclude=True) - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType((1, 1 , h, 1), "int32") - - def test_where(): cond = relay.var("cond", relay.TensorType((3, 4), "float32")) x = relay.var("x", relay.TensorType((3, 4), "float32")) @@ -76,9 +55,45 @@ def test_where(): assert zz.checked_type == relay.TensorType((3, 4), "float32") +def verify_reduce(test_func, data, axis, keepdims, exclude, output): + x = relay.var("x", relay.TensorType(data, "float32")) + z = test_func(x, axis, keepdims, exclude) + zz = relay.ir_pass.infer_type(z) + if axis: + assert "axis=" in z.astext() + if keepdims: + assert "keepdims=" in z.astext() + if exclude: + assert "exclude=" in z.astext() + out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32" + assert zz.checked_type == relay.ty.TensorType(output, out_type) + +def test_reduce_functions(): + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + for func in [relay.sum, + relay.max, + relay.min, + relay.mean, + relay.prod, + relay.argmin, + relay.argmax]: + verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4)) + verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3)) + verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) + verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) + verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) + verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) + verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1)) + verify_reduce(func, (4, 4, 3), None, False, True, ()) + verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) + verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) + verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,)) + verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128)) + verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) + if __name__ == "__main__": test_binary_op() test_cmp_type() test_binary_int_broadcast() test_where() - test_arg_reduce() + test_reduce_functions()