Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rounding shift rights should use rounding halving add #6494

Merged
merged 2 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 123 additions & 28 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class FindIntrinsics : public IRMutator {
IRMatcher::Wild<1> y;
IRMatcher::Wild<2> z;
IRMatcher::WildConst<0> c0;
IRMatcher::WildConst<1> c1;

Expr visit(const Add *op) override {
if (!find_intrinsics_for_type(op->type)) {
Expand Down Expand Up @@ -383,43 +384,129 @@ class FindIntrinsics : public IRMutator {
auto is_x_same_int = op->type.is_int() && is_int(x, bits);
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
// clang-format off
if (rewrite(max(min(widening_add(x, y), upper), lower), saturating_add(x, y), is_x_same_int_or_uint) ||
rewrite(max(min(widening_sub(x, y), upper), lower), saturating_sub(x, y), is_x_same_int_or_uint) ||
rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper), saturating_add(x, y), is_x_same_uint) ||
rewrite(min(widening_add(x, y), upper), saturating_add(x, y), op->type.is_uint() && is_x_same_uint) ||
rewrite(max(widening_sub(x, y), lower), saturating_sub(x, y), op->type.is_uint() && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), 1), halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(shift_right(widening_sub(x, y), 1), halving_sub(x, y), is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(halving_add(widening_add(x, 1), y), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(halving_add(widening_sub(x, y), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(widening_sub(x, y), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||

rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && is_uint(z)) ||
rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && is_uint(z)) ||
rewrite(min(shift_right(widening_mul(x, y), z), upper), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && is_uint(z)) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && is_uint(z)) ||
if (
// Saturating patterns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side comment, the reformatting here makes this change a lot harder to review easily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, may I suggest restoring original format and putting clang-format off back in place? If we want to let these be 'naturally' reformatted then IMHO we should do that in a standalone PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will in future, but here I just manually reformatted a few surrounding lines because they were hard to modify. What I'd done happened to agree with clang-format, so I also removed those comments.

Copy link
Member Author

@abadams abadams Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general rewrite rules are exempt from clang-format, but these particular ones use named intrinsics, so they get very long indeed.

rewrite(max(min(widening_add(x, y), upper), lower),
saturating_add(x, y),
is_x_same_int_or_uint) ||

rewrite(max(min(widening_sub(x, y), upper), lower),
saturating_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper),
saturating_add(x, y),
is_x_same_uint) ||

rewrite(min(widening_add(x, y), upper),
saturating_add(x, y),
op->type.is_uint() && is_x_same_uint) ||

rewrite(max(widening_sub(x, y), lower),
saturating_sub(x, y),
op->type.is_uint() && is_x_same_uint) ||

// Averaging patterns
//
// We have a slight preference for rounding_halving_add over
// using halving_add when unsigned, because x86 supports it.

rewrite(shift_right(widening_add(x, c0), 1),
rounding_halving_add(x, c0 - 1),
c0 > 0 && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), 1),
halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(shift_right(widening_add(x, c0), c1),
rounding_shift_right(x, cast(op->type, c1)),
c0 == shift_left(1, c1 - 1) && is_x_same_int_or_uint) ||

rewrite(shift_right(widening_add(x, c0), c1),
shift_right(rounding_halving_add(x, cast(op->type, fold(c0 - 1))), cast(op->type, fold(c1 - 1))),
c0 > 0 && c1 > 0 && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), c0),
shift_right(halving_add(x, y), cast(op->type, fold(c0 - 1))),
c0 > 0 && is_x_same_int_or_uint) ||

rewrite(shift_right(widening_sub(x, y), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, 1), y),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_sub(x, y), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(widening_sub(x, y), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

// Multiply-keep-high-bits patterns.

rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && is_uint(z)) ||

rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && is_uint(z)) ||

rewrite(min(shift_right(widening_mul(x, y), z), upper),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && is_uint(z)) ||

rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && is_uint(z)) ||

// We don't need saturation for the full upper half of a multiply.
// For signed integers, this is almost true, except for when x and y
// are both the most negative value. For these, we only need saturation
// at the upper bound.
rewrite(min(shift_right(widening_mul(x, y), c0), upper), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && c0 >= bits - 1) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && c0 >= bits - 1) ||
rewrite(shift_right(widening_mul(x, y), c0), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && c0 >= bits) ||
rewrite(rounding_shift_right(widening_mul(x, y), c0), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && c0 >= bits) ||
rewrite(min(shift_right(widening_mul(x, y), c0), upper),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && c0 >= bits - 1) ||

rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && c0 >= bits - 1) ||

rewrite(shift_right(widening_mul(x, y), c0),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && c0 >= bits) ||

rewrite(rounding_shift_right(widening_mul(x, y), c0),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && c0 >= bits) ||

// We can ignore the sign of the widening subtract for halving subtracts.
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), halving_sub(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(cast(op_type_wide, widening_sub(x, y)), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||
// Halving subtract patterns
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

false) {
internal_assert(rewrite.result.type() == op->type)
<< "Rewrite changed type: " << Expr(op) << " -> " << rewrite.result << "\n";
return mutate(rewrite.result);
}
// clang-format on

// When the argument is a widened rounding shift, we might not need the widening.
// When there is saturation, we can only avoid the widening if we know the shift is
Expand Down Expand Up @@ -763,6 +850,14 @@ Expr lower_rounding_shift_left(const Expr &a, const Expr &b) {
}

Expr lower_rounding_shift_right(const Expr &a, const Expr &b) {
if (is_positive_const(b)) {
// We can handle the rounding with an averaging instruction. We prefer
// the rounding average instruction (we could use either), because the
// non-rounding one is missing on x86.
Expr shift = simplify(b - 1);
Expr round = simplify(cast(a.type(), (1 << shift) - 1));
return rounding_halving_add(a, round) >> shift;
}
// Shift right, then add one to the result if bits were dropped
// (because b > 0) and the most significant dropped bit was a one.
Expr b_positive = select(b > 0, make_one(a.type()), make_zero(a.type()));
Expand Down
44 changes: 43 additions & 1 deletion src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,49 @@ struct Intrin {
return Expr();
}

constexpr static bool foldable = false;
constexpr static bool foldable = true;

HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t arg1;
// Assuming the args have the same type as the intrinsic is incorrect in
// general. But for the intrinsics we can fold (just shifts), the LHS
// has the same type as the intrinsic, and we can always treat the RHS
// as a signed int, because we're using 64 bits for it.
std::get<0>(args).make_folded_const(val, ty, state);
halide_type_t signed_ty = ty;
signed_ty.code = halide_type_int;
// We can just directly get the second arg here, because we only want to
// instantiate this method for shifts, which have two args.
std::get<1>(args).make_folded_const(arg1, signed_ty, state);

if (intrin == Call::shift_left) {
if (arg1.u.i64 < 0) {
if (ty.code == halide_type_int) {
// Arithmetic shift
val.u.i64 >>= -arg1.u.i64;
} else {
// Logical shift
val.u.u64 >>= -arg1.u.i64;
}
} else {
val.u.u64 <<= arg1.u.i64;
}
} else if (intrin == Call::shift_right) {
if (arg1.u.i64 > 0) {
if (ty.code == halide_type_int) {
// Arithmetic shift
val.u.i64 >>= arg1.u.i64;
} else {
// Logical shift
val.u.u64 >>= arg1.u.i64;
}
} else {
val.u.u64 <<= -arg1.u.i64;
}
} else {
internal_error << "Folding not implemented for intrinsic: " << intrin;
}
}

HALIDE_ALWAYS_INLINE
Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Expand Down
8 changes: 8 additions & 0 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ int main(int argc, char **argv) {
Expr u8y = make_leaf(UInt(8, 4), "u8y");
Expr u8z = make_leaf(UInt(8, 4), "u8w");
Expr u8w = make_leaf(UInt(8, 4), "u8z");
Expr u16x = make_leaf(UInt(16, 4), "u16x");
Expr u32x = make_leaf(UInt(32, 4), "u32x");
Expr u32y = make_leaf(UInt(32, 4), "u32y");
Expr i32x = make_leaf(Int(32, 4), "i32x");
Expand Down Expand Up @@ -244,6 +245,13 @@ int main(int argc, char **argv) {
check((i8x + i8(32)) / 64, (i8x + i8(32)) >> 6); // Not a rounding_shift_right due to overflow.
check((i32x + 16) / 32, rounding_shift_right(i32x, 5));

// rounding_right_shift of a widening add can be strength-reduced
check(narrow((u16(u8x) + 15) >> 4), rounding_halving_add(u8x, u8(14)) >> u8(3));
check(narrow((u32(u16x) + 15) >> 4), rounding_halving_add(u16x, u16(14)) >> u16(3));

// But not if the constant can't fit in the narrower type
check(narrow((u16(u8x) + 500) >> 4), narrow((u16(u8x) + 500) >> 4));

check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4)));
check(u16(min((u64(u32x) + 8) / 16, 65535)), u16(min(rounding_shift_right(u32x, 4), 65535)));

Expand Down
Loading