Skip to content

Commit

Permalink
[TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size (a…
Browse files Browse the repository at this point in the history
…pache#5307)

* support extent(threadIdx.x) < warp_size in lower_warp_memory

* more docs for lower_warp_memory
  • Loading branch information
roastduck authored and trevor-m committed Jun 18, 2020
1 parent 224252c commit c9d15f9
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 34 deletions.
10 changes: 9 additions & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1228,9 +1228,17 @@ constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id);
* }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
Expand Down
16 changes: 10 additions & 6 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ struct CUDAPopcount {
}
};

struct CUDAShuffle {
std::string operator()(DataType t, std::string name) const {
return "__shfl";
}
};
static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);
Expand Down Expand Up @@ -154,7 +158,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);
.set_body(DispatchCUDAShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
Expand Down
20 changes: 14 additions & 6 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include <tvm/arith/analyzer.h>
#include "../intrin_rule.h"

namespace tvm {
Expand Down Expand Up @@ -89,14 +90,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh")

// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
std::string operator()(DataType t, std::string name) const {
return "intel_sub_group_shuffle";
}
};
static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>);
.set_body(DispatchIntelShuffle);

} // namespace intrin
} // namespace codegen
Expand Down
65 changes: 44 additions & 21 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,45 @@ namespace tir {
//
// Before rewrite,
//
// alloc warp warp_mem[n * warp_size * m]
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
// load warp_mem[m * z + (warp_size * m) * y + x]
// alloc warp warp_mem[n * width * m]
// store warp_mem[m * warp_index + (width * m) * y + x]
// load warp_mem[m * z + (width * m) * y + x]
// subject to x \in [0, m), y \in [0, n)
//
// where width equals to the extent of threadIdx.x, which should
// be no larger than the warp size
//
// After rewrite:
//
// alloc local local_mem[n * m]
// store warp_mem[m * y + x]
// warp_shuffle(load warp_mem[m * y + x], z)
// subject to (m * y + x) is invariant to warp_index
//
// If width == warp size, we are shuffling on full warps.
// Otherwise, we are virtually shuffling on sub-warps,
// whose size equals to width. In this case, you can imagine
// a warp only consists of `width` threads. Width is passed
// as an argument to the shuffle primitive, and will be
// lowered to the device code if the target supports.
//
// A limitation of this sub-warp approach is that users
// cannot shuffle across the sub-warp boundary (i.e. shuffle
// with threadIdx.y or threadIdx.z indices). It can be solved
// via fusing threadIdx.x to the warp size, or improving the
// analyzer to detect both 3 thread axes, which is left for
// future improvements.

// Algorithm
//
// To implement this rewrite rule, we can do the follow step:
// For each warp memory alloc
// - Use linear pattern detector on load index to find m
// - Deduce n given warp_size and alloc size
// - Now that we have m, n, warp_size, we can proceed with the rewrite
// - Deduce n given width and alloc size
// - Now that we have m, n, width, we can proceed with the rewrite

// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
// store warp_mem[m * warp_index + (width * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor {
public:
WarpStoreCoeffFinder(const VarNode* buffer,
Expand Down Expand Up @@ -153,12 +170,12 @@ class WarpIndexFinder : private StmtVisitor {
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
// find the warp co-efficient and the shuffle width in the statement
std::pair<Var, int> Find(const Stmt& stmt) {
this->VisitStmt(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
return std::make_pair(warp_index_->var, width_);
}

private:
Expand All @@ -167,11 +184,12 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
int value;
int value = 0;
CHECK(arith::GetConstInt(op->value, &value) &&
value == warp_size_)
<< "Expect threadIdx.x 's size to be equal to warp size("
<< warp_size_ << ")" << " to enable warp memory"
value <= warp_size_ &&
warp_size_ % value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
if (warp_index_.defined()) {
CHECK(warp_index_.same_as(iv))
Expand All @@ -180,6 +198,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
width_ = value;
warp_index_ = iv;
}
}
Expand All @@ -188,6 +207,8 @@ class WarpIndexFinder : private StmtVisitor {
}
// warp size
int warp_size_{0};
// number of threads involved in one shuffle
int width_{0};
// the warp index
IterVar warp_index_{nullptr};
};
Expand All @@ -204,16 +225,16 @@ class WarpAccessRewriter : protected StmtExprMutator {
CHECK_GT(alloc_size, 0)
<< "warp memory only support constant alloc size";
alloc_size *= op->dtype.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var;
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_, analyzer_).Find(op->body);
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
<< "Warp memory must be multiple of the extent of threadIdx.x";
warp_group_ = alloc_size / (width_ * warp_coeff_);
return AllocateNode::make(
op->buffer_var,
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
{make_const(DataType::Int(32), alloc_size / width_)},
op->condition,
this->VisitStmt(op->body));
}
Expand Down Expand Up @@ -247,7 +268,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
op->dtype, op->buffer_var, local_index, op->predicate);
return CallNode::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle,
{load_value, group},
{load_value, group, width_, warp_size_},
CallNode::Intrinsic);
} else {
return StmtExprMutator::VisitExpr_(op);
Expand Down Expand Up @@ -276,9 +297,9 @@ class WarpAccessRewriter : protected StmtExprMutator {
return std::make_pair(x, z);
} else {
PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_);
y = y * m + x;
PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
Expand All @@ -290,6 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
int warp_size_{0};
// The buffer variable
const VarNode* buffer_;
// number of threads involved in one shuffle
int width_{0};
// Warp index
Var warp_index_;
// the coefficient m
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,48 @@ def check_cuda(dtype):
check_cuda("float32")
check_cuda("float16")

def test_lower_warp_memory_cuda_half_a_warp():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

m = 16
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')

cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 2 * m
with cuda_target:
s = te.create_schedule(B.op)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")

AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(xi, tx)
s[B].bind(xo, bx)
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)

ctx = tvm.gpu(0)
func = tvm.build(s, [A, B], "cuda")
A_np = np.array(list(range(m)), dtype=dtype)
B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
A_nd = tvm.nd.array(A_np, ctx)
B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
func(A_nd, B_nd)
tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()

0 comments on commit c9d15f9

Please sign in to comment.