Skip to content

Commit

Permalink
Support new warp shuffle intrinsics after CUDA Volta architecture (#6505
Browse files Browse the repository at this point in the history
)

* warp shuffle for volta.

* Add a warp shuffle test.

* Remove TODO because we have HoistWarpShuffles.

* Fix test case position.

* Pass target to lower_warp_shuffles.

* format

Co-authored-by: jinyue.jy <jinyue.jy@alibaba-inc.com>
  • Loading branch information
jinderek and jinyue.jy authored Dec 23, 2021
1 parent e7f655b commit 1d1f06a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ void lower_impl(const vector<Function> &output_funcs,

if (t.has_feature(Target::CUDA)) {
debug(1) << "Injecting warp shuffles...\n";
s = lower_warp_shuffles(s);
s = lower_warp_shuffles(s, t);
log("Lowering after injecting warp shuffles:", s);
}

Expand Down
50 changes: 37 additions & 13 deletions src/LowerWarpShuffles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class LowerWarpShuffles : public IRMutator {
};
Scope<AllocInfo> allocation_info;
Scope<Interval> bounds;
int cuda_cap;

Stmt visit(const For *op) override {
ScopedBinding<Interval>
Expand Down Expand Up @@ -562,6 +563,20 @@ class LowerWarpShuffles : public IRMutator {

internal_assert(may_use_warp_shuffle) << name << ", " << idx << ", " << lane << "\n";

// We must add .sync after volta architecture:
// https://docs.nvidia.com/cuda/volta-tuning-guide/index.html
string sync_suffix = "";
if (cuda_cap >= 70) {
sync_suffix = ".sync";
}

auto shfl_args = [&](const std::vector<Expr> &args) {
if (cuda_cap >= 70) {
return args;
}
return std::vector({args[1], args[2], args[3]});
};

string intrin_suffix;
if (shuffle_type.is_float()) {
intrin_suffix = ".f32";
Expand All @@ -578,12 +593,12 @@ class LowerWarpShuffles : public IRMutator {
lane = solve_expression(lane, this_lane_name).result;

Expr shuffled;

Expr membermask = (int)0xffffffff;
if (expr_match(this_lane + wild, lane, result)) {
// We know that 0 <= lane + wild < warp_size by how we
// constructed it, so we can just do a shuffle down.
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], 31}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], 31}), Call::PureExtern);
shuffled = down;
} else if (expr_match((this_lane + wild) % wild, lane, result) &&
is_const_power_of_two_integer(result[1], &bits) &&
Expand All @@ -593,10 +608,10 @@ class LowerWarpShuffles : public IRMutator {
// intermediate registers than using a general gather for
// this.
Expr mask = (1 << bits) - 1;
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl.down" + intrin_suffix,
{base_val, result[0], mask}, Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl.up" + intrin_suffix,
{base_val, (1 << bits) - result[0], 0}, Call::PureExtern);
Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix,
shfl_args({membermask, base_val, result[0], mask}), Call::PureExtern);
Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".up" + intrin_suffix,
shfl_args({membermask, base_val, (1 << bits) - result[0], 0}), Call::PureExtern);
Expr cond = (this_lane >= (1 << bits) - result[0]);
Expr equiv = select(cond, up, down);
shuffled = simplify(equiv, true, bounds);
Expand All @@ -609,8 +624,8 @@ class LowerWarpShuffles : public IRMutator {
// could hypothetically be used for boundary conditions.
Expr mask = simplify(((31 & ~(warp_size - 1)) << 8) | 31);
// The idx variant can do a general gather. Use it for all other cases.
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl.idx" + intrin_suffix,
{base_val, lane, mask}, Call::PureExtern);
shuffled = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".idx" + intrin_suffix,
shfl_args({membermask, base_val, lane, mask}), Call::PureExtern);
}
// TODO: There are other forms, like butterfly and clamp, that
// don't need to use the general gather
Expand Down Expand Up @@ -651,7 +666,9 @@ class LowerWarpShuffles : public IRMutator {
}

public:
LowerWarpShuffles() = default;
LowerWarpShuffles(int cuda_cap)
: cuda_cap(cuda_cap) {
}
};

class HoistWarpShufflesFromSingleIfStmt : public IRMutator {
Expand Down Expand Up @@ -804,22 +821,29 @@ class LowerWarpShufflesInEachKernel : public IRMutator {
Stmt visit(const For *op) override {
if (op->device_api == DeviceAPI::CUDA && has_lane_loop(op)) {
Stmt s = op;
s = LowerWarpShuffles().mutate(s);
s = LowerWarpShuffles(cuda_cap).mutate(s);
s = HoistWarpShuffles().mutate(s);
return simplify(s);
} else {
return IRMutator::visit(op);
}
}

int cuda_cap;

public:
LowerWarpShufflesInEachKernel(int cuda_cap)
: cuda_cap(cuda_cap) {
}
};

} // namespace

Stmt lower_warp_shuffles(Stmt s) {
Stmt lower_warp_shuffles(Stmt s, const Target &t) {
s = hoist_loop_invariant_values(s);
s = SubstituteInLaneVar().mutate(s);
s = simplify(s);
s = LowerWarpShufflesInEachKernel().mutate(s);
s = LowerWarpShufflesInEachKernel(t.get_cuda_capability_lower_bound()).mutate(s);
return s;
};

Expand Down
2 changes: 1 addition & 1 deletion src/LowerWarpShuffles.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Internal {

/** Rewrite access to things stored outside the loop over GPU lanes to
* use nvidia's warp shuffle instructions. */
Stmt lower_warp_shuffles(Stmt s);
Stmt lower_warp_shuffles(Stmt s, const Target &t);

} // namespace Internal
} // namespace Halide
Expand Down
52 changes: 48 additions & 4 deletions test/correctness/register_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ int main(int argc, char **argv) {
Target t = get_jit_target_from_environment();

int cap = t.get_cuda_capability_lower_bound();
if (cap < 50 || cap >= 80) {
printf("[SKIP] CUDA with capability between 5.0 and 7.5 required\n");
// TODO: Use the shfl.sync intrinsics for cuda 8.0 and above
// See issue #5630
if (cap < 50) {
printf("[SKIP] CUDA with capability greater than or equal to 5.0 required, cap:%d\n", cap);
return 0;
}

Expand Down Expand Up @@ -495,6 +493,52 @@ int main(int argc, char **argv) {
}
}

{
// Use warp shuffle to do the reduction.
Func a, b, c;
Var x, y, yo, yi, ylane, u;
RVar ro, ri;

a(x, y) = x + y;
a.compute_root();

RDom r(0, 1024);
b(y) = 0;
b(y) += a(r, y);
c(y) = b(y);

int warp = 8;
c
.split(y, yo, yi, 1 * warp)
.split(yi, yi, ylane, 1)
.gpu_blocks(yo)
.gpu_threads(yi, ylane);
Func intm = b.update()
.split(r, ri, ro, warp)
.reorder(ri, ro)
.rfactor(ro, u);
intm
.compute_at(c, yi)
.update()
.gpu_lanes(u);
intm
.gpu_lanes(u);

Buffer<int> out = c.realize({256});
for (int y = 0; y < out.width(); y++) {
int correct = 0;
for (int x = 0; x < 1024; x++) {
correct += x + y;
}
int actual = out(y);
if (correct != actual) {
printf("out(%d) = %d instead of %d\n",
y, actual, correct);
return -1;
}
}
}

{
// Test a case that caused combinatorial explosion
Var x;
Expand Down

0 comments on commit 1d1f06a

Please sign in to comment.