Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new warp shuffle intrinsics after CUDA Volta architecture #6505

Merged
merged 6 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ void lower_impl(const vector<Function> &output_funcs,

if (t.has_feature(Target::CUDA)) {
debug(1) << "Injecting warp shuffles...\n";
s = lower_warp_shuffles(s);
s = lower_warp_shuffles(s, t);
log("Lowering after injecting warp shuffles:", s);
}

Expand Down
50 changes: 37 additions & 13 deletions src/LowerWarpShuffles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class LowerWarpShuffles : public IRMutator {
};
Scope<AllocInfo> allocation_info;
Scope<Interval> bounds;
int cuda_cap;

Stmt visit(const For *op) override {
ScopedBinding<Interval>
Expand Down Expand Up @@ -562,6 +563,20 @@ class LowerWarpShuffles : public IRMutator {

internal_assert(may_use_warp_shuffle) << name << ", " << idx << ", " << lane << "\n";

// We must add .sync after volta architecture:
// https://docs.nvidia.com/cuda/volta-tuning-guide/index.html
string sync_suffix = "";
if (cuda_cap >= 70) {
sync_suffix = ".sync";
}

auto shfl_args = [&](const std::vector<Expr> &args) {
if (cuda_cap >= 70) {
return args;
}
return std::vector({args[1], args[2], args[3]});
};

string intrin_suffix;
if (shuffle_type.is_float()) {
intrin_suffix = ".f32";
Expand All @@ -578,12 +593,12 @@ class LowerWarpShuffles : public IRMutator {
lane = solve_expression(lane, this_lane_name).result;

Expr shuffled;

Expr membermask = (int)0xffffffff;
if (expr_match(this_lane + wild, lane, result)) {
// We know that 0 <= lane + wild < warp_size by how we
// constructed it, so we can just do a shuffle down.
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], 31}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], 31}), Call::PureExtern);
shuffled = down;
} else if (expr_match((this_lane + wild) % wild, lane, result) &&
is_const_power_of_two_integer(result[1], &bits) &&
Expand All @@ -593,10 +608,10 @@ class LowerWarpShuffles : public IRMutator {
// intermediate registers than using a general gather for
// this.
Expr mask = (1 << bits) - 1;
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], mask}, Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl.up" + intrin_suffix,
{base_val, (1 << bits) - result[0], 0}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], mask}), Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".up" + intrin_suffix,
shfl_args({membermask, base_val, (1 << bits) - result[0], 0}), Call::PureExtern);
Expr cond = (this_lane >= (1 << bits) - result[0]);
Expr equiv = select(cond, up, down);
shuffled = simplify(equiv, true, bounds);
Expand All @@ -609,8 +624,8 @@ class LowerWarpShuffles : public IRMutator {
// could hypothetically be used for boundary conditions.
Expr mask = simplify(((31 & ~(warp_size - 1)) << 8) | 31);
// The idx variant can do a general gather. Use it for all other cases.
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl.idx" + intrin_suffix,
{base_val, lane, mask}, Call::PureExtern);
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".idx" + intrin_suffix,
shfl_args({membermask, base_val, lane, mask}), Call::PureExtern);
}
// TODO: There are other forms, like butterfly and clamp, that
// don't need to use the general gather
Expand Down Expand Up @@ -651,7 +666,9 @@ class LowerWarpShuffles : public IRMutator {
}

public:
LowerWarpShuffles() = default;
LowerWarpShuffles(int cuda_cap)
: cuda_cap(cuda_cap) {
}
};

class HoistWarpShufflesFromSingleIfStmt : public IRMutator {
Expand Down Expand Up @@ -804,22 +821,29 @@ class LowerWarpShufflesInEachKernel : public IRMutator {
Stmt visit(const For *op) override {
if (op->device_api == DeviceAPI::CUDA && has_lane_loop(op)) {
Stmt s = op;
s = LowerWarpShuffles().mutate(s);
s = LowerWarpShuffles(cuda_cap).mutate(s);
s = HoistWarpShuffles().mutate(s);
return simplify(s);
} else {
return IRMutator::visit(op);
}
}

int cuda_cap;

public:
LowerWarpShufflesInEachKernel(int cuda_cap)
: cuda_cap(cuda_cap) {
}
};

} // namespace

Stmt lower_warp_shuffles(Stmt s) {
Stmt lower_warp_shuffles(Stmt s, const Target &t) {
s = hoist_loop_invariant_values(s);
s = SubstituteInLaneVar().mutate(s);
s = simplify(s);
s = LowerWarpShufflesInEachKernel().mutate(s);
s = LowerWarpShufflesInEachKernel(t.get_cuda_capability_lower_bound()).mutate(s);
return s;
};

Expand Down
2 changes: 1 addition & 1 deletion src/LowerWarpShuffles.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Internal {

/** Rewrite access to things stored outside the loop over GPU lanes to
* use nvidia's warp shuffle instructions. */
Stmt lower_warp_shuffles(Stmt s);
Stmt lower_warp_shuffles(Stmt s, const Target &t);

} // namespace Internal
} // namespace Halide
Expand Down
52 changes: 48 additions & 4 deletions test/correctness/register_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();

int cap = t.get_cuda_capability_lower_bound();
if (cap < 50 || cap >= 80) {
printf("[SKIP] CUDA with capability between 5.0 and 7.5 required\n");
// TODO: Use the shfl.sync intrinsics for cuda 8.0 and above
// See issue #5630
if (cap < 50) {
printf("[SKIP] CUDA with capability greater than or equal to 5.0 required, cap:%d\n", cap);
return 0;
}

Expand Down Expand Up @@ -495,6 +493,52 @@ int main(int argc, char **argv) {
}
}

{
// Use warp shuffle to do the reduction.
Func a, b, c;
Var x, y, yo, yi, ylane, u;
RVar ro, ri;

a(x, y) = x + y;
a.compute_root();

RDom r(0, 1024);
b(y) = 0;
b(y) += a(r, y);
c(y) = b(y);

int warp = 8;
c
.split(y, yo, yi, 1 * warp)
.split(yi, yi, ylane, 1)
.gpu_blocks(yo)
.gpu_threads(yi, ylane);
Func intm = b.update()
.split(r, ri, ro, warp)
.reorder(ri, ro)
.rfactor(ro, u);
intm
.compute_at(c, yi)
.update()
.gpu_lanes(u);
intm
.gpu_lanes(u);

Buffer<int> out = c.realize({256});
for (int y = 0; y < out.width(); y++) {
int correct = 0;
for (int x = 0; x < 1024; x++) {
correct += x + y;
}
int actual = out(y);
if (correct != actual) {
printf("out(%d) = %d instead of %d\n",
y, actual, correct);
return -1;
}
}
}

{
// Test a case that caused combinatorial explosion
Var x;
Expand Down