Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LowerWarpMemory: remove unneeded shuffle when accessing from the same thread #8681

Merged
merged 1 commit into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store(op->buffer_var, op->value, local_index, op->predicate);
PrimExpr new_value = VisitExpr(op->value);
return Store(op->buffer_var, new_value, local_index, op->predicate);
} else {
return StmtExprMutator::VisitStmt_(op);
}
Expand All @@ -256,6 +257,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
<< " local_index=" << local_index;
PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
if (analyzer_->CanProveEqual(group, warp_index_)) {
return load_value;
}
PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
return Call(load_value.dtype(), builtin::tvm_warp_shuffle(),
{mask, load_value, group, width_, warp_size_});
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,38 @@ def check(device, m):
check(device, m=65)


@tvm.testing.requires_cuda
def test_lower_warp_memory_same_thread():
m = n = 128
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n), name="k")
B = te.compute((m,), lambda i: te.sum(A[i, k], axis=[k]))

s = te.create_schedule(B.op)
BB = s.cache_write(B, "warp")
tx = te.thread_axis("threadIdx.x")
xo, xi = s[B].split(B.op.axis[0], factor=32)
s[B].bind(xi, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[BB].compute_at(s[B], xo)
xo, xi = s[BB].split(s[BB].op.axis[0], factor=32)
s[BB].bind(xi, tx)

cuda_target = tvm.target.Target("cuda")
assert cuda_target.thread_warp_size == 32
mod = tvm.lower(s, [A, B], name="f")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
assert "tvm_warp_shuffle" not in fdevice.astext()


if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()
test_lower_warp_memory_roundup()
test_lower_warp_memory_same_thread()