diff --git a/src/target/source/literal/cuda_binary_search.h b/src/target/source/literal/cuda_binary_search.h index becd3e33c0d6..2c7a2b6a770f 100644 --- a/src/target/source/literal/cuda_binary_search.h +++ b/src/target/source/literal/cuda_binary_search.h @@ -26,15 +26,15 @@ static constexpr const char* _cuda_binary_search_def = R"( template -__forceinline__ __device__ int32_t __lower_bound( +__forceinline__ __device__ int __lower_bound( const DType* __restrict__ arr, DType val, - int32_t l, - int32_t r) { - int32_t low = l - 1, high = r; + int l, + int r) { + int low = l - 1, high = r; /* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */ while (low + 1 < high) { - int32_t mid = (low + high) >> 1; + int mid = (low + high) >> 1; if (arr[mid] < val) { low = mid; } else { @@ -46,15 +46,15 @@ __forceinline__ __device__ int32_t __lower_bound( } template -__forceinline__ __device__ int32_t __upper_bound( +__forceinline__ __device__ int __upper_bound( const DType* __restrict__ arr, DType val, - int32_t l, - int32_t r) { - int32_t low = l - 1, high = r; + int l, + int r) { + int low = l - 1, high = r; /* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */ while (low + 1 < high) { - int32_t mid = (low + high) >> 1; + int mid = (low + high) >> 1; if (arr[mid] > val) { high = mid; } else { diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index eacd369f43d3..b412c62ce1b2 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -92,6 +92,25 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int B[A_indices[vi * NNZ_COLS + vj] * K + vk] +@T.prim_func +def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalis": True}) + A = T.match_buffer(a, (M * K,), "float32") + B = T.match_buffer(b, (N * K,), "float32") + C_data = T.match_buffer(c, (NNZ,), "float32") + C_indptr = T.match_buffer(indptr, (M + 1,), "int32") + C_indices = T.match_buffer(indices, (NNZ,), "int32") + for ij, k in T.grid(NNZ, K): + with T.block("sddmm"): + vij, vk = T.axis.remap("SR", [ij, k]) + T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]]) + T.writes([C_data[vij]]) + with T.init(): + C_data[vij] = 0. + C_data[vij] = C_data[vij] + \ + A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk] + + def test_csrmm(): # generate random input m = 4096 @@ -219,6 +238,50 @@ def test_ellmm(): assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy()) +def test_sddmm(): + # generate random input + m = 4096 + n = 4096 + k = 256 + C = sp.random(m, n, dtype="float32", density=0.0125, format='csr') + indptr = C.indptr + indices = C.indices + C_coo = C.tocoo() + nnz = C.nnz + x = np.random.rand(m, k).astype("float32") + y = np.random.rand(n, k).astype("float32") + z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col] + z = np.zeros((nnz,)).astype("float32") + + # specialize function + _, _, _, _, _, M, N, K, NNZ = sddmm_tir.params + sch = tir.Schedule( + sddmm_tir.specialize( + {M: m, N: n, K: k, NNZ: nnz} + ) + ) + blk = sch.get_block("sddmm") + ij, k = sch.get_loops(blk) + #sch.decompose_reduction(blk, ij) + sch.bind(ij, "blockIdx.x") + ko, ki = sch.split(k, [None, 1]) + sch.bind(ki, "threadIdx.x") + + # convert numpy tensor to tvm ndarray + C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) + C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0)) + X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) + Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0)) + C_data = tvm.nd.array(z, device=tvm.cuda(0)) + + # build function + f = tvm.build(sch.mod['main'], target="cuda") + f(X_nd, Y_nd, C_data, C_indptr, C_indices) + + # assertion + np.allclose(z_ground_truth, C_data.numpy()) + + def test_bmm(): # TODO(zihao) pass @@ -228,4 +291,5 @@ def test_bmm(): test_csrmm() test_bsrmm() test_ellmm() + test_sddmm() test_bmm() diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 036db96b1c2c..16680bfbff2a 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -253,19 +253,21 @@ def test_fma(): assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" -@tvm.script.tir -def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: - n = tir.var('int32') - m = tir.var('int32') - A = tir.match_buffer(a, (n,), dtype='int32') - B = tir.match_buffer(b, (m,), dtype='int32') - C = tir.match_buffer(c, (m,), dtype='int32') - D = tir.match_buffer(d, (m,), dtype='int32') - with tir.block([m], 'search') as [vi]: - tir.reads([A[0:n], B[vi]]) - tir.writes([C[vi], D[vi]]) - C[vi] = tir.lower_bound(A.data, B[vi], 0, n) - D[vi] = tir.upper_bound(A.data, B[vi], 0, n) +@T.prim_func +def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + n = T.var('int32') + m = T.var('int32') + A = T.match_buffer(a, (n,), dtype='int32') + B = T.match_buffer(b, (m,), dtype='int32') + C = T.match_buffer(c, (m,), dtype='int32') + D = T.match_buffer(d, (m,), dtype='int32') + for i in T.serial(0, m): + with T.block('search'): + vi = T.axis.S(m, i) + T.reads([A[0:n], B[vi]]) + T.writes([C[vi], D[vi]]) + C[vi] = T.lower_bound(A.data, B[vi], 0, n) + D[vi] = T.upper_bound(A.data, B[vi], 0, n) def test_binary_search():