Skip to content

Commit

Permalink
fix stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Zhao Luo authored and Andrew Zhao Luo committed Mar 24, 2021
1 parent 5fae7f0 commit a37657d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 16 deletions.
44 changes: 44 additions & 0 deletions python/tvm/topi/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,47 @@ def cumsum(
dtype=dtype,
exclusive=exclusive,
)


def cumprod(
data: tvm.te.Tensor,
axis: Optional[int] = None,
dtype: Optional[int] = None,
exclusive: Optional[bool] = None,
) -> tvm.te.Tensor:
"""Numpy style cumprod op. Return the cumulative product of the elements along a given axis.
Parameters
----------
data : tvm.te.Tensor
The input data to the operator.
axis : int, optional
Axis along which the cumulative product is computed. The default (None) is to compute
the cumproduct over the flattened array.
dtype : string, optional
Type of the returned array and of the accumulator in which the elements are multiplied.
If dtype is not specified, it defaults to the dtype of data.
exclusive : bool, optional
If True, will return exclusive product in which the first element is not
included. In other terms, if True, the j-th output element would be
the product of the first (j-1) elements. Otherwise, it would be the product of
the first j elements.
Returns
-------
result : tvm.te.Tensor
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
"""
return scanop(
data=data,
binop=generic.multiply,
identity_value=1,
op_name="cumprod_generic",
axis=axis,
dtype=dtype,
exclusive=exclusive,
)
2 changes: 1 addition & 1 deletion python/tvm/topi/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Unique operator"""
from tvm import te, tir
from ..te import hybrid
from .cumsum import cumsum
from .scan import cumsum
from .sort import sort, argsort


Expand Down
37 changes: 22 additions & 15 deletions tests/python/topi/python/test_topi_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable

import numpy as np
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import topi

topi_funcs = {
'cumsum': {'generic': topi.cumsum, 'cuda': topi.cuda.cumsum},
'cumprod' = {'generic': topi.cumprod, 'cuda': topi.cuda.cumprod}
"cumsum": {"generic": topi.cumsum, "cuda": topi.cuda.cumsum},
"cumprod": {"generic": topi.cumprod, "cuda": topi.cuda.cumprod},
}

identity_value = {'cumsum': 0, 'cumprod': 1}
identity_value = {"cumsum": 0, "cumprod": 1}


def get_implementations(name):
topi_func_generic = topi_funcs[name]['generic'],
topi_func_cuda = topi_funcs[name]['cuda']
def get_implementations(name, axis, dtype, exclusive):
topi_func_generic = topi_funcs[name]["generic"]
topi_func_cuda = topi_funcs[name]["cuda"]

return {
"generic": (
Expand All @@ -54,14 +57,15 @@ def get_implementations(name):
),
}

def test_scan_helper(

def _run_tests(
ctx,
target,
op_name: str = 'cumsum',
gt_func: Callable[..., np.array] = gt_func,
op_name: str = "cumsum",
gt_func: Callable[..., np.array] = np.cumsum,
):
def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False):
implementations = get_implementations(name)
implementations = get_implementations(op_name, axis, dtype, exclusive)
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)

Expand Down Expand Up @@ -104,24 +108,27 @@ def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False):
data = np.random.randint(-100, 100, size=(100, 100)).astype("int64")

expected_result = np.roll(gt_func(data), 1)
expected_result[0] = identity_value[name]
expected_result[0] = identity_value[op_name]
check_scan(expected_result, data, dtype="int64", exclusive=True)

expected_result = np.roll(gt_func(data, axis=0, dtype=in_dtype), 1, axis=0)
expected_result[0, :] = identity_value[name]
expected_result[0, :] = identity_value[op_name]
check_scan(expected_result, data, axis=0, exclusive=True)

expected_result = np.roll(gt_func(data, axis=1, dtype=in_dtype), 1, axis=1)
expected_result[:, 0] = identity_value[name]
expected_result[:, 0] = identity_value[op_name]
check_scan(gt_func(data, axis=1, dtype=in_dtype), data, axis=1)


@tvm.testing.parametrize_targets
def test_cumsum(ctx, target):
test_scan_helper(ctx, target, op_name='cumsum', gt_func=np.cumsum)
_run_tests(ctx, target, op_name="cumsum", gt_func=np.cumsum)


@tvm.testing.parametrize_targets
def test_cumprod(ctx, target):
test_scan_helper(ctx, target, op_name='cumprod', gt_func=np.cumprod)
_run_tests(ctx, target, op_name="cumprod", gt_func=np.cumprod)


if __name__ == "__main__":
test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
Expand Down

0 comments on commit a37657d

Please sign in to comment.