Skip to content

Commit

Permalink
ROCm: Add warp shuffles and enable reductions (apache#5727)
Browse files Browse the repository at this point in the history
Thank you @masahi and @wpan11nv for the feedback
  • Loading branch information
t-vi authored and trevor-m committed Jun 18, 2020
1 parent 98147a7 commit ebfbfe2
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 139 deletions.
52 changes: 52 additions & 0 deletions src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <sstream>

Expand All @@ -40,8 +41,59 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
*rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern);
}

inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr e_call = targs[0];
using namespace tir;
const CallNode* call = e_call.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
PrimExpr var = call->args[1];
CHECK_EQ(var.dtype().bits(), 32);

// get own lane in self (__lane_id)
PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
PrimExpr zero = tir::make_zero(DataType::Int(32));
PrimExpr lo = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero},
CallNode::PureExtern);
PrimExpr self = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo},
CallNode::PureExtern);

// compute lane to get from
PrimExpr width = call->args[3];
PrimExpr index;
if (call->name == "tvm_warp_shuffle") {
PrimExpr src_lane = call->args[2];
index = src_lane + (self & ~(width - 1));
} else if (call->name == "tvm_warp_shuffle_up") {
PrimExpr delta = call->args[2];
index = self - delta;
index = SelectNode::make(index < (self & ~(width - 1)), self, index);
} else {
CHECK_EQ(call->name, "tvm_warp_shuffle_down");
PrimExpr delta = call->args[2];
index = self + delta;
index = SelectNode::make((self & (width - 1)) + delta >= width, self, index);
}
PrimExpr res = CallNode::make(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var},
CallNode::PureExtern);
*rv = res;
}

namespace llvm {

// dummy because we don't have the activemask
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_activemask")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
PrimExpr zero = tir::make_zero(DataType::Int(32));
*rv = zero;
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle").set_body(DispatchShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(DispatchShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML);
Expand Down
3 changes: 2 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ Target CreateTarget(const std::string& target_name, const std::vector<std::strin
// For now assume rocm schedule for opencl
if (target_name == "opencl") {
t->device_type = kDLOpenCL;
} else {
} else { // rocm
t->device_type = kDLROCM;
t->thread_warp_size = 64;
}
t->keys_array.push_back(target_name);
t->keys_array.push_back("gpu");
Expand Down
17 changes: 14 additions & 3 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
//
// Allocate reduction vars v[i], i = 0..size-1
//
// for offset from 16 to 1 by 2
// for offset from WARP_SIZE to 1 by 2
//
// a <- load(v[i])
// b <- shuffle_down(load(v[i], offset))
Expand Down Expand Up @@ -244,7 +244,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}

// Emit reductions within a warp.
for (int offset = 16; offset > 0; offset /= 2) {
for (int offset = warp_size_ / 2; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -478,9 +478,20 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// the warp size.
//
// TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
bool is_warp_reduction(const std::vector<DataType>& types) const {
// Only cuda target supports warp reductions.
if (target_->target_name != "cuda") return false;
if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false;

// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->target_name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_vector()) return true;
return ty.bits() != 32;
}))) {
return false;
}

// Supported types:
// {u}int, {u}long, {u}long long, float, double, half/half2
Expand Down
12 changes: 9 additions & 3 deletions tests/python/integration/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def check_device(device, host="llvm"):
check_device("vulkan")
check_device("cuda")
check_device("opencl")
check_device("rocm")
test_prim(te.sum, np.sum)
test_prim(tvm.te.min, np.amin)
test_prim(tvm.te.max, np.amax)
Expand Down Expand Up @@ -179,7 +180,7 @@ def check_target(device, host="stackvm"):
check_target("cuda")
check_target("metal")
check_target("opencl")

check_target("rocm")

def test_rfactor_elemwise_threads():
n = 1025
Expand Down Expand Up @@ -230,6 +231,7 @@ def check_target(device, host="stackvm"):
check_target("cuda")
check_target("metal")
check_target("opencl")
check_target("rocm")

def test_argmax():
def fcombine(x, y):
Expand Down Expand Up @@ -337,6 +339,7 @@ def check_target(device):

check_target("cuda")
check_target("vulkan")
check_target("rocm")

def test_warp_reduction1():
nthx = 32
Expand Down Expand Up @@ -365,10 +368,10 @@ def check_target(device, m, n):
s[B].bind(xi, thread_y)
s[B].bind(xo, block_x)

print(tvm.lower(s, [A, B], simple_mode=True))
tvm.lower(s, [A, B], simple_mode=True)

# validation
func = tvm.build(s, [A, B], "cuda", name="warp_reduction")
func = tvm.build(s, [A, B], device, name="warp_reduction")
a_np = np.random.uniform(size=(m,n)).astype(A.dtype)
b_np = np.zeros((m,), dtype=A.dtype)
a = tvm.nd.array(a_np, ctx)
Expand All @@ -379,6 +382,8 @@ def check_target(device, m, n):

check_target("cuda", m=32, n=256)
check_target("cuda", m=10, n=20)
check_target("rocm", m=32, n=256)
check_target("rocm", m=10, n=20)
# This is a bug in normal reduction.
# check_target("cuda", m=10, n=37)

Expand Down Expand Up @@ -437,6 +442,7 @@ def check_target(device):
tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3)

check_target("cuda")
check_target("rocm")

if __name__ == "__main__":
test_rfactor_elemwise_threads()
Expand Down
Loading

0 comments on commit ebfbfe2

Please sign in to comment.