From 8b57f01e75170a0020e7553f0038bd29ad5fadbb Mon Sep 17 00:00:00 2001 From: "Tang, Shizhi" Date: Fri, 17 Apr 2020 22:35:11 +0800 Subject: [PATCH] [TIR] Make lower_warp_memory support extent(threadIdx.x) < warp_size (#5307) * support extent(threadIdx.x) < warp_size in lower_warp_memory * more docs for lower_warp_memory --- include/tvm/tir/expr.h | 10 ++- src/target/source/intrin_rule_cuda.cc | 16 +++-- src/target/source/intrin_rule_opencl.cc | 20 ++++-- src/tir/transforms/lower_warp_memory.cc | 65 +++++++++++++------ .../test_tir_transform_lower_warp_memory.py | 42 ++++++++++++ 5 files changed, 119 insertions(+), 34 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a1603d5e7bda..6764178cc23c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1228,9 +1228,17 @@ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; /*! * \brief See pseudo code * - * Type tvm_warp_shuffle(Type value, warp_id) { + * Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) { * return (value passed in by warp indicated by warp_id); * } + * + * Parameter warp_id indicates the source thread ID in a warp. + * + * Parameter width indicates the number of threads involved in one + * shuffle. See CUDA document for __shfl. + * + * Parameter warp_size is the size of a warp, which helps a backend + * to determine wheter the width paramter is legal. */ constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; /*! diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index d9441203edc0..f40dd5e86bad 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -81,11 +81,15 @@ struct CUDAPopcount { } }; -struct CUDAShuffle { - std::string operator()(DataType t, std::string name) const { - return "__shfl"; - } -}; +static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + Array cuda_args{{call->args[0], call->args[1], call->args[2]}}; + *rv = CallNode::make( + call->dtype, "__shfl", cuda_args, CallNode::PureExtern); +} TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") .set_body(DispatchExtern); @@ -154,7 +158,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") -.set_body(DispatchExtern); +.set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") .set_body(DispatchExtern); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 1a4f52e4dfd1..7374e6d40032 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -21,6 +21,7 @@ * \file intrin_rule_opencl.cc * \brief OpenCL intrinsic rules. */ +#include #include "../intrin_rule.h" namespace tvm { @@ -89,14 +90,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension -struct IntelShuffle { - std::string operator()(DataType t, std::string name) const { - return "intel_sub_group_shuffle"; - } -}; +static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size + arith::Analyzer analyzer; + CHECK(analyzer.CanProve(call->args[2] == call->args[3])) + << "Intel warp shuffle dose not support width != warp_size"; + Array cuda_args{{call->args[0], call->args[1]}}; + *rv = CallNode::make( + call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern); +} TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") -.set_body(DispatchExtern); +.set_body(DispatchIntelShuffle); } // namespace intrin } // namespace codegen diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 612a8f4d9eef..71e7cfaf4832 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -60,28 +60,45 @@ namespace tir { // // Before rewrite, // -// alloc warp warp_mem[n * warp_size * m] -// store warp_mem[m * warp_index + (warp_size * m) * y + x] -// load warp_mem[m * z + (warp_size * m) * y + x] +// alloc warp warp_mem[n * width * m] +// store warp_mem[m * warp_index + (width * m) * y + x] +// load warp_mem[m * z + (width * m) * y + x] // subject to x \in [0, m), y \in [0, n) // +// where width equals to the extent of threadIdx.x, which should +// be no larger than the warp size +// // After rewrite: // // alloc local local_mem[n * m] // store warp_mem[m * y + x] // warp_shuffle(load warp_mem[m * y + x], z) // subject to (m * y + x) is invariant to warp_index +// +// If width == warp size, we are shuffling on full warps. +// Otherwise, we are virtually shuffling on sub-warps, +// whose size equals to width. In this case, you can imagine +// a warp only consists of `width` threads. Width is passed +// as an argument to the shuffle primitive, and will be +// lowered to the device code if the target supports. +// +// A limitation of this sub-warp approach is that users +// cannot shuffle across the sub-warp boundary (i.e. shuffle +// with threadIdx.y or threadIdx.z indices). It can be solved +// via fusing threadIdx.x to the warp size, or improving the +// analyzer to detect both 3 thread axes, which is left for +// future improvements. // Algorithm // // To implement this rewrite rule, we can do the follow step: // For each warp memory alloc // - Use linear pattern detector on load index to find m -// - Deduce n given warp_size and alloc size -// - Now that we have m, n, warp_size, we can proceed with the rewrite +// - Deduce n given width and alloc size +// - Now that we have m, n, width, we can proceed with the rewrite // Visitor to find m in pattern -// store warp_mem[m * warp_index + (warp_size * m) * y + x] +// store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: WarpStoreCoeffFinder(const VarNode* buffer, @@ -153,12 +170,12 @@ class WarpIndexFinder : private StmtVisitor { explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) { } - // find the warp co-efficient in the statement given the warp size - IterVar Find(const Stmt& stmt) { + // find the warp co-efficient and the shuffle width in the statement + std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); CHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; - return warp_index_; + return std::make_pair(warp_index_->var, width_); } private: @@ -167,11 +184,12 @@ class WarpIndexFinder : private StmtVisitor { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { - int value; + int value = 0; CHECK(arith::GetConstInt(op->value, &value) && - value == warp_size_) - << "Expect threadIdx.x 's size to be equal to warp size(" - << warp_size_ << ")" << " to enable warp memory" + value <= warp_size_ && + warp_size_ % value == 0) + << "Expect threadIdx.x 's size to be no larger than, and a factor of" + << " warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { CHECK(warp_index_.same_as(iv)) @@ -180,6 +198,7 @@ class WarpIndexFinder : private StmtVisitor { << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { + width_ = value; warp_index_ = iv; } } @@ -188,6 +207,8 @@ class WarpIndexFinder : private StmtVisitor { } // warp size int warp_size_{0}; + // number of threads involved in one shuffle + int width_{0}; // the warp index IterVar warp_index_{nullptr}; }; @@ -204,16 +225,16 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); - warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; + std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); warp_coeff_ = WarpStoreCoeffFinder( buffer_, warp_index_, analyzer_).Find(op->body); - CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) - << "Warp memory must be multiple of warp size"; - warp_group_ = alloc_size / (warp_size_ * warp_coeff_); + CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0) + << "Warp memory must be multiple of the extent of threadIdx.x"; + warp_group_ = alloc_size / (width_ * warp_coeff_); return AllocateNode::make( op->buffer_var, op->dtype, - {make_const(DataType::Int(32), alloc_size / warp_size_)}, + {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, this->VisitStmt(op->body)); } @@ -247,7 +268,7 @@ class WarpAccessRewriter : protected StmtExprMutator { op->dtype, op->buffer_var, local_index, op->predicate); return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, - {load_value, group}, + {load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); @@ -276,9 +297,9 @@ class WarpAccessRewriter : protected StmtExprMutator { return std::make_pair(x, z); } else { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); - PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); + PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); y = y * m + x; - PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), + PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); @@ -290,6 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator { int warp_size_{0}; // The buffer variable const VarNode* buffer_; + // number of threads involved in one shuffle + int width_{0}; // Warp index Var warp_index_; // the coefficient m diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 8a31a1537ca2..a761cf1a95d8 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -91,6 +91,48 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +def test_lower_warp_memory_cuda_half_a_warp(): + def check_cuda(dtype): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + m = 16 + A = te.placeholder((m,), name='A', dtype=dtype) + B = te.compute((m,), lambda i: A[(i + 1) % m], name='B') + + cuda_target = tvm.target.create("cuda") + assert cuda_target.thread_warp_size == 2 * m + with cuda_target: + s = te.create_schedule(B.op) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + + AA = s.cache_read(A, "warp", [B]) + xo, xi = s[B].split(B.op.axis[0], nparts=1) + s[B].bind(xi, tx) + s[B].bind(xo, bx) + s[AA].compute_at(s[B], xo) + xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1) + s[AA].bind(xo, bx) + s[AA].bind(xi, tx) + + ctx = tvm.gpu(0) + func = tvm.build(s, [A, B], "cuda") + A_np = np.array(list(range(m)), dtype=dtype) + B_np = np.array(list(range(1, m)) + [0], dtype=dtype) + A_nd = tvm.nd.array(A_np, ctx) + B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx) + func(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3) + + check_cuda("float32") + check_cuda("float16") + if __name__ == "__main__": test_lower_warp_memory_local_scope() test_lower_warp_memory_cuda_end_to_end() + test_lower_warp_memory_cuda_half_a_warp()