Skip to content

Commit

Permalink
[ONNX][TOPI] Support select_last_index for argmin/max (apache#8816)
Browse files Browse the repository at this point in the history
* support select_last_index for argmin/max

* reverse conditions which made on accident

* forward args in reduce.py

* make proper nodes for reduction ops

* remove complicated nested lambdas

* fix lambda capture for conversion

* forward more arguments

* forward more args

* enable onnx tests

* wrapping casts to remove ambiguity

* revert changes extraneous

* correct incorrect attrs being used for ops

* change attributes

* remove old impl

* register new attribute node

* clean up test

* reformat

* reformat

* coolio

* stable comparison

* casts to avoid ambiguity

* casting more

* correct arg passing

* support select_last_index for argmin/max

* reverse conditions which made on accident

* forward args in reduce.py

* make proper nodes for reduction ops

* remove complicated nested lambdas

* fix lambda capture for conversion

* forward more arguments

* forward more args

* enable onnx tests

* wrapping casts to remove ambiguity

* revert changes extraneous

* correct incorrect attrs being used for ops

* change attributes

* remove old impl

* register new attribute node

* clean up test

* reformat

* reformat

* coolio

* stable comparison

* casts to avoid ambiguity

* casting more

* correct arg passing

* fix broken input

* OneElementReduceAttrs-->ArgReduceAttrs"

* reduce boilerplate

* change names

* remove log statement

* jostle ci

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
2 people authored and ylc committed Sep 29, 2021
1 parent e57e0e0 commit bb27c0f
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 87 deletions.
36 changes: 36 additions & 0 deletions include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,42 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
}
};

/*! \brief Attributes for Reduce operators which reduce by finding a single element. E.g. argmin */
struct ArgReduceAttrs : public tvm::AttrsNode<ArgReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool select_last_index;
bool exclude;

TVM_DECLARE_ATTRS(ArgReduceAttrs, "relay.attrs.ArgReduceAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");

TVM_ATTR_FIELD(keepdims).set_default(false).describe(
"If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(select_last_index)
.set_default(false)
.describe(
"Whether to select the last index if the target element appears multiple times, else "
"select the first index which the target element appears");
TVM_ATTR_FIELD(exclude).set_default(false).describe(
"Whether to perform reduction on axis that are NOT in axis instead.");
}
};

struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
Array<Integer> axis;
bool keepdims;
Expand Down
99 changes: 76 additions & 23 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,45 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
return CommReduce(data, axis, MaxOp, keepdims, atleast1d);
}

inline FCommReduce MakeArgminReducer(bool select_last_index = false) {
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;

// Casting to avoid operator ambiguity
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);

// These variables compare the actual values of the array
auto is_smaller = lhs_val < rhs_val;
auto is_same = lhs_val == rhs_val;

// This checks if the indices are correct for the reduction. E.g. for select_last_index
// it gives precedence for later indices of the same element and precedence for sooner
// indices if not select_last_index;
PrimExpr proper_index;
if (select_last_index) {
proper_index = lhs_idx > rhs_idx;
} else {
proper_index = lhs_idx < rhs_idx;
}

PrimExpr update_index = is_smaller || (is_same && proper_index);
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [&](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
};
return MakeCommReducer(fcombine, fidentity, "argmin");
}

/*!
* \brief Creates an operation that finds the indices of the minimum
* values over a given axis.
Expand All @@ -442,35 +481,48 @@ inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
* \param select_last_index Whether to select the last index if the minimum element
* appears multiple times, else select the first index.
*
* \return A Tensor whose op member is the argmin operation
*/
inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
};
auto func = MakeCommReducer(fcombine, fidentity, "argmin");
return CommReduceIdx(data, axis, func, keepdims, atleast1d);
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgminReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
// Create a Commutative Reducer with a comparison operation, and method to get the initial value.
auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val

// Casting to avoid operator ambiguity
PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);

// These variables compare the actual values of the array
auto is_bigger = lhs_val > rhs_val;
auto is_same = lhs_val == rhs_val;

// This checks if the indices are correct for the reduction. E.g. for select_last_index
// it gives precedence for later indices of the same element and precedence for sooner
// indices if not select_last_index;
PrimExpr proper_index;
if (select_last_index) {
proper_index = lhs_idx > rhs_idx;
} else {
proper_index = lhs_idx < rhs_idx;
}

PrimExpr update_index = is_bigger || (is_same && proper_index);
result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
auto fidentity = [&](std::vector<DataType> types) {
Array<PrimExpr> result;
result.push_back(tvm::tir::make_const(types[0], -1)); // idx
result.push_back(tvm::min_value(types[1])); // val
Expand All @@ -490,12 +542,13 @@ inline FCommReduce MakeArgmaxReducer() {
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
*
* \param select_last_index Whether to select the last index if the maximum element
* appears multiple times, else select the first index.
* \return A Tensor whose op member is the argmax operation
*/
inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false,
bool atleast1d = false) {
auto reducer = MakeArgmaxReducer();
bool atleast1d = false, bool select_last_index = false) {
auto reducer = MakeArgmaxReducer(select_last_index);
return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}

Expand Down
20 changes: 9 additions & 11 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@
from .. import loops as _loops
from .. import op as _op
from .. import qnn as _qnn
from .. import random as _random
from .. import ty as _ty
from .. import vision as _vision
from .. import random as _random
from .common import (
AttrCvt,
Renamer,
fold_constant,
get_name,
get_relay_op,
gru_cell,
infer_channels,
infer_shape,
infer_type,
infer_value,
lstm_cell,
new_var,
unbind,
gru_cell,
lstm_cell,
)

__all__ = ["from_onnx"]
Expand Down Expand Up @@ -1786,25 +1786,23 @@ class ArgMax(OnnxOpConverter):
"""Operator converter for ArgMax."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMax")
def _impl_v13(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
select_last_index = attr.get("select_last_index", False)
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
return _op.cast(AttrCvt("argmax")(inputs, attr), "int64")


class ArgMin(OnnxOpConverter):
"""Operator converter for ArgMin."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "select_last_index" in attr:
raise NotImplementedError("select_last_index not supported in ArgMin")
def _impl_v13(cls, inputs, attr, params):
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
select_last_index = attr.get("select_last_index", False)
attr = {"axis": axis, "keepdims": keepdims, "select_last_index": select_last_index}
return _op.cast(AttrCvt("argmin")(inputs, attr), "int64")


Expand Down
20 changes: 14 additions & 6 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"""Reduce operators."""
# pylint: disable=redefined-builtin

from ..expr import Tuple, TupleWrapper
from . import _make
from .tensor import sqrt, log, exp
from .tensor import exp, log, sqrt
from .transform import squeeze
from ..expr import Tuple, TupleWrapper


def argmax(data, axis=None, keepdims=False, exclude=False):
def argmax(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
"""Returns the indices of the maximum values along an axis.
Parameters
Expand All @@ -45,16 +45,20 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
select_last_index : bool
Whether to select the last index or the first index if the max element appears in
multiple indices, default is False (first index).
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmax(data, axis, keepdims, exclude)
return _make.argmax(data, axis, keepdims, exclude, select_last_index)


def argmin(data, axis=None, keepdims=False, exclude=False):
def argmin(data, axis=None, keepdims=False, exclude=False, select_last_index=False):
"""Returns the indices of the minimum values along an axis.
Parameters
Expand All @@ -76,13 +80,17 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
select_last_index : bool
Whether to select the last index or the first index if the min element appears in
multiple indices, default is False (first index).
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.argmin(data, axis, keepdims, exclude)
return _make.argmin(data, axis, keepdims, exclude, select_last_index)


def sum(data, axis=None, keepdims=False, exclude=False):
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def min(data, axis=None, keepdims=False):
return cpp.min(data, axis, keepdims)


def argmax(data, axis=None, keepdims=False):
def argmax(data, axis=None, keepdims=False, select_last_index=False):
"""Returns the indices of the maximum values along an axis.
Parameters
Expand All @@ -185,14 +185,18 @@ def argmax(data, axis=None, keepdims=False):
with size one.
With this option, the result will broadcast correctly against the input array.
select_last_index: bool
Whether to select the last index if the maximum element appears multiple times, else
select the first index.
Returns
-------
ret : tvm.te.Tensor
"""
return cpp.argmax(data, axis, keepdims)
return cpp.argmax(data, axis, keepdims, select_last_index)


def argmin(data, axis=None, keepdims=False):
def argmin(data, axis=None, keepdims=False, select_last_index=False):
"""Returns the indices of the minimum values along an axis.
Parameters
Expand All @@ -210,11 +214,15 @@ def argmin(data, axis=None, keepdims=False):
with size one.
With this option, the result will broadcast correctly against the input array.
select_last_index: bool
Whether to select the last index if the minimum element appears multiple times, else
select the first index.
Returns
-------
ret : tvm.te.Tensor
"""
return cpp.argmin(data, axis, keepdims)
return cpp.argmin(data, axis, keepdims, select_last_index)


def prod(data, axis=None, keepdims=False):
Expand Down
Loading

0 comments on commit bb27c0f

Please sign in to comment.