From 4d0cfd984a0bdde599005cb4d36e7898d1095964 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 25 Oct 2021 22:58:53 -0400 Subject: [PATCH] [TIR] Move UnifyThreadBinding to earlier stage (#9365) * Move unify thread binding to earlier stage * Unify thread binding support AttrStmt --- src/driver/driver_api.cc | 2 +- src/tir/transforms/unify_thread_binding.cc | 90 +++++-- ...test_tir_transform_unify_thread_binding.py | 236 ++++++++++-------- 3 files changed, 205 insertions(+), 123 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 24cae798988e..34661f81c847 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -237,10 +237,10 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 6a26103e6079..aa586846f5d4 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -26,11 +26,14 @@ #include #include +#include "../../support/utils.h" #include "ir_utils.h" namespace tvm { namespace tir { +using support::StartsWith; + /*! * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same @@ -41,14 +44,28 @@ class ThreadBindingUnifier : public StmtExprMutator { static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); } private: - Stmt VisitStmt_(const AttrStmtNode* attr) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { // If this AttrStmt is not thread binding attribute, return as usual. - if (attr->attr_key != attr::thread_extent && attr->attr_key != attr::virtual_thread) { - return StmtMutator::VisitStmt_(attr); + if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) { + return StmtMutator::VisitStmt_(op); + } + IterVar old_iter_var = Downcast(op->node); + return UnifyThreadBindingImpl(op, old_iter_var->var, old_iter_var, old_iter_var->dom); + } + + Stmt VisitStmt_(const ForNode* op) final { + // If this For is not thread binding attribute, return as usual. + if (op->kind != ForKind::kThreadBinding) { + return StmtExprMutator::VisitStmt_(op); } + return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), + Range::FromMinExtent(op->min, op->extent)); + } - // Step 1. Fetch the old IterVar and the thread tag. - IterVar old_iter_var = Downcast(attr->node); + template + Stmt UnifyThreadBindingImpl(const Node* op, const Var& old_var, const IterVar& old_iter_var, + const Range& dom) { + // Step 1. Fetch the thread tag. IterVar new_iter_var{nullptr}; const String& thread_tag = old_iter_var->thread_tag; @@ -56,9 +73,12 @@ class ThreadBindingUnifier : public StmtExprMutator { // thread block depth is 0 before the increasement, it means we are entering a new kernel, and // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have // thread axes with different extents. - if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + bool is_kernel_launch_scope = false; + int old_thread_block_depth = thread_block_depth_; + if (StartsWith(thread_tag, "blockIdx.") || !thread_block_depth_) { if (!thread_block_depth_) { thread_tag2iter_var_map_.clear(); + is_kernel_launch_scope = true; } ++thread_block_depth_; } @@ -69,31 +89,56 @@ class ThreadBindingUnifier : public StmtExprMutator { Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; - CHECK(ana.CanProveEqual(old_iter_var->dom->extent, (*it).second->dom->extent)) + ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); + CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) << "ValueError: All loops that are bound to `" << thread_tag << "` should have the same extent. However, there are two loops with extent " - << (*it).second->dom->extent << " and " << old_iter_var->dom->extent - << ", which are not equal"; + << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal"; } else { ObjectPtr p_new_iter_var = make_object(*old_iter_var.get()); p_new_iter_var->var = Var(thread_tag); + p_new_iter_var->dom = dom; new_iter_var = IterVar(p_new_iter_var); thread_tag2iter_var_map_.Set(thread_tag, new_iter_var); + launch_threads_.push_back(new_iter_var); } // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the // new variable in further mutation. Thus, we store the mapping entry. - var_substitution_map_.Set(old_iter_var->var, new_iter_var->var); - - // Step 5. Mutate recursively, update the AttrStmt with the new IterVar, and decrease the depth - // counter if the thread tag starts with "blockIdx". - AttrStmt new_attr = Downcast(StmtMutator::VisitStmt_(attr)); - ObjectPtr p_new_attr = CopyOnWrite(new_attr.get()); - p_new_attr->node = new_iter_var; - if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { - --thread_block_depth_; + var_substitution_map_.Set(old_var, new_iter_var->var); + + // Step 5. Mutate recursively, update the body with the new IterVar, and restore the depth + // counter. Emit for-loops to launch threads if current statement is the outermost thread + // binding of the kernel. + Stmt new_stmt = StmtMutator::VisitStmt_(op); + auto* new_node = new_stmt.as(); + ICHECK(new_node); + thread_block_depth_ = old_thread_block_depth; + if (is_kernel_launch_scope) { + return EmitLaunchThreads(new_node->body); + } else { + return new_node->body; } - return Stmt(p_new_attr); + } + + /*! + * \brief Emit loop nests representing all thread bindings of the kernel + * \param body The body of the innermost loop of the thread bindings. + * \return The loop nests of the thread bindings. + */ + Stmt EmitLaunchThreads(const Stmt& body) { + Stmt result = body; + while (!launch_threads_.empty()) { + const IterVar& thread_binding = launch_threads_.back(); + // Recreate the IterVar as we don't duplicate `dom` in both For and IterVar. This is + // necessary for unit tests. + result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, + ForKind::kThreadBinding, result, + IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + thread_binding->thread_tag)); + launch_threads_.pop_back(); + } + return result; } PrimExpr VisitExpr_(const VarNode* var) final { @@ -106,8 +151,13 @@ class ThreadBindingUnifier : public StmtExprMutator { /*! * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all * occurrences of the thread tag - * */ + */ Map thread_tag2iter_var_map_; + /*! + * \brief A list of IterVar corresponding to threads in current kernel. This will be used to + * generate for-loops to launch threads. + */ + Array launch_threads_; /*! \brief A mapping from old variables to new variables, which is used for substitution */ Map var_substitution_map_; /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index 1ce9b0cacd29..6880aabcd2f7 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest +import sys + import tvm from tvm import te from tvm.script import tir as T @@ -35,6 +37,42 @@ def _check_fail(original): @T.prim_func def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i in T.thread_binding(0, 128, "blockIdx.x"): + for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 + for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j1_1 in T.serial(0, 32): + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 + + +@T.prim_func +def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) + + +@T.prim_func +def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") @@ -42,158 +80,152 @@ def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(i, 128) - with T.launch_thread(j0_0, 4): - for j0_1 in T.serial(0, 32): - T.store( - B.data, - i * 128 + j0_0 * 32 + j0_1, - T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, - True, - ) + T.launch_thread(j0_0, 4) T.launch_thread(j1_0, 4) + + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_1 in T.serial(0, 32): - T.store( - C.data, - i * 128 + j1_0 * 32 + j1_1, - T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, - True, - ) + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 @T.prim_func -def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: - thread_x = T.env_thread("threadIdx.x") - block_x = T.env_thread("blockIdx.x") +def unified_element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - T.launch_thread(block_x, 128) - with T.launch_thread(thread_x, 4): - for j0_1 in T.serial(0, 32): - T.store( - B.data, - block_x * 128 + thread_x * 32 + j0_1, - T.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, - True, - ) - T.launch_thread(thread_x, 4) - for j1_1 in T.serial(0, 32): - T.store( - C.data, - block_x * 128 + thread_x * 32 + j1_1, - T.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, - True, - ) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) @T.prim_func def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: - i_0 = T.env_thread("vthread.x") - i_1 = T.env_thread("threadIdx.x") - j_0 = T.env_thread("vthread.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) - T.launch_thread(i_0, 2) - T.launch_thread(i_1, 64) - T.launch_thread(j_0, 2) - for j_1 in T.serial(0, 64): - T.store( - B.data, - i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, - T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, - True, - ) + for i_0 in T.thread_binding(0, 2, "vthread.x"): + for i_1 in T.thread_binding(0, 64, "threadIdx.x"): + for j_0 in T.thread_binding(0, 2, "vthread.x"): + for j_1 in T.serial(0, 64): + with T.block(""): + B[i_0 * 64 + i_1, j_0 * 64 + j_1] = A[i_0 * 64 + i_1, j_0 * 64 + j_1] * 2.0 @T.prim_func def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: - vthread_x = T.env_thread("vthread.x") - thread_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) - T.launch_thread(vthread_x, 2) - T.launch_thread(thread_x, 64) - T.launch_thread(vthread_x, 2) - for j_1 in T.serial(0, 64): - T.store( - B.data, - vthread_x * 8256 + thread_x * 128 + j_1, - T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, - True, - ) + for vthread_x in T.thread_binding(0, 2, "vthread.x"): + for threadIdx_x in T.thread_binding(0, 64, "threadIdx.x"): + for j_1 in T.serial(0, 64): + with T.block(""): + B[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] = ( + A[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] * 2.0 + ) @T.prim_func def element_wise_two_thread_x_in_same_kernel_not_equal( a: T.handle, b: T.handle, c: T.handle ) -> None: - i = T.env_thread("blockIdx.x") - j0 = T.env_thread("threadIdx.x") - j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 64]) - T.launch_thread(i, 128) - with T.launch_thread(j0, 128): - T.store(B.data, i * 64 + j0, T.load("float32", A.data, i * 128 + j0) * 2.0, True) - T.launch_thread(j1, 64) - T.store(C.data, i * 64 + j1, T.load("float32", A.data, i * 128 + j1) + 1.0, True) + for i in T.thread_binding(0, 128, "blockIdx.x"): + for j0 in T.thread_binding(0, 128, "threadIdx.x"): + B[i, j0] = A[i, j0] * 2.0 + for j1 in T.thread_binding(0, 64, "threadIdx.x"): + C[i, j1] = A[i, j1] + 1.0 @T.prim_func def element_wise_kernels_with_different_size( a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - i0 = T.env_thread("blockIdx.x") - j0 = T.env_thread("threadIdx.x") - i1 = T.env_thread("blockIdx.x") - j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) - with T.launch_thread(i0, 128): - T.launch_thread(j0, 128) - T.store(B.data, i0 * 128 + j0, T.load("float32", A.data, i0 * 128 + j0) * 2.0, True) - T.launch_thread(i1, 256) - T.launch_thread(j1, 256) - T.store(D.data, i1 * 256 + j1, T.load("float32", C.data, i1 * 256 + j1) + 1.0, True) + for i0 in T.thread_binding(0, 128, "blockIdx.x"): + for j0 in T.thread_binding(0, 128, "threadIdx.x"): + B[i0, j0] = A[i0, j0] * 2.0 + for i1 in T.thread_binding(0, 256, "blockIdx.x"): + for j1 in T.thread_binding(0, 256, "threadIdx.x"): + D[i1, j1] = C[i1, j1] + 1.0 @T.prim_func def unified_element_wise_kernels_with_different_size( a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - block_x = T.env_thread("blockIdx.x") - thread_x = T.env_thread("threadIdx.x") - block_x_1 = T.env_thread("blockIdx.x") - thread_x_1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) - with T.launch_thread(block_x, 128): - T.launch_thread(thread_x, 128) - T.store( - B.data, - block_x * 128 + thread_x, - T.load("float32", A.data, block_x * 128 + thread_x) * 2.0, - True, - ) - T.launch_thread(block_x_1, 256) - T.launch_thread(thread_x_1, 256) - T.store( - D.data, - block_x_1 * 256 + thread_x_1, - T.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, - True, - ) + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 128, "threadIdx.x"): + B[blockIdx_x, threadIdx_x] = A[blockIdx_x, threadIdx_x] * 2.0 + for blockIdx_x in T.thread_binding(0, 256, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 256, "threadIdx.x"): + D[blockIdx_x, threadIdx_x] = C[blockIdx_x, threadIdx_x] + 1.0 + + +@T.prim_func +def element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i in T.thread_binding(0, 128, "threadIdx.y"): + for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 + for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j1_1 in T.serial(0, 32): + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 + + +@T.prim_func +def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for blockIdx_x in T.thread_binding(0, 128, "threadIdx.y"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) def test_thread_x(): _check(element_wise_thread_x, unified_element_wise_thread_x) +def test_env_thread_x(): + _check(element_wise_env_thread_x, unified_element_wise_env_thread_x) + + def test_vthread_x(): _check(element_wise_vthread_x, unified_element_wise_vthread_x) @@ -208,6 +240,10 @@ def test_kernels_with_different_size(): ) +def test_implicit_block(): + _check(element_wise_implicit_block, unified_element_wise_implicit_block) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0) @@ -220,8 +256,4 @@ def test_lower_te(): if __name__ == "__main__": - test_thread_x() - test_vthread_x() - test_two_thread_x_in_same_kernel_not_equal() - test_kernels_with_different_size() - test_lower_te() + sys.exit(pytest.main([__file__] + sys.argv[1:]))