Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 13, 2020
1 parent d056e4c commit a1f95a7
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,46 @@ def test_vectorized_cooperative_fetching_xy():

vcf_check_common(s, [A, B, C])

def test_unrolled_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

dtype = 'float32'
target = 'cuda'

## Compute declaration
N = 128
A = te.placeholder((N, N), name='A')
B = te.placeholder((N, N), name='B')
k = te.reduce_axis((0, N), name='k')
C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')

## Schedule
s = te.create_schedule([C.op])
CC = s.cache_write(C, "local")
i, j = s[C].op.axis
bx, tx, ii, ji = s[C].tile(i, j, 1, 2)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(ji)
s[CC].compute_at(s[C], tx)
i, j = s[CC].op.axis
k = s[CC].op.reduce_axis[0]
ko, ki = s[CC].split(k, 2)
s[CC].unroll(ki)
s[CC].vectorize(j)

## Check correctness
ctx = tvm.context(target)
a_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
b_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
c_tvm = tvm.nd.empty((N, N), ctx=ctx)
func_tvm = tvm.build(s, [A, B, C], target=target)
func_tvm(a_tvm, b_tvm, c_tvm)
c_np = c_tvm.asnumpy()
tvm.testing.assert_allclose(c_np, 128 * np.ones((N, N)))

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
Expand All @@ -897,3 +937,5 @@ def test_vectorized_cooperative_fetching_xy():
test_cuda_vectorize_load_permute_pad()
test_vectorized_cooperative_fetching_x()
test_vectorized_cooperative_fetching_xy()
test_unrolled_vectorization()

0 comments on commit a1f95a7

Please sign in to comment.