Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new warp shuffle intrinsics after CUDA Volta architecture #6505

Merged
merged 6 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/LowerWarpShuffles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,22 @@ class LowerWarpShuffles : public IRMutator {

internal_assert(may_use_warp_shuffle) << name << ", " << idx << ", " << lane << "\n";

// Reference: https://docs.nvidia.com/cuda/volta-tuning-guide/index.html
// We must add .sync after volta architecture.
string sync_suffix = "";
Target t = get_jit_target_from_environment();
abadams marked this conversation as resolved.
Show resolved Hide resolved
int cap = t.get_cuda_capability_lower_bound();
if (cap >= 70) {
sync_suffix = ".sync";
}

auto shfl_args = [&](const std::vector<Expr> &args) {
if (cap >= 70) {
return args;
}
return std::vector({args[1], args[2], args[3]});
};

string intrin_suffix;
if (shuffle_type.is_float()) {
intrin_suffix = ".f32";
Expand All @@ -578,12 +594,12 @@ class LowerWarpShuffles : public IRMutator {
lane = solve_expression(lane, this_lane_name).result;

Expr shuffled;

Expr membermask = (int)0xffffffff;
if (expr_match(this_lane + wild, lane, result)) {
// We know that 0 <= lane + wild < warp_size by how we
// constructed it, so we can just do a shuffle down.
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], 31}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], 31}), Call::PureExtern);
shuffled = down;
} else if (expr_match((this_lane + wild) % wild, lane, result) &&
is_const_power_of_two_integer(result[1], &bits) &&
Expand All @@ -593,10 +609,10 @@ class LowerWarpShuffles : public IRMutator {
// intermediate registers than using a general gather for
// this.
Expr mask = (1 << bits) - 1;
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], mask}, Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl.up" + intrin_suffix,
{base_val, (1 << bits) - result[0], 0}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], mask}), Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".up" + intrin_suffix,
shfl_args({membermask, base_val, (1 << bits) - result[0], 0}), Call::PureExtern);
Expr cond = (this_lane >= (1 << bits) - result[0]);
Expr equiv = select(cond, up, down);
shuffled = simplify(equiv, true, bounds);
Expand All @@ -609,8 +625,8 @@ class LowerWarpShuffles : public IRMutator {
// could hypothetically be used for boundary conditions.
Expr mask = simplify(((31 & ~(warp_size - 1)) << 8) | 31);
// The idx variant can do a general gather. Use it for all other cases.
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl.idx" + intrin_suffix,
{base_val, lane, mask}, Call::PureExtern);
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".idx" + intrin_suffix,
shfl_args({membermask, base_val, lane, mask}), Call::PureExtern);
}
// TODO: There are other forms, like butterfly and clamp, that
// don't need to use the general gather
Expand Down
53 changes: 49 additions & 4 deletions test/correctness/register_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();

int cap = t.get_cuda_capability_lower_bound();
if (cap < 50 || cap >= 80) {
printf("[SKIP] CUDA with capability between 5.0 and 7.5 required\n");
// TODO: Use the shfl.sync intrinsics for cuda 8.0 and above
// See issue #5630
if (cap < 50) {
printf("[SKIP] CUDA with capability greater than or equal to 5.0 required, cap:%d\n", cap);
return 0;
}

Expand Down Expand Up @@ -185,6 +183,53 @@ int main(int argc, char **argv) {
}
}

{
// Using warp shuffle to do the reduction.
Func a, b, c;
Var x, y, yo, yi, ylane, u;
RVar ro, ri;

a(x, y) = x + y;
a.compute_root();

RDom r(0, 1024);
b(y) = 0;
b(y) += a(r, y);
c(y) = b(y);

int warp = 8;
c
.split(y, yo, yi, 1 * warp)
.split(yi, yi, ylane, 1)
.gpu_blocks(yo)
.gpu_threads(yi, ylane);
Func intm = b.update()
.split(r, ri, ro, warp)
.reorder(ri, ro)
.rfactor(ro, u);
intm
.compute_at(c, yi)
.update()
.gpu_lanes(u);
intm
.gpu_lanes(u);

Buffer<int> out = c.realize({256});
for (int y = 0; y < out.width(); y++) {
int correct = 0;
for (int x = 0; x < 1024; x++) {
correct += x + y;
}
int actual = out(y);
if (correct != actual) {
printf("out(%d) = %d instead of %d\n",
y, actual, correct);
return -1;
}
}
}


{
// Same as above, but in half-warps
abadams marked this conversation as resolved.
Show resolved Hide resolved
Func a, b, c, d;
Expand Down