Skip to content

Commit

Permalink
[TE] Support mixing normal and cross-thread reduction (#5193)
Browse files Browse the repository at this point in the history
* Support mixing normal and cross-thread reduction

* minor improvements
  • Loading branch information
roastduck authored Apr 4, 2020
1 parent 75e936e commit b41f4e5
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
2 changes: 0 additions & 2 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,6 @@ ComputeType DetectComputeType(const ComputeOpNode* self,
<< "Cannot mix cross thread reduction with Tensorize";
return ComputeType::kTensorize;
}
CHECK(normal_red == 0 || thread_red == 0)
<< "Cannot mix normal reduction with thread reduce";
if (thread_red != 0) {
return ComputeType::kCrossThreadReduction;
} else {
Expand Down
72 changes: 70 additions & 2 deletions src/te/operation/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,63 @@ Stmt MakeCrossThreadReduction(
for (PrimExpr v : conds) {
cond = cond && v;
}

std::vector<std::vector<Stmt>> common, normal_red;
for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) {
IterVar iv = stage->leaf_iter_vars[i];
IterVarAttr attr;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
attr = (*it).second;
}
if (iv->iter_type == kCommReduce) {
if (attr.defined() && attr->bind_thread.defined()) {
common.emplace_back(nest[i + 1]);
} else {
normal_red.emplace_back(nest[i + 1]);
}
} else {
common.emplace_back(nest[i + 1]);
}
}

// If we load from and then store into the same res_handles in the thread_allreduce intrinsic,
// something goes wrong, so we use an extra variable here for normal reduction.
std::vector<Var> normal_res_handles;
std::vector<Stmt> normal_init, normal_update;
if (!normal_red.empty()) {
normal_res_handles.reserve(size);
normal_init.reserve(size);
normal_update.resize(size);
const CommReducerNode* combiner = reduces[0]->combiner.as<CommReducerNode>();
CHECK(combiner);
Array<PrimExpr> lhs;
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes())));
}
Array<PrimExpr> init_value = combiner->identity_element;
Array<PrimExpr> update_value = (*combiner)(lhs, reduces[0]->source);
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_init.emplace_back(StoreNode::make(
normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
normal_update.emplace_back(StoreNode::make(
normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
}
}

Array<PrimExpr> freduce_args;
freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
if (!normal_red.empty()) {
DataType t = reduces[i]->dtype;
freduce_args.push_back(LoadNode::make(
t, normal_res_handles[i], 0, const_true(t.lanes())));
} else {
freduce_args.push_back(reduces[0]->source[i]);
}
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
Expand Down Expand Up @@ -94,6 +147,15 @@ Stmt MakeCrossThreadReduction(
tir::attr::reduce_scope,
make_zero(DataType::Handle()),
reduce_body);

if (!normal_red.empty()) {
Stmt init_body = SeqStmt::Flatten(normal_init);
Stmt update_body = SeqStmt::Flatten(normal_update);
update_body = MergeNest(normal_red, update_body);
reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body);
reduce_body = MergeNest(MakeIfNest(conds), reduce_body);
}

std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
Expand All @@ -110,9 +172,15 @@ Stmt MakeCrossThreadReduction(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
if (!normal_red.empty()) {
body = AllocateNode::make(
normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
}
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
return MergeNest(common, body);
}
} // namespace te
} // namespace tvm
30 changes: 29 additions & 1 deletion tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,33 @@ def check_cuda(dtype, m=32, n=32):
check_cuda("float32")
check_cuda("float16")

def test_cuda_mix_threaded_and_normal_reduction():
def check_cuda(dtype, m=32, n=32):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return

a = tvm.te.placeholder((m, n), name="a", dtype=dtype)
b = topi.sum(a)
with tvm.target.cuda():
sb = tvm.te.create_schedule(b.op)
i, _ = b.op.reduce_axis
sb[b].bind(i, tvm.te.thread_axis("threadIdx.x"))
ctx = tvm.gpu(0)
func = tvm.build(sb, [a, b], 'cuda')
a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
b_np = np.sum(a_np)
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

check_cuda("float32")
check_cuda("float16")

def test_cuda_floordiv_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
Expand Down Expand Up @@ -528,7 +555,8 @@ def run_test(dtype):
test_rfactor_predicates()
test_cuda_const_float_to_half()
test_cuda_reduction()
test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
test_vectorized_popcount()
test_vectorized_popcount()

0 comments on commit b41f4e5

Please sign in to comment.