Skip to content

Commit

Permalink
LowerWarpMemory: remove unneeded shuffle when accessing from the same…
Browse files Browse the repository at this point in the history
… thread (apache#8681)
  • Loading branch information
vinx13 authored and mehrdadh committed Aug 11, 2021
1 parent 4b7b32e commit 387836c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store(op->buffer_var, op->value, local_index, op->predicate);
PrimExpr new_value = VisitExpr(op->value);
return Store(op->buffer_var, new_value, local_index, op->predicate);
} else {
return StmtExprMutator::VisitStmt_(op);
}
Expand All @@ -256,6 +257,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
<< " local_index=" << local_index;
PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
if (analyzer_->CanProveEqual(group, warp_index_)) {
return load_value;
}
PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
{mask, load_value, group, width_, warp_size_});
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,38 @@ def check(device, m):
check(device, m=65)


@tvm.testing.requires_cuda
def test_lower_warp_memory_same_thread():
m = n = 128
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
B = te.compute((m,), lambda i: te.sum(A[i, k], axis=[k]))

s = te.create_schedule(B.op)
BB = s.cache_write(B, "warp")
tx = te.thread_axis("threadIdx.x")
xo, xi = s[B].split(B.op.axis[0], factor=32)
s[B].bind(xi, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[BB].compute_at(s[B], xo)
xo, xi = s[BB].split(s[BB].op.axis[0], factor=32)
s[BB].bind(xi, tx)

cuda_target = tvm.target.Target("cuda")
assert cuda_target.thread_warp_size == 32
mod = tvm.lower(s, [A, B], name="f")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
assert "tvm_warp_shuffle" not in fdevice.astext()


if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()
test_lower_warp_memory_roundup()
test_lower_warp_memory_same_thread()

0 comments on commit 387836c

Please sign in to comment.