Skip to content

Commit

Permalink
Handle loads of broadcasts in FlattenNestedRamps (#8139)
Browse files Browse the repository at this point in the history
With sufficiently perverse schedules, it's possible to end up with a
load of a broadcast index (rather than a broadcast of a scalar load).
This made FlattenNestedRamps divide by zero. Unfortunately this happened
in a complex production pipeline, so I'm not entirely sure how to
reproduce it. For that pipeline, this change fixes it and produces
correct output.
  • Loading branch information
abadams authored Mar 8, 2024
1 parent 8cc4f02 commit 009fe7a
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/FlattenNestedRamps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ class FlattenRamps : public IRMutator {

// If they are, we'll have a full vector of const_indices
if ((int)const_indices.size() == lanes) {

// Compute the stride for the underlying strided load
int stride = 0;
for (int c : const_indices) {
stride = (int)gcd(stride, c);
}
for (int &c : const_indices) {
c /= stride;
int stride = 0, extent = 1;
if (max_constant_offset > 0) {
for (int c : const_indices) {
stride = (int)gcd(stride, c);
}
for (int &c : const_indices) {
c /= stride;
}
// Compute the number of elements loaded
extent = (int)((max_constant_offset / stride) + 1);
}

// Compute the number of elements loaded
int extent = (int)((max_constant_offset / stride) + 1);

// If we're gathering from a very large range, it
// might be better to just do the gather rather than
// doing a big dense load and then shuffling. We
Expand All @@ -105,12 +105,22 @@ class FlattenRamps : public IRMutator {
// in the schedule somehow.
const int max_unused_lane_factor = 4;
if (extent < max_unused_lane_factor * lanes) {
Expr dense_index = Ramp::make(min_lane, make_const(min_lane.type(), stride), extent);
Expr dense_load =
Load::make(op->type.with_lanes(extent), op->name, dense_index,
op->image, op->param,
const_true(extent), ModulusRemainder{});
return Shuffle::make({dense_load}, const_indices);
if (max_constant_offset == 0) {
// It's a load of a broadcast. Convert it to a broadcast of a load
Expr load = Load::make(op->type.element_of(), op->name, min_lane,
op->image, op->param,
const_true(), ModulusRemainder{});
return Broadcast::make(load, lanes);
} else {
// Turn it into a dense load and a shuffle
Expr dense_index =
Ramp::make(min_lane, make_const(min_lane.type(), stride), extent);
Expr dense_load =
Load::make(op->type.with_lanes(extent), op->name, dense_index,
op->image, op->param,
const_true(extent), ModulusRemainder{});
return Shuffle::make({dense_load}, const_indices);
}
}
}
}
Expand Down

0 comments on commit 009fe7a

Please sign in to comment.