Skip to content

Commit

Permalink
fix mode case
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 27, 2019
1 parent ef69736 commit 17f2db6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_cuda(dtype, n, lanes):
print("skip because gpu does not support int8")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
Expand Down Expand Up @@ -165,9 +165,10 @@ def test_cuda_shuffle():
print("skip because cuda is not enabled..")
return

idxm = tvm.indexmod
a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
sch = tvm.create_schedule(c.op)
x = c.op.axis[0]
xo, xi = sch[c].split(x, 4)
Expand Down

0 comments on commit 17f2db6

Please sign in to comment.