Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 4, 2021
1 parent 47d45b2 commit 15189ca
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/python/topi/python/test_topi_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
"cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
"metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
}
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
Expand All @@ -47,6 +48,9 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")

for in_dtype in ["float32", "float64"]:
if str(target.kind) == 'metal' and in_dtype == 'float64':
# float64 is not supported in metal
continue
data = np.random.randn(10, 10).astype(in_dtype)
check_cumsum(np.cumsum(data), data)
check_cumsum(np.cumsum(data, axis=0), data, axis=0)
Expand Down Expand Up @@ -74,3 +78,4 @@ def check_cumsum(np_ref, data, axis=None, dtype=None):
test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))

0 comments on commit 15189ca

Please sign in to comment.