Skip to content

Commit

Permalink
fix lower_warp_memory (apache#5247)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored and Trevor Morris committed Apr 16, 2020
1 parent d8963bb commit 6a1dc8d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
}

protected:
PrimExpr Mutate_(const VarNode* op) {
PrimExpr VisitExpr_(const VarNode* op) override {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const StoreNode* op) {
Stmt VisitStmt_(const StoreNode* op) override {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
Expand All @@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
}
}

PrimExpr Mutate_(const LoadNode* op) {
PrimExpr VisitExpr_(const LoadNode* op) override {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
Expand Down
51 changes: 49 additions & 2 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
import tvm
from tvm import te
from tvm.contrib.nvcc import have_fp16

def test_lower_warp_mem():
import numpy as np

def test_lower_warp_memory_local_scope():
m = 128
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i] + 3, name='B')
Expand All @@ -44,6 +47,50 @@ def test_lower_warp_mem():
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)

def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

m = 128
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')

cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
with cuda_target:
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 64)
xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(
list(range(1, 32)) + [0] +
list(range(33, 64)) + [32] +
list(range(65, 96)) + [64] +
list(range(97, 128)) + [96],
dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_lower_warp_mem()
test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()

0 comments on commit 6a1dc8d

Please sign in to comment.