Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Fix lower_warp_memory #5247

Merged
merged 1 commit into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
}

protected:
PrimExpr Mutate_(const VarNode* op) {
PrimExpr VisitExpr_(const VarNode* op) override {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}

Stmt VisitStmt_(const StoreNode* op) {
Stmt VisitStmt_(const StoreNode* op) override {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
Expand All @@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
}
}

PrimExpr Mutate_(const LoadNode* op) {
PrimExpr VisitExpr_(const LoadNode* op) override {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
Expand Down
51 changes: 49 additions & 2 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.
import tvm
from tvm import te
from tvm.contrib.nvcc import have_fp16

def test_lower_warp_mem():
import numpy as np

def test_lower_warp_memory_local_scope():
m = 128
A = te.placeholder((m,), name='A')
B = te.compute((m,), lambda i: A[i] + 3, name='B')
Expand All @@ -44,6 +47,50 @@ def test_lower_warp_mem():
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)

def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

m = 128
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')

cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
with cuda_target:
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 64)
xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(
list(range(1, 32)) + [0] +
list(range(33, 64)) + [32] +
list(range(65, 96)) + [64] +
list(range(97, 128)) + [96],
dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_lower_warp_mem()
test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()