Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fast integer divide that rounds to zero #6455

Merged
merged 6 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions src/CodeGen_Internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ bool can_allocation_fit_on_stack(int64_t size) {
return (size <= (int64_t)Runtime::Internal::Constants::maximum_stack_allocation_bytes);
}

Expr lower_int_uint_div(const Expr &a, const Expr &b) {
Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) {
// Detect if it's a small int division
internal_assert(a.type() == b.type());
const int64_t *const_int_divisor = as_const_int(b);
Expand All @@ -261,7 +261,16 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
int shift_amount;
if (is_const_power_of_two_integer(b, &shift_amount) &&
(t.is_int() || t.is_uint())) {
return a >> make_const(UInt(a.type().bits()), shift_amount);
if (round_to_zero) {
Expr result = a;
// Normally a right-shift isn't right for division rounding to
// zero. It does the wrong thing for negative values. Add a fudge so
// that a right-shift becomes correct.
result += (result >> (t.bits() - 1)) & (b - 1);
return result >> shift_amount;
} else {
return a >> make_const(UInt(a.type().bits()), shift_amount);
}
} else if (const_int_divisor &&
t.is_int() &&
(t.bits() == 8 || t.bits() == 16 || t.bits() == 32) &&
Expand All @@ -271,33 +280,55 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
int64_t multiplier;
int shift;
if (t.bits() == 32) {
multiplier = IntegerDivision::table_s32[*const_int_divisor][2];
shift = IntegerDivision::table_s32[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz32[*const_int_divisor][2];
shift = IntegerDivision::table_srz32[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s32[*const_int_divisor][2];
shift = IntegerDivision::table_s32[*const_int_divisor][3];
}
} else if (t.bits() == 16) {
multiplier = IntegerDivision::table_s16[*const_int_divisor][2];
shift = IntegerDivision::table_s16[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz16[*const_int_divisor][2];
shift = IntegerDivision::table_srz16[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s16[*const_int_divisor][2];
shift = IntegerDivision::table_s16[*const_int_divisor][3];
}
} else {
// 8 bit
multiplier = IntegerDivision::table_s8[*const_int_divisor][2];
shift = IntegerDivision::table_s8[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz8[*const_int_divisor][2];
shift = IntegerDivision::table_srz8[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s8[*const_int_divisor][2];
shift = IntegerDivision::table_s8[*const_int_divisor][3];
}
}
Expr num = a;

// Make an all-ones mask if the numerator is negative
Type num_as_uint_t = num.type().with_code(Type::UInt);
Expr sign = cast(num_as_uint_t, num >> make_const(UInt(t.bits()), t.bits() - 1));

// Flip the numerator bits if the mask is high.
num = cast(num_as_uint_t, num);
num = num ^ sign;
if (!round_to_zero) {
// Flip the numerator bits if the mask is high.
num = cast(num_as_uint_t, num);
num = num ^ sign;
}

// Multiply and keep the high half of the
// result, and then apply the shift.
Expr mult = make_const(num.type(), multiplier);
num = mul_shift_right(num, mult, shift + num.type().bits());

// Maybe flip the bits back again.
num = cast(a.type(), num ^ sign);
if (round_to_zero) {
// Add one if the numerator was negative
num -= sign;
} else {
// Maybe flip the bits back again.
num = cast(a.type(), num ^ sign);
}

return num;
} else if (const_uint_divisor &&
Expand Down Expand Up @@ -352,6 +383,9 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
}

return val;
} else if (round_to_zero) {
// Return the input division unchanged.
return Call::make(a.type(), Call::div_round_to_zero, {a, b}, Call::PureIntrinsic);
} else {
return lower_euclidean_div(a, b);
}
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ std::pair<Expr, Expr> long_div_mod_round_to_zero(const Expr &a, const Expr &b,
* Can introduce mulhi_shr and sorted_avg intrinsics as well as those from the
* lower_euclidean_ operation -- div_round_to_zero or mod_round_to_zero. */
///@{
Expr lower_int_uint_div(const Expr &a, const Expr &b);
Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero = false);
Expr lower_int_uint_mod(const Expr &a, const Expr &b);
///@}

Expand Down
6 changes: 6 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,12 @@ void CodeGen_LLVM::visit(const Call *op) {
Let::make(b_name, op->args[1],
Select::make(a_var < b_var, b_var - a_var, a_var - b_var))));
} else if (op->is_intrinsic(Call::div_round_to_zero)) {
// See if we can rewrite it to something faster (e.g. a shift)
Expr e = lower_int_uint_div(op->args[0], op->args[1], /** round to zero */ true);
if (!e.as<Call>()) {
codegen(e);
return;
}
internal_assert(op->args.size() == 2);
Value *a = codegen(op->args[0]);
Value *b = codegen(op->args[1]);
Expand Down
96 changes: 93 additions & 3 deletions src/FastIntegerDivide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ Buffer<uint8_t> integer_divide_table_s8() {
return im;
}

Buffer<uint8_t> integer_divide_table_srz8() {
static auto im = []() {
Buffer<uint8_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz8[i][2];
if (i > 1) {
internal_assert(table_runtime_srz8[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Buffer<uint16_t> integer_divide_table_u16() {
static auto im = []() {
Buffer<uint16_t> im(256);
Expand Down Expand Up @@ -76,6 +90,20 @@ Buffer<uint16_t> integer_divide_table_s16() {
return im;
}

Buffer<uint16_t> integer_divide_table_srz16() {
static auto im = []() {
Buffer<uint16_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz16[i][2];
if (i > 1) {
internal_assert(table_runtime_srz16[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Buffer<uint32_t> integer_divide_table_u32() {
static auto im = []() {
Buffer<uint32_t> im(256);
Expand Down Expand Up @@ -104,9 +132,21 @@ Buffer<uint32_t> integer_divide_table_s32() {
return im;
}

} // namespace
Buffer<uint32_t> integer_divide_table_srz32() {
static auto im = []() {
Buffer<uint32_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz32[i][2];
if (i > 1) {
internal_assert(table_runtime_srz32[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Expr fast_integer_divide(Expr numerator, Expr denominator) {
Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_zero) {
if (is_const(denominator)) {
// There's code elsewhere for this case.
return numerator / cast<uint8_t>(denominator);
Expand Down Expand Up @@ -160,7 +200,7 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {
// Do a final shift
result = result >> cast(result.type(), shift);

} else {
} else if (!round_to_zero) {

Expr mul, shift = shift_for_denominator(denominator);
switch (t.bits()) {
Expand Down Expand Up @@ -205,6 +245,46 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {

// Maybe flip the bits again
result = xsign ^ result;
} else {
// Signed round to zero
Expr mul, shift = shift_for_denominator(denominator);
switch (t.bits()) {
case 8: {
Buffer<uint8_t> table = integer_divide_table_srz8();
mul = table(denominator);
break;
}
case 16: {
Buffer<uint16_t> table = integer_divide_table_srz16();
mul = table(denominator);
break;
}
default: // 32
{
Buffer<uint32_t> table = integer_divide_table_srz32();
mul = table(denominator);
break;
}
}

// Extract sign bit
// Expr xsign = (t.bits() < 32) ? (numerator / (1 << (t.bits()-1))) : (numerator >> (t.bits()-1));
Expr xsign = select(numerator > 0, cast(t, 0), cast(t, -1));

// Multiply-keep-high-half
result = (cast(wide, mul) * numerator);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should use widening_mul intrinsics, because uses of this are after find_intrinsics. Maybe this whole sequence should be mul_shift_right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this code is only called directly by users, so it's before find_intrinsics. The compiler doesn't ever call this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add this as a comment for future readers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think we should change it to intrinsics anyways. But since the code is just moved and pre-existing, maybe it should be a separate PR.

if (t.bits() < 32) {
result = result / (1 << t.bits());
} else {
result = result >> Internal::make_const(result.type(), t.bits());
}
result = cast(t, result);

// Do the final shift
result = result >> cast(result.type(), shift);

// Add one if the numerator was negative
result -= xsign;
}

// The tables don't work for denominator == 1
Expand All @@ -215,6 +295,16 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {
return result;
}

} // namespace

Expr fast_integer_divide_round_to_zero(Expr numerator, Expr denominator) {
return fast_integer_divide_impl(numerator, denominator, /** round to zero **/ true);
}

Expr fast_integer_divide(Expr numerator, Expr denominator) {
return fast_integer_divide_impl(numerator, denominator, /** round to zero **/ false);
}

Expr fast_integer_modulo(Expr numerator, Expr denominator) {
Expr ratio = fast_integer_divide(numerator, denominator);
return std::move(numerator) - ratio * std::move(denominator);
Expand Down
4 changes: 4 additions & 0 deletions src/FastIntegerDivide.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ namespace Halide {
*/
Expr fast_integer_divide(Expr numerator, Expr denominator);

/** A variant of the above which rounds towards zero instead of rounding towards
* negative infinity. */
Expr fast_integer_divide_round_to_zero(Expr numerator, Expr denominator);

/** Use the fast integer division tables to implement a modulo
* operation via the Euclidean identity: a%b = a - (a/b)*b
*/
Expand Down
Loading