diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 6123c613d0bd..6f703c9ec4e3 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -443,8 +443,6 @@ ComputeType DetectComputeType(const ComputeOpNode* self, << "Cannot mix cross thread reduction with Tensorize"; return ComputeType::kTensorize; } - CHECK(normal_red == 0 || thread_red == 0) - << "Cannot mix normal reduction with thread reduce"; if (thread_red != 0) { return ComputeType::kCrossThreadReduction; } else { diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 705d2317940c..1b3d87d57006 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -57,10 +57,63 @@ Stmt MakeCrossThreadReduction( for (PrimExpr v : conds) { cond = cond && v; } + + std::vector> common, normal_red; + for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + IterVarAttr attr; + auto it = stage->iter_var_attrs.find(iv); + if (it != stage->iter_var_attrs.end()) { + attr = (*it).second; + } + if (iv->iter_type == kCommReduce) { + if (attr.defined() && attr->bind_thread.defined()) { + common.emplace_back(nest[i + 1]); + } else { + normal_red.emplace_back(nest[i + 1]); + } + } else { + common.emplace_back(nest[i + 1]); + } + } + + // If we load from and then store into the same res_handles in the thread_allreduce intrinsic, + // something goes wrong, so we use an extra variable here for normal reduction. + std::vector normal_res_handles; + std::vector normal_init, normal_update; + if (!normal_red.empty()) { + normal_res_handles.reserve(size); + normal_init.reserve(size); + normal_update.resize(size); + const CommReducerNode* combiner = reduces[0]->combiner.as(); + CHECK(combiner); + Array lhs; + for (size_t i = 0; i < size; ++i) { + DataType t = reduces[i]->dtype; + normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle()); + lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); + } + Array init_value = combiner->identity_element; + Array update_value = (*combiner)(lhs, reduces[0]->source); + for (size_t i = 0; i < size; ++i) { + DataType t = reduces[i]->dtype; + normal_init.emplace_back(StoreNode::make( + normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); + normal_update.emplace_back(StoreNode::make( + normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + } + } + Array freduce_args; freduce_args.push_back(make_const(DataType::UInt(32), static_cast(size))); for (size_t i = 0; i < size; ++i) { - freduce_args.push_back(reduces[0]->source[i]); + if (!normal_red.empty()) { + DataType t = reduces[i]->dtype; + freduce_args.push_back(LoadNode::make( + t, normal_res_handles[i], 0, const_true(t.lanes()))); + } else { + freduce_args.push_back(reduces[0]->source[i]); + } } freduce_args.push_back(cond); std::vector res_handles(size); @@ -94,6 +147,15 @@ Stmt MakeCrossThreadReduction( tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); + + if (!normal_red.empty()) { + Stmt init_body = SeqStmt::Flatten(normal_init); + Stmt update_body = SeqStmt::Flatten(normal_update); + update_body = MergeNest(normal_red, update_body); + reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body); + reduce_body = MergeNest(MakeIfNest(conds), reduce_body); + } + std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; @@ -110,9 +172,15 @@ Stmt MakeCrossThreadReduction( res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); body = AttrStmtNode::make( res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + if (!normal_red.empty()) { + body = AllocateNode::make( + normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmtNode::make( + normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + } } body = Substitute(body, value_map); - return MergeNest(nest, body); + return MergeNest(common, body); } } // namespace te } // namespace tvm diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index e8c6cd1925a8..453bd4335a39 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -321,6 +321,33 @@ def check_cuda(dtype, m=32, n=32): check_cuda("float32") check_cuda("float16") +def test_cuda_mix_threaded_and_normal_reduction(): + def check_cuda(dtype, m=32, n=32): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + a = tvm.te.placeholder((m, n), name="a", dtype=dtype) + b = topi.sum(a) + with tvm.target.cuda(): + sb = tvm.te.create_schedule(b.op) + i, _ = b.op.reduce_axis + sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) + ctx = tvm.gpu(0) + func = tvm.build(sb, [a, b], 'cuda') + a_np = np.random.uniform(size=(m, n)).astype(a.dtype) + b_np = np.sum(a_np) + a_nd = tvm.nd.array(a_np, ctx) + b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx) + func(a_nd, b_nd) + tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) + + check_cuda("float32") + check_cuda("float16") + def test_cuda_floordiv_with_vectorization(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") @@ -478,7 +505,8 @@ def run_test(dtype): test_rfactor_predicates() test_cuda_const_float_to_half() test_cuda_reduction() + test_cuda_mix_threaded_and_normal_reduction() test_cuda_floordiv_with_vectorization() test_vectorized_intrin1() test_vectorized_intrin2() - test_vectorized_popcount() \ No newline at end of file + test_vectorized_popcount()