Skip to content

Commit

Permalink
merge cumsum and cumprod to scan, merge tests
Browse files Browse the repository at this point in the history
fix stuff
  • Loading branch information
Andrew Zhao Luo authored and Andrew Zhao Luo committed Mar 24, 2021
1 parent 4ead969 commit 23d4325
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 322 deletions.
12 changes: 6 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,21 +1464,21 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_cumbinop(topi_compute):
"""Wrap cumbinop style topi compute"""
def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

def _compute_cumbinop(attrs, inputs, _):
def _compute_scanop(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

return _compute_cumbinop
return _compute_scanop


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumbinop(topi.cumsum),
wrap_compute_scanop(topi.cumsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumsum.generic",
)
Expand All @@ -1490,7 +1490,7 @@ def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumbinop(topi.cumprod),
wrap_compute_scanop(topi.cumprod),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumprod.generic",
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
from .sparse_reshape import *
from .scatter_add import *
from .argwhere import *
from .cumsum import *
from .cumprod import *
from .scan import *
from .einsum import *
from .unique import *
from . import generic
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,15 @@ def traverse(op):
return s


def cumbinop(
def scanop(
data: tvm.te.Tensor,
binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"],
identity_value: Union[float, int],
axis: Optional[int] = None,
dtype: Optional[str] = None,
exclusive: Optional[bool] = None,
) -> tvm.te.Tensor:
"""Cumulative binary operator with similar axis behavior as np.cumsum and np.cumprod.
"""Cumulative binary operator (scan) with similar axis behavior as np.cumsum and np.cumprod.
See cumprod and cumsum for an example of use.
Expand Down Expand Up @@ -616,7 +616,7 @@ def cumsum(
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
return cumbinop(
return scanop(
data=data,
binop=tvm.tir.generic.add,
identity_value=0,
Expand Down Expand Up @@ -659,7 +659,7 @@ def cumprod(
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
return cumbinop(
return scanop(
data=data,
binop=tvm.tir.generic.multiply,
identity_value=1,
Expand Down
68 changes: 0 additions & 68 deletions python/tvm/topi/cumprod.py

This file was deleted.

52 changes: 48 additions & 4 deletions python/tvm/topi/cumsum.py → python/tvm/topi/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Cumsum operator"""
"""Scan (cumulative binary) operators"""
from typing import Callable, Optional

import tvm
Expand All @@ -26,7 +26,7 @@
from .utils import get_const_int, prod


def cumbinop(
def scanop(
data: tvm.te.Tensor,
binop: Callable[["tvm.Expr", "tvm.Expr"], "tvm.Expr"],
identity_value: "tvm.Expr",
Expand All @@ -35,7 +35,7 @@ def cumbinop(
dtype: Optional[str] = None,
exclusive: Optional[bool] = None,
) -> tvm.te.Tensor:
"""Cumulative binary operator with similar axis behavior as np.cumsum and np.cumprod.
"""Cumulative binary operator (scan) with similar axis behavior as np.cumsum and np.cumprod.
See cumprod and cumsum for an example of use.
Expand Down Expand Up @@ -181,7 +181,7 @@ def cumsum(
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
return cumbinop(
return scanop(
data=data,
binop=generic.add,
identity_value=0,
Expand All @@ -190,3 +190,47 @@ def cumsum(
dtype=dtype,
exclusive=exclusive,
)


def cumprod(
data: tvm.te.Tensor,
axis: Optional[int] = None,
dtype: Optional[int] = None,
exclusive: Optional[bool] = None,
) -> tvm.te.Tensor:
"""Numpy style cumprod op. Return the cumulative product of the elements along a given axis.
Parameters
----------
data : tvm.te.Tensor
The input data to the operator.
axis : int, optional
Axis along which the cumulative product is computed. The default (None) is to compute
the cumproduct over the flattened array.
dtype : string, optional
Type of the returned array and of the accumulator in which the elements are multiplied.
If dtype is not specified, it defaults to the dtype of data.
exclusive : bool, optional
If True, will return exclusive product in which the first element is not
included. In other terms, if True, the j-th output element would be
the product of the first (j-1) elements. Otherwise, it would be the product of
the first j elements.
Returns
-------
result : tvm.te.Tensor
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
return scanop(
data=data,
binop=generic.multiply,
identity_value=1,
op_name="cumprod_generic",
axis=axis,
dtype=dtype,
exclusive=exclusive,
)
2 changes: 1 addition & 1 deletion python/tvm/topi/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Unique operator"""
from tvm import te, tir
from ..te import hybrid
from .cumsum import cumsum
from .scan import cumsum
from .sort import sort, argsort


Expand Down
40 changes: 19 additions & 21 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,13 +1759,13 @@ def verify_adv_index(data_shape, index_shapes):


# Helper for testing binop functions
cumbinops_supported = {"cumsum": relay.op.cumsum, "cumprod": relay.op.cumprod}
scanops_supported = {"cumsum": relay.op.cumsum, "cumprod": relay.op.cumprod}


def run_binop_tests(
target, ctx, binop_type: str, gt_func: Callable[..., np.array], identity_value: int
):
def assert_relay_cumbinop(
def assert_relay_scanop(
data_np: np.array,
np_out: np.array,
axis: int = None,
Expand All @@ -1776,11 +1776,9 @@ def assert_relay_cumbinop(
):
inp = relay.var("data", relay.TensorType(data_np.shape, str(data_np.dtype)))

if binop_type not in cumbinops_supported.keys():
raise ValueError(
f"Unknown function {binop_type}. Options: {cumbinops_supported.keys()}"
)
out = cumbinops_supported[binop_type](inp, axis, out_dtype, exclusive=exclusive)
if binop_type not in scanops_supported.keys():
raise ValueError(f"Unknown function {binop_type}. Options: {scanops_supported.keys()}")
out = scanops_supported[binop_type](inp, axis, out_dtype, exclusive=exclusive)
func = relay.Function([inp], out)

for kind in ["graph", "debug"]:
Expand All @@ -1789,38 +1787,38 @@ def assert_relay_cumbinop(
tvm.testing.assert_allclose(op_res.asnumpy(), np_out, rtol=rtol, atol=atol)

data = np.array([2, 3, 0])
assert_relay_cumbinop(data, gt_func(data))
assert_relay_cumbinop(data, gt_func(data), out_dtype="int64")
assert_relay_scanop(data, gt_func(data))
assert_relay_scanop(data, gt_func(data), out_dtype="int64")

data = np.random.randn(10, 10)
assert_relay_cumbinop(data, gt_func(data))
assert_relay_cumbinop(data, gt_func(data, axis=0), axis=0)
assert_relay_cumbinop(data, gt_func(data, axis=1), axis=1)
assert_relay_scanop(data, gt_func(data))
assert_relay_scanop(data, gt_func(data, axis=0), axis=0)
assert_relay_scanop(data, gt_func(data, axis=1), axis=1)

data = np.random.randn(10, 5, 10).astype("float32")
assert_relay_cumbinop(data, gt_func(data), rtol=1e-4, atol=1e-4)
assert_relay_cumbinop(data, gt_func(data, axis=0), axis=0, rtol=1e-4, atol=1e-4)
assert_relay_cumbinop(data, gt_func(data, axis=1), axis=1, rtol=1e-4, atol=1e-4)
assert_relay_cumbinop(data, gt_func(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4)
assert_relay_scanop(data, gt_func(data), rtol=1e-4, atol=1e-4)
assert_relay_scanop(data, gt_func(data, axis=0), axis=0, rtol=1e-4, atol=1e-4)
assert_relay_scanop(data, gt_func(data, axis=1), axis=1, rtol=1e-4, atol=1e-4)
assert_relay_scanop(data, gt_func(data, axis=-1), axis=-1, rtol=1e-4, atol=1e-4)

data = np.random.rand(10) > 0.5
data = data.astype(np.int32)
assert_relay_cumbinop(data, gt_func(data, dtype=np.int32))
assert_relay_cumbinop(data, gt_func(data, dtype="int64"), out_dtype="int64")
assert_relay_scanop(data, gt_func(data, dtype=np.int32))
assert_relay_scanop(data, gt_func(data, dtype="int64"), out_dtype="int64")

# Test exclusivity operations
data = np.random.randint(-100, 100, size=(10, 10)).astype("int64")
expected_result = np.roll(gt_func(data), 1)
expected_result[0] = identity_value
assert_relay_cumbinop(data, expected_result, exclusive=True)
assert_relay_scanop(data, expected_result, exclusive=True)

expected_result = np.roll(gt_func(data, axis=0), 1, axis=0)
expected_result[0, :] = identity_value
assert_relay_cumbinop(data, expected_result, exclusive=True, axis=0)
assert_relay_scanop(data, expected_result, exclusive=True, axis=0)

expected_result = np.roll(gt_func(data, axis=1), 1, axis=1)
expected_result[:, 0] = identity_value
assert_relay_cumbinop(data, expected_result, exclusive=True, axis=1)
assert_relay_scanop(data, expected_result, exclusive=True, axis=1)


@tvm.testing.parametrize_targets
Expand Down
Loading

0 comments on commit 23d4325

Please sign in to comment.