Skip to content

Commit

Permalink
Fix bounds_of_nested_lanes (#8039)
Browse files Browse the repository at this point in the history
* Fix bounds_of_nested_lanes

bounds_of_nested_lanes assumed that one layer of nested vectorization
could be removed at a time. When faced with the expression:

min(ramp(x8(a), x8(b), 5), x40(27))

It panicked, because on the left hand side it reduced the bounds to
x8(a) ... x8(a) + x8(b) * 4, and on the right hand side it reduced the
bounds to 27. It then attempted to take a min of mismatched types.

In general we can't assume that binary operators on nested vectors have
the same nesting structure on both sides, so I just rewrote it to reduce
directly to a scalar.

Fixes #8038
  • Loading branch information
abadams authored and steven-johnson committed Feb 1, 2024
1 parent 6d29ad5 commit be6d6c6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 64 deletions.
140 changes: 76 additions & 64 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,121 +29,146 @@ Expr get_lane(const Expr &e, int l) {
return Shuffle::make_slice(e, l, 0, 1);
}

/** Find the exact max and min lanes of a vector expression. Not
* conservative like bounds_of_expr, but uses similar rules for some
* common node types where it can be exact. If e is a nested vector,
* the result will be the bounds of the vectors in each lane. */
Interval bounds_of_nested_lanes(const Expr &e) {
/** A helper like .as<Broadcast>(), but unwraps arbitrarily many layers of
* nested broadcasts. Guaranteed to return either a broadcast of a scalar or
* nullptr. */
const Broadcast *as_scalar_broadcast(const Expr &e) {
const Broadcast *b = e.as<Broadcast>();
if (b && b->value.type().is_scalar()) {
return b;
} else if (b) {
return as_scalar_broadcast(b->value);
} else {
return nullptr;
}
};

/** Find the exact scalar max and min lanes of a vector expression. Not
* conservative like bounds_of_expr, but uses similar rules for some common node
* types where it can be exact. Always returns a scalar, even in the case of
* nested vectorization. */
Interval bounds_of_lanes(const Expr &e) {
if (e.type().is_scalar()) {
return {e, e};
}

if (const Add *add = e.as<Add>()) {
if (const Broadcast *b = add->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(add->a);
if (const Broadcast *b = as_scalar_broadcast(add->b)) {
Interval ia = bounds_of_lanes(add->a);
return {ia.min + b->value, ia.max + b->value};
} else if (const Broadcast *b = add->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(add->b);
} else if (const Broadcast *b = as_scalar_broadcast(add->a)) {
Interval ia = bounds_of_lanes(add->b);
return {b->value + ia.min, b->value + ia.max};
}
} else if (const Sub *sub = e.as<Sub>()) {
if (const Broadcast *b = sub->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(sub->a);
if (const Broadcast *b = as_scalar_broadcast(sub->b)) {
Interval ia = bounds_of_lanes(sub->a);
return {ia.min - b->value, ia.max - b->value};
} else if (const Broadcast *b = sub->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(sub->b);
return {b->value - ia.max, b->value - ia.max};
} else if (const Broadcast *b = as_scalar_broadcast(sub->a)) {
Interval ia = bounds_of_lanes(sub->b);
return {b->value - ia.max, b->value - ia.min};
}
} else if (const Mul *mul = e.as<Mul>()) {
if (const Broadcast *b = mul->b.as<Broadcast>()) {
if (const Broadcast *b = as_scalar_broadcast(mul->b)) {
if (is_positive_const(b->value)) {
Interval ia = bounds_of_nested_lanes(mul->a);
Interval ia = bounds_of_lanes(mul->a);
return {ia.min * b->value, ia.max * b->value};
} else if (is_negative_const(b->value)) {
Interval ia = bounds_of_nested_lanes(mul->a);
Interval ia = bounds_of_lanes(mul->a);
return {ia.max * b->value, ia.min * b->value};
}
} else if (const Broadcast *b = mul->a.as<Broadcast>()) {
} else if (const Broadcast *b = as_scalar_broadcast(mul->a)) {
if (is_positive_const(b->value)) {
Interval ia = bounds_of_nested_lanes(mul->b);
Interval ia = bounds_of_lanes(mul->b);
return {b->value * ia.min, b->value * ia.max};
} else if (is_negative_const(b->value)) {
Interval ia = bounds_of_nested_lanes(mul->b);
Interval ia = bounds_of_lanes(mul->b);
return {b->value * ia.max, b->value * ia.min};
}
}
} else if (const Div *div = e.as<Div>()) {
if (const Broadcast *b = div->b.as<Broadcast>()) {
if (const Broadcast *b = as_scalar_broadcast(div->b)) {
if (is_positive_const(b->value)) {
Interval ia = bounds_of_nested_lanes(div->a);
Interval ia = bounds_of_lanes(div->a);
return {ia.min / b->value, ia.max / b->value};
} else if (is_negative_const(b->value)) {
Interval ia = bounds_of_nested_lanes(div->a);
Interval ia = bounds_of_lanes(div->a);
return {ia.max / b->value, ia.min / b->value};
}
}
} else if (const And *and_ = e.as<And>()) {
if (const Broadcast *b = and_->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(and_->a);
if (const Broadcast *b = as_scalar_broadcast(and_->b)) {
Interval ia = bounds_of_lanes(and_->a);
return {ia.min && b->value, ia.max && b->value};
} else if (const Broadcast *b = and_->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(and_->b);
} else if (const Broadcast *b = as_scalar_broadcast(and_->a)) {
Interval ia = bounds_of_lanes(and_->b);
return {ia.min && b->value, ia.max && b->value};
}
} else if (const Or *or_ = e.as<Or>()) {
if (const Broadcast *b = or_->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(or_->a);
if (const Broadcast *b = as_scalar_broadcast(or_->b)) {
Interval ia = bounds_of_lanes(or_->a);
return {ia.min && b->value, ia.max && b->value};
} else if (const Broadcast *b = or_->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(or_->b);
} else if (const Broadcast *b = as_scalar_broadcast(or_->a)) {
Interval ia = bounds_of_lanes(or_->b);
return {ia.min && b->value, ia.max && b->value};
}
} else if (const Min *min = e.as<Min>()) {
if (const Broadcast *b = min->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(min->a);
if (const Broadcast *b = as_scalar_broadcast(min->b)) {
Interval ia = bounds_of_lanes(min->a);
// ia and b->value have both had one nesting layer of vectorization
// peeled off, but that doesn't make them the same type.
return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
} else if (const Broadcast *b = min->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(min->b);
} else if (const Broadcast *b = as_scalar_broadcast(min->a)) {
Interval ia = bounds_of_lanes(min->b);
return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
}
} else if (const Max *max = e.as<Max>()) {
if (const Broadcast *b = max->b.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(max->a);
if (const Broadcast *b = as_scalar_broadcast(max->b)) {
Interval ia = bounds_of_lanes(max->a);
return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
} else if (const Broadcast *b = max->a.as<Broadcast>()) {
Interval ia = bounds_of_nested_lanes(max->b);
} else if (const Broadcast *b = as_scalar_broadcast(max->a)) {
Interval ia = bounds_of_lanes(max->b);
return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
}
} else if (const Not *not_ = e.as<Not>()) {
Interval ia = bounds_of_nested_lanes(not_->a);
Interval ia = bounds_of_lanes(not_->a);
return {!ia.max, !ia.min};
} else if (const Ramp *r = e.as<Ramp>()) {
Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
if (is_positive_const(r->stride)) {
return {r->base, r->base + last_lane_idx * r->stride};
} else if (is_negative_const(r->stride)) {
return {r->base + last_lane_idx * r->stride, r->base};
Interval ib = bounds_of_lanes(r->base);
const Broadcast *b = as_scalar_broadcast(r->stride);
Expr stride = b ? b->value : r->stride;
if (stride.type().is_scalar()) {
if (is_positive_const(stride)) {
return {ib.min, ib.max + last_lane_idx * stride};
} else if (is_negative_const(stride)) {
return {ib.min + last_lane_idx * stride, ib.max};
}
}
} else if (const LE *le = e.as<LE>()) {
// The least true this can be is if we maximize the LHS and minimize the RHS.
// The most true this can be is if we minimize the LHS and maximize the RHS.
// This is only exact if one of the two sides is a Broadcast.
Interval ia = bounds_of_nested_lanes(le->a);
Interval ib = bounds_of_nested_lanes(le->b);
Interval ia = bounds_of_lanes(le->a);
Interval ib = bounds_of_lanes(le->b);
if (ia.is_single_point() || ib.is_single_point()) {
return {ia.max <= ib.min, ia.min <= ib.max};
}
} else if (const LT *lt = e.as<LT>()) {
// The least true this can be is if we maximize the LHS and minimize the RHS.
// The most true this can be is if we minimize the LHS and maximize the RHS.
// This is only exact if one of the two sides is a Broadcast.
Interval ia = bounds_of_nested_lanes(lt->a);
Interval ib = bounds_of_nested_lanes(lt->b);
Interval ia = bounds_of_lanes(lt->a);
Interval ib = bounds_of_lanes(lt->b);
if (ia.is_single_point() || ib.is_single_point()) {
return {ia.max < ib.min, ia.min < ib.max};
}

} else if (const Broadcast *b = e.as<Broadcast>()) {
} else if (const Broadcast *b = as_scalar_broadcast(e)) {
return {b->value, b->value};
} else if (const Let *let = e.as<Let>()) {
Interval ia = bounds_of_nested_lanes(let->value);
Interval ib = bounds_of_nested_lanes(let->body);
Interval ia = bounds_of_lanes(let->value);
Interval ib = bounds_of_lanes(let->body);
if (expr_uses_var(ib.min, let->name)) {
ib.min = Let::make(let->name, let->value, ib.min);
}
Expand All @@ -166,19 +191,6 @@ Interval bounds_of_nested_lanes(const Expr &e) {
}
};

/** Similar to bounds_of_nested_lanes, but it recursively reduces
* the bounds of nested vectors to scalars. */
Interval bounds_of_lanes(const Expr &e) {
Interval bounds = bounds_of_nested_lanes(e);
if (!bounds.min.type().is_scalar()) {
bounds.min = bounds_of_lanes(bounds.min).min;
}
if (!bounds.max.type().is_scalar()) {
bounds.max = bounds_of_lanes(bounds.max).max;
}
return bounds;
}

// A ramp with the lanes repeated inner_repetitions times, and then
// the whole vector repeated outer_repetitions times.
// E.g: <0 0 2 2 4 4 6 6 0 0 2 2 4 4 6 6>.
Expand Down
19 changes: 19 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8038
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y"), yi("yi"), yo("yo"), xi("xi"), xo("xo"), yofxi("yofxi"), yofxio("yofxio"), yofxii("yofxii"), yofxiifyi("yofxiifyi"), yofxioo("yofxioo"), yofxioi("yofxioi");
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5, "rdom_r");
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);
local_sum.split(y, yi, yo, 2, TailStrategy::GuardWithIf).split(x, xi, xo, 5, TailStrategy::Predicate).fuse(yo, xi, yofxi).split(yofxi, yofxio, yofxii, 8, TailStrategy::ShiftInwards).fuse(yofxii, yi, yofxiifyi).split(yofxio, yofxioo, yofxioi, 5, TailStrategy::ShiftInwards).vectorize(yofxiifyi).vectorize(yofxioi);
local_sum.update(0).unscheduled();
blurry.split(x, xo, xi, 5, TailStrategy::Auto);
Pipeline p({blurry});
auto buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

printf("Success!\n");
return 0;
}

0 comments on commit be6d6c6

Please sign in to comment.