diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 14b75ff1c0a8..d91b3594b5a3 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */ +struct ArgReduceAttrs : public tvm::AttrsNode { + Array axis; + bool keepdims; + bool select_last_index; + bool exclude; + + TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(select_last_index) + .set_default(false) + .describe( + "Whether to select the last index if the target element appears multiple times, else " + "select the first index which the target element appears"); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); + } +}; + struct VarianceAttrs : public tvm::AttrsNode { Array axis; bool keepdims; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 15d1455bb267..d4e420d80b02 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } +inline FCommReduce MakeArgminReducer(bool select_last_index = false) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [=](Array lhs, Array rhs) { + Array result; + + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + + // These variables compare the actual values of the array + auto is_smaller = lhs_val < rhs_val; + auto is_same = lhs_val == rhs_val; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs_idx > rhs_idx; + } else { + proper_index = lhs_idx < rhs_idx; + } + + PrimExpr update_index = is_smaller || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val + return result; + }; + auto fidentity = [&](std::vector types) { + Array result; + result.push_back(tvm::tir::make_const(types[0], -1)); // idx + result.push_back(tvm::max_value(types[1])); // val + return result; + }; + return MakeCommReducer(fcombine, fidentity, "argmin"); +} + /*! * \brief Creates an operation that finds the indices of the minimum * values over a given axis. @@ -442,35 +481,48 @@ inline Tensor max(const Tensor& data, const Array& axis, bool keepdims * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. + * \param select_last_index Whether to select the last index if the minimum element + * appears multiple times, else select the first index. * * \return A Tensor whose op member is the argmin operation */ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto fcombine = [](Array lhs, Array rhs) { - Array result; - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val - return result; - }; - auto fidentity = [](std::vector types) { - Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val - return result; - }; - auto func = MakeCommReducer(fcombine, fidentity, "argmin"); - return CommReduceIdx(data, axis, func, keepdims, atleast1d); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgminReducer(select_last_index); + return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } -inline FCommReduce MakeArgmaxReducer() { - auto fcombine = [](Array lhs, Array rhs) { +inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { + // Create a Commutative Reducer with a comparison operation, and method to get the initial value. + auto fcombine = [=](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val + + // Casting to avoid operator ambiguity + PrimExpr lhs_idx = static_cast(lhs[0]); + PrimExpr rhs_idx = static_cast(rhs[0]); + PrimExpr lhs_val = static_cast(lhs[1]); + PrimExpr rhs_val = static_cast(rhs[1]); + + // These variables compare the actual values of the array + auto is_bigger = lhs_val > rhs_val; + auto is_same = lhs_val == rhs_val; + + // This checks if the indices are correct for the reduction. E.g. for select_last_index + // it gives precedence for later indices of the same element and precedence for sooner + // indices if not select_last_index; + PrimExpr proper_index; + if (select_last_index) { + proper_index = lhs_idx > rhs_idx; + } else { + proper_index = lhs_idx < rhs_idx; + } + + PrimExpr update_index = is_bigger || (is_same && proper_index); + result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [](std::vector types) { + auto fidentity = [&](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val @@ -490,12 +542,13 @@ inline FCommReduce MakeArgmaxReducer() { * left in the result as dimensions with size one. This enables the result * to broadcast correctly against the input array. * \param atleast1d Whether the output need to be atleast1d. - * + * \param select_last_index Whether to select the last index if the maximum element + * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, - bool atleast1d = false) { - auto reducer = MakeArgmaxReducer(); + bool atleast1d = false, bool select_last_index = false) { + auto reducer = MakeArgmaxReducer(select_last_index); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9144d3e145c8..f9b49204b85e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -32,23 +32,23 @@ from .. import loops as _loops from .. import op as _op from .. import qnn as _qnn +from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .. import random as _random from .common import ( AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, infer_value, + lstm_cell, new_var, unbind, - gru_cell, - lstm_cell, ) __all__ = ["from_onnx"] @@ -1786,12 +1786,11 @@ class ArgMax(OnnxOpConverter): """Operator converter for ArgMax.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMax") + def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") @@ -1799,12 +1798,11 @@ class ArgMin(OnnxOpConverter): """Operator converter for ArgMin.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - if "select_last_index" in attr: - raise NotImplementedError("select_last_index not supported in ArgMin") + def _impl_v13(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) - attr = {"axis": axis, "keepdims": keepdims} + select_last_index = attr.get("select_last_index", False) + attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index} return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 368ffb5ab0ca..23accebfd0ec 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -17,13 +17,13 @@ """Reduce operators.""" # pylint: disable=redefined-builtin +from ..expr import Tuple, TupleWrapper from . import _make -from .tensor import sqrt, log, exp +from .tensor import exp, log, sqrt from .transform import squeeze -from ..expr import Tuple, TupleWrapper -def argmax(data, axis=None, keepdims=False, exclude=False): +def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the max element appears in + multiple indices, default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmax(data, axis, keepdims, exclude) + return _make.argmax(data, axis, keepdims, exclude, select_last_index) -def argmin(data, axis=None, keepdims=False, exclude=False): +def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False): If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead. + select_last_index : bool + Whether to select the last index or the first index if the min element appears in + multiple indices, default is False (first index). + Returns ------- result : relay.Expr The computed result. """ axis = [axis] if isinstance(axis, int) else axis - return _make.argmin(data, axis, keepdims, exclude) + return _make.argmin(data, axis, keepdims, exclude, select_last_index) def sum(data, axis=None, keepdims=False, exclude=False): diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py index 77f9ad447ed1..45d07af577a3 100644 --- a/python/tvm/topi/reduction.py +++ b/python/tvm/topi/reduction.py @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False): return cpp.min(data, axis, keepdims) -def argmax(data, axis=None, keepdims=False): +def argmax(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the maximum values along an axis. Parameters @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False): with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the maximum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmax(data, axis, keepdims) + return cpp.argmax(data, axis, keepdims, select_last_index) -def argmin(data, axis=None, keepdims=False): +def argmin(data, axis=None, keepdims=False, select_last_index=False): """Returns the indices of the minimum values along an axis. Parameters @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False): with size one. With this option, the result will broadcast correctly against the input array. + select_last_index: bool + Whether to select the last index if the minimum element appears multiple times, else + select the first index. + Returns ------- ret : tvm.te.Tensor """ - return cpp.argmin(data, axis, keepdims) + return cpp.argmin(data, axis, keepdims, select_last_index) def prod(data, axis=None, keepdims=False): diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index f08af1e7e4ad..693589fecfb4 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -38,6 +38,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); +TVM_REGISTER_NODE_TYPE(ArgReduceAttrs); TVM_REGISTER_NODE_TYPE(VarianceAttrs); /*! @@ -207,9 +208,29 @@ Array ReduceCompute(const Attrs& attrs, const Array& inp return {topi::identity(inputs[0])}; } } + return {f(inputs[0], axes, param->keepdims, false)}; } +template +Array ArgReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { + const ArgReduceAttrs* param = attrs.as(); + ICHECK(param != nullptr); + if (inputs[0]->shape.size() == 0) { + return {topi::identity(inputs[0])}; + } + auto axes = param->axis; + if (param->exclude) { + axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); + if (axes.size() == 0) { + return {topi::identity(inputs[0])}; + } + } + + return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; +} + /*! * \brief ReduceShapeImpl get the outshape for the reduction operator * \param in_shape Shape of input data. @@ -269,22 +290,16 @@ inline std::vector ReduceShapeImpl(const std::vector& in_s } } -/*! - * \brief ArgReduceRel Output type and shape relation evaluation function. - * \param num_inputs Number of input types in the args. - * \param attrs The additional attributes of the operator. - * \param reporter The reporter to report solution to. - * \return false if This relation cannot be resolved. true if this relation has been resolved. - */ -bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { +template +bool GenericReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; ICHECK(static_cast(data->shape.size()) != 0); std::vector in_shape(data->shape.begin(), data->shape.end()); - const ReduceAttrs* param = attrs.as(); + const T* param = attrs.as(); ICHECK(param != nullptr); // assign output type and shape @@ -292,6 +307,17 @@ bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); return true; } +/*! + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + return GenericReduceRel(types, num_inputs, attrs, reporter); +} /*! * \brief ReduceRel Output type and shape relation evaluation function. @@ -324,6 +350,16 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); } +Expr MakeOneElementReduce(Expr data, Array axis, bool keepdims, bool exclude, + bool select_last_index, String op_name) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + attrs->exclude = exclude; + attrs->select_last_index = select_last_index; + return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); +} + #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ @@ -331,35 +367,43 @@ Expr MakeReduce(Expr data, Array axis, bool keepdims, bool exclude, Str }); \ RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") +#define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude, \ + bool select_last_index) { \ + return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmax); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmax); } -RELAY_REGISTER_REDUCE_OP("argmax") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax") .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) .set_attr("TOpPattern", kCommReduce); Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return ReduceCompute(attrs, inputs, out_type, topi::argmin); + return ArgReduceCompute(attrs, inputs, out_type, topi::argmin); } -RELAY_REGISTER_REDUCE_OP("argmin") +RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin") .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. )code" TVM_ADD_FILELINE) - .set_attrs_type() + .set_attrs_type() .set_support_level(4) - .add_type_rel("ArgReduce", ArgReduceRel) + .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) .set_attr("TOpPattern", kCommReduce); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 55c59162e68c..3d1c6f9f7d5b 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -45,11 +45,11 @@ TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); + *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2], false, args[3]); }); TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9e0eb1f75217..a1d821686ed5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import glob import os import re -import glob import numpy as np import pytest @@ -236,7 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -4680,22 +4679,6 @@ def verify_eyelike(indata): "test_adagrad_multiple", "test_adam", "test_adam_multiple", - "test_argmax_default_axis_example_select_last_index", - "test_argmax_default_axis_random_select_last_index", - "test_argmax_keepdims_example_select_last_index", - "test_argmax_keepdims_random_select_last_index", - "test_argmax_negative_axis_keepdims_example_select_last_index", - "test_argmax_negative_axis_keepdims_random_select_last_index", - "test_argmax_no_keepdims_example_select_last_index", - "test_argmax_no_keepdims_random_select_last_index", - "test_argmin_default_axis_example_select_last_index", - "test_argmin_default_axis_random_select_last_index", - "test_argmin_keepdims_example_select_last_index", - "test_argmin_keepdims_random_select_last_index", - "test_argmin_negative_axis_keepdims_example_select_last_index", - "test_argmin_negative_axis_keepdims_random_select_last_index", - "test_argmin_no_keepdims_example_select_last_index", - "test_argmin_no_keepdims_random_select_last_index", "test_cast_BFLOAT16_to_FLOAT", "test_cast_DOUBLE_to_FLOAT16", "test_cast_FLOAT_to_BFLOAT16", diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index df77c33658de..6415976bfd59 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np -from tvm import relay +import numpy.random +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay, te from tvm.relay import transform from tvm.relay.testing import run_infer_type -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -342,6 +342,34 @@ def _unbiased_func(a, axis=None, dtype=None, keepdims=None): verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1)) +@tvm.testing.uses_gpu +def test_argmin_argmax_get_last_elements(): + def get_test_case(shape, gt_func, test_argmin=False): + total_ele = np.product(shape) + arr = np.zeros(total_ele) + target_value = -1 if test_argmin else 1 + arr[: total_ele // 3] = target_value + np.random.shuffle(arr) + arr = arr.reshape(shape) + ans = gt_func(np.flip(arr)) + return arr, len(arr) - ans - 1 + + funcs_and_gt_funcs = [(relay.argmax, np.argmax), (relay.argmin, np.argmin)] + lengths = [5, 10, 15] + for func, gt_func in funcs_and_gt_funcs: + for shape in lengths: + x_in = relay.var("x_in", shape=[shape]) + output = func(x_in, select_last_index=True) + arr, ans = get_test_case(shape, gt_func, test_argmin=func == relay.argmin) + + mod = tvm.IRModule.from_expr(output) + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor( + "graph", mod=mod, device=dev, target=target + ).evaluate()(arr) + assert op_res.numpy().item() == ans + + def verify_mean_var_std(funcs, shape, axis, keepdims): test_func = funcs[0] ref_func = funcs[1]