Skip to content

Commit

Permalink
rounding shift rights should use rounding halving add (#6494)
Browse files Browse the repository at this point in the history
* rounding shift rights should use rounding halving add

On x86 currently we lower cast<uint8_t>((cast<uint16_t>(x) + 8) / 16)
to:

cast<uint8_t>(shift_right(widening_add(x, 8), 4))

This compiles to 8 instructions on x86: Widen each half of the input
vector, add 8 to each half-vector, shift each half-vector, then narrow
each half-vector.

First, this should have been a rounding_shift_right. Some patterns were
missing in FindIntrinsics.

Second, rounding_shift_right had suboptimal codegen in the case where
the second arg is a positive const. On archs without a rounding shift
right instruction you can further rewrite this to:

shift_right(rounding_halving_add(x, 7), 3)

which is just two instructions on x86.
  • Loading branch information
abadams committed Dec 13, 2021
1 parent 11448b2 commit e23b6f0
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 54 deletions.
151 changes: 123 additions & 28 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class FindIntrinsics : public IRMutator {
IRMatcher::Wild<1> y;
IRMatcher::Wild<2> z;
IRMatcher::WildConst<0> c0;
IRMatcher::WildConst<1> c1;

Expr visit(const Add *op) override {
if (!find_intrinsics_for_type(op->type)) {
Expand Down Expand Up @@ -383,43 +384,129 @@ class FindIntrinsics : public IRMutator {
auto is_x_same_int = op->type.is_int() && is_int(x, bits);
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
// clang-format off
if (rewrite(max(min(widening_add(x, y), upper), lower), saturating_add(x, y), is_x_same_int_or_uint) ||
rewrite(max(min(widening_sub(x, y), upper), lower), saturating_sub(x, y), is_x_same_int_or_uint) ||
rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper), saturating_add(x, y), is_x_same_uint) ||
rewrite(min(widening_add(x, y), upper), saturating_add(x, y), op->type.is_uint() && is_x_same_uint) ||
rewrite(max(widening_sub(x, y), lower), saturating_sub(x, y), op->type.is_uint() && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), 1), halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(shift_right(widening_sub(x, y), 1), halving_sub(x, y), is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(halving_add(widening_add(x, 1), y), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(halving_add(widening_sub(x, y), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(widening_sub(x, y), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||

rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && is_uint(z)) ||
rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && is_uint(z)) ||
rewrite(min(shift_right(widening_mul(x, y), z), upper), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && is_uint(z)) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && is_uint(z)) ||
if (
// Saturating patterns
rewrite(max(min(widening_add(x, y), upper), lower),
saturating_add(x, y),
is_x_same_int_or_uint) ||

rewrite(max(min(widening_sub(x, y), upper), lower),
saturating_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper),
saturating_add(x, y),
is_x_same_uint) ||

rewrite(min(widening_add(x, y), upper),
saturating_add(x, y),
op->type.is_uint() && is_x_same_uint) ||

rewrite(max(widening_sub(x, y), lower),
saturating_sub(x, y),
op->type.is_uint() && is_x_same_uint) ||

// Averaging patterns
//
// We have a slight preference for rounding_halving_add over
// using halving_add when unsigned, because x86 supports it.

rewrite(shift_right(widening_add(x, c0), 1),
rounding_halving_add(x, c0 - 1),
c0 > 0 && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), 1),
halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(shift_right(widening_add(x, c0), c1),
rounding_shift_right(x, cast(op->type, c1)),
c0 == shift_left(1, c1 - 1) && is_x_same_int_or_uint) ||

rewrite(shift_right(widening_add(x, c0), c1),
shift_right(rounding_halving_add(x, cast(op->type, fold(c0 - 1))), cast(op->type, fold(c1 - 1))),
c0 > 0 && c1 > 0 && is_x_same_uint) ||

rewrite(shift_right(widening_add(x, y), c0),
shift_right(halving_add(x, y), cast(op->type, fold(c0 - 1))),
c0 > 0 && is_x_same_int_or_uint) ||

rewrite(shift_right(widening_sub(x, y), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_add(x, 1), y),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(halving_add(widening_sub(x, y), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(widening_sub(x, y), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

// Multiply-keep-high-bits patterns.

rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && is_uint(z)) ||

rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && is_uint(z)) ||

rewrite(min(shift_right(widening_mul(x, y), z), upper),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && is_uint(z)) ||

rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && is_uint(z)) ||

// We don't need saturation for the full upper half of a multiply.
// For signed integers, this is almost true, except for when x and y
// are both the most negative value. For these, we only need saturation
// at the upper bound.
rewrite(min(shift_right(widening_mul(x, y), c0), upper), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && c0 >= bits - 1) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && c0 >= bits - 1) ||
rewrite(shift_right(widening_mul(x, y), c0), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && c0 >= bits) ||
rewrite(rounding_shift_right(widening_mul(x, y), c0), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && c0 >= bits) ||
rewrite(min(shift_right(widening_mul(x, y), c0), upper),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && c0 >= bits - 1) ||

rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && c0 >= bits - 1) ||

rewrite(shift_right(widening_mul(x, y), c0),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && c0 >= bits) ||

rewrite(rounding_shift_right(widening_mul(x, y), c0),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && c0 >= bits) ||

// We can ignore the sign of the widening subtract for halving subtracts.
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), halving_sub(x, y), is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(cast(op_type_wide, widening_sub(x, y)), 1), rounding_halving_sub(x, y), is_x_same_int_or_uint) ||
// Halving subtract patterns
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||

rewrite(rounding_shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
rounding_halving_sub(x, y),
is_x_same_int_or_uint) ||

false) {
internal_assert(rewrite.result.type() == op->type)
<< "Rewrite changed type: " << Expr(op) << " -> " << rewrite.result << "\n";
return mutate(rewrite.result);
}
// clang-format on

// When the argument is a widened rounding shift, we might not need the widening.
// When there is saturation, we can only avoid the widening if we know the shift is
Expand Down Expand Up @@ -763,6 +850,14 @@ Expr lower_rounding_shift_left(const Expr &a, const Expr &b) {
}

Expr lower_rounding_shift_right(const Expr &a, const Expr &b) {
if (is_positive_const(b)) {
// We can handle the rounding with an averaging instruction. We prefer
// the rounding average instruction (we could use either), because the
// non-rounding one is missing on x86.
Expr shift = simplify(b - 1);
Expr round = simplify(cast(a.type(), (1 << shift) - 1));
return rounding_halving_add(a, round) >> shift;
}
// Shift right, then add one to the result if bits were dropped
// (because b > 0) and the most significant dropped bit was a one.
Expr b_positive = select(b > 0, make_one(a.type()), make_zero(a.type()));
Expand Down
44 changes: 43 additions & 1 deletion src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,49 @@ struct Intrin {
return Expr();
}

constexpr static bool foldable = false;
constexpr static bool foldable = true;

HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t arg1;
// Assuming the args have the same type as the intrinsic is incorrect in
// general. But for the intrinsics we can fold (just shifts), the LHS
// has the same type as the intrinsic, and we can always treat the RHS
// as a signed int, because we're using 64 bits for it.
std::get<0>(args).make_folded_const(val, ty, state);
halide_type_t signed_ty = ty;
signed_ty.code = halide_type_int;
// We can just directly get the second arg here, because we only want to
// instantiate this method for shifts, which have two args.
std::get<1>(args).make_folded_const(arg1, signed_ty, state);

if (intrin == Call::shift_left) {
if (arg1.u.i64 < 0) {
if (ty.code == halide_type_int) {
// Arithmetic shift
val.u.i64 >>= -arg1.u.i64;
} else {
// Logical shift
val.u.u64 >>= -arg1.u.i64;
}
} else {
val.u.u64 <<= arg1.u.i64;
}
} else if (intrin == Call::shift_right) {
if (arg1.u.i64 > 0) {
if (ty.code == halide_type_int) {
// Arithmetic shift
val.u.i64 >>= arg1.u.i64;
} else {
// Logical shift
val.u.u64 >>= arg1.u.i64;
}
} else {
val.u.u64 <<= -arg1.u.i64;
}
} else {
internal_error << "Folding not implemented for intrinsic: " << intrin;
}
}

HALIDE_ALWAYS_INLINE
Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Expand Down
8 changes: 8 additions & 0 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ int main(int argc, char **argv) {
Expr u8y = make_leaf(UInt(8, 4), "u8y");
Expr u8z = make_leaf(UInt(8, 4), "u8w");
Expr u8w = make_leaf(UInt(8, 4), "u8z");
Expr u16x = make_leaf(UInt(16, 4), "u16x");
Expr u32x = make_leaf(UInt(32, 4), "u32x");
Expr u32y = make_leaf(UInt(32, 4), "u32y");
Expr i32x = make_leaf(Int(32, 4), "i32x");
Expand Down Expand Up @@ -244,6 +245,13 @@ int main(int argc, char **argv) {
check((i8x + i8(32)) / 64, (i8x + i8(32)) >> 6); // Not a rounding_shift_right due to overflow.
check((i32x + 16) / 32, rounding_shift_right(i32x, 5));

// rounding_right_shift of a widening add can be strength-reduced
check(narrow((u16(u8x) + 15) >> 4), rounding_halving_add(u8x, u8(14)) >> u8(3));
check(narrow((u32(u16x) + 15) >> 4), rounding_halving_add(u16x, u16(14)) >> u16(3));

// But not if the constant can't fit in the narrower type
check(narrow((u16(u8x) + 500) >> 4), narrow((u16(u8x) + 500) >> 4));

check((u64(u32x) + 8) / 16, u64(rounding_shift_right(u32x, 4)));
check(u16(min((u64(u32x) + 8) / 16, 65535)), u16(min(rounding_shift_right(u32x, 4), 65535)));

Expand Down
Loading

0 comments on commit e23b6f0

Please sign in to comment.