Skip to content

Commit

Permalink
[ONNX] Add CumSum operator to ONNX frontend (apache#7391)
Browse files Browse the repository at this point in the history
* [ONNX] Add CumSum operator to ONNX frontend

* Fix lint and add attributes to CumSum

* Fix CumSum test

* Add support exclusive attribute

* Add support reverse attribute

* Fix clang-format

* Fix lint

* Move reverse calculation to ONNX frontend and add exclusive to GPU

* Add test for int type
  • Loading branch information
echuraev authored and trevor-m committed Mar 2, 2021
1 parent f36063d commit c73dac3
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 10 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
Integer axis;
DataType dtype;
Integer exclusive;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(NullValue<Integer>());
}
};

Expand Down
25 changes: 24 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value
from .common import infer_type, get_name


Expand Down Expand Up @@ -1075,6 +1075,28 @@ def _impl_v1(cls, inputs, attr, params):
return _op.shape_of(inputs[0], "int64")


class CumSum(OnnxOpConverter):
"""Operator converter for CumSum."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
dim = inputs[1]

if dim is not None:
dim = int(infer_value(dim, params).asnumpy())

exclusive = attr.get("exclusive", 0)
reverse = attr.get("reverse", 0)

if reverse != 0:
out = _op.reverse(data, axis=dim)
out = _op.cumsum(out, axis=dim, exclusive=exclusive)
return _op.reverse(out, axis=dim)

return _op.cumsum(data, axis=dim, exclusive=exclusive)


class Cast(OnnxOpConverter):
"""Operator converter for Cast."""

Expand Down Expand Up @@ -2736,6 +2758,7 @@ def _get_convert_map(opset):
"Resize": Resize.get_converter(opset),
"NonZero": NonZero.get_converter(opset),
"Range": Range.get_converter(opset),
"CumSum": CumSum.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def compute_scatter_nd(attrs, inputs, output_type):
@_reg.register_compute("cumsum")
def compute_cumsum(attrs, inputs, output_type):
"""Compute definition of cumsum"""
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype)]
return [topi.cumsum(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumsum", strategy.cumsum_strategy)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""

def _compute_cumsum(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype)]
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

return _compute_cumsum

Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ def adv_index(inputs):
return _make.adv_index(Tuple(inputs))


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand All @@ -1378,6 +1378,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.
Returns
-------
result : relay.Expr
Expand Down Expand Up @@ -1407,4 +1413,4 @@ def cumsum(data, axis=None, dtype=None):
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return _make.cumsum(data, axis, dtype)
return _make.cumsum(data, axis, dtype, exclusive)
10 changes: 9 additions & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def traverse(op):
return s


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
Parameters
Expand All @@ -504,6 +504,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.
Returns
-------
result : tvm.te.Tensor
Expand All @@ -514,4 +520,6 @@ def cumsum(data, axis=None, dtype=None):
axis = 0
data = reshape(data, (prod(data.shape),))
axis = get_const_int(axis)
if exclusive is not None and exclusive != 0:
return exclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
return inclusive_scan(data, axis, output_dtype=dtype, binop=tvm.tir.generic.add)
21 changes: 18 additions & 3 deletions python/tvm/topi/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .math import cast


def cumsum(data, axis=None, dtype=None):
def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative sum of the elements along a given axis.
Parameters
Expand All @@ -38,6 +38,12 @@ def cumsum(data, axis=None, dtype=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.
Returns
-------
result : tvm.te.Tensor
Expand Down Expand Up @@ -75,6 +81,9 @@ def maybe_cast(x):
elif i > axis:
axis_mul_after *= value

if exclusive is None:
exclusive = 0

def gen_ir(data_buf, out_buf):
ib = ir_builder.create()
data_buf = ib.buffer_ptr(data_buf)
Expand All @@ -84,12 +93,18 @@ def gen_ir(data_buf, out_buf):
i = fused // axis_mul_after
j = fused % axis_mul_after
base_idx = i * cumsum_axis_len * axis_mul_after + j
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
if exclusive == 0:
out_buf[base_idx] = maybe_cast(data_buf[base_idx])
else:
out_buf[base_idx] = cast(0, dtype)
with ib.for_range(0, cumsum_axis_len - 1, "_k") as _k:
k = _k + 1
cur_idx = base_idx + k * axis_mul_after
prev_idx = base_idx + (k - 1) * axis_mul_after
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
if exclusive == 0:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[cur_idx])
else:
out_buf[cur_idx] = out_buf[prev_idx] + maybe_cast(data_buf[prev_idx])

return ib.get()

Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3757,10 +3757,11 @@ bool CumsumRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Expr MakeCumsum(Expr data, Integer axis, DataType dtype) {
Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Integer exclusive) {
auto attrs = make_object<CumsumAttrs>();
attrs->dtype = dtype;
attrs->axis = axis;
attrs->exclusive = exclusive;
static const Op& op = Op::Get("cumsum");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down
77 changes: 77 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,82 @@ def verify_softplus(indata):
verify_softplus(input_data)


def test_cumsum():
def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
cumsum_node = onnx.helper.make_node(
"CumSum",
inputs=["X", "axis"],
outputs=["Y"],
)
if exclusive != 0:
exclusive_attr = helper.make_attribute("exclusive", exclusive)
cumsum_node.attribute.append(exclusive_attr)
if reverse != 0:
reverse_attr = helper.make_attribute("reverse", reverse)
cumsum_node.attribute.append(reverse_attr)
nodes = [
make_constant_node("axis", onnx.TensorProto.INT32, [1], [axis]),
cumsum_node,
]
if type == "float32":
tensor_type = TensorProto.FLOAT
else:
tensor_type = TensorProto.INT32
type = "int32"

graph = helper.make_graph(
nodes,
"cumsum_test",
inputs=[
helper.make_tensor_value_info("X", tensor_type, list(indata.shape)),
],
outputs=[helper.make_tensor_value_info("Y", tensor_type, list(indata.shape))],
)

model = helper.make_model(graph, producer_name="cumsum_test")

verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11)

data = (
np.array(
[
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
11.0,
12.0,
]
)
.astype(np.float32)
.reshape((3, 4))
)

verify_cumsum(data, 0)
verify_cumsum(data, 1)
verify_cumsum(data, 0, 1, 0)
verify_cumsum(data, 1, 1, 0)
verify_cumsum(data, 0, 0, 1)
verify_cumsum(data, 1, 0, 1)
verify_cumsum(data, 1, 1, 1)
data = np.random.randn(1, 32, 32, 3).astype("float32")
verify_cumsum(data, 1)
data = np.random.randn(1, 32, 32, 3).astype("int32")
verify_cumsum(data, 0, type="int32")
verify_cumsum(data, 1, type="int32")
verify_cumsum(data, 0, 1, 0, type="int32")
verify_cumsum(data, 1, 1, 0, type="int32")
verify_cumsum(data, 0, 0, 1, type="int32")
verify_cumsum(data, 1, 0, 1, type="int32")
verify_cumsum(data, 1, 1, 1, type="int32")


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4040,3 +4116,4 @@ def verify_softplus(indata):
test_size()
test_maxunpool()
test_softplus()
test_cumsum()

0 comments on commit c73dac3

Please sign in to comment.