Skip to content

Commit

Permalink
adding masa's fix from apache#7669
Browse files Browse the repository at this point in the history
Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
  • Loading branch information
tmoreau89 committed Mar 16, 2021
1 parent e0e6104 commit 44d6e50
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"Scan related operators"
import logging

import tvm
from tvm import te
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
Expand Down Expand Up @@ -109,7 +111,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, size_cast_dtype))),
"int64",
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def check_mod(mod, x_np, res_np):

check_mod(mod, x_np, res_np)

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.cumsum(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
res_np = np.cumsum(x_np)

check_mod(mod, x_np, res_np)


if __name__ == "__main__":
test_bool_load()
Expand Down

0 comments on commit 44d6e50

Please sign in to comment.