Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs in div_round_to_zero #7008

Merged
merged 2 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/CodeGen_Internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,27 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) {
Type num_as_uint_t = num.type().with_code(Type::UInt);
Expr sign = cast(num_as_uint_t, num >> make_const(UInt(t.bits()), t.bits() - 1));

if (!round_to_zero) {
// If the numerator is negative, we want to either flip the bits (when
// rounding to negative infinity), or negate the numerator (when
// rounding to zero).
if (round_to_zero) {
num = abs(num);
} else {
// Flip the numerator bits if the mask is high.
num = cast(num_as_uint_t, num);
num = num ^ sign;
}

// Multiply and keep the high half of the
// result, and then apply the shift.
internal_assert(num.type().can_represent(multiplier));
Expr mult = make_const(num.type(), multiplier);
num = mul_shift_right(num, mult, shift + num.type().bits());

// Maybe flip the bits back or negate again.
num = cast(a.type(), num ^ sign);
if (round_to_zero) {
// Add one if the numerator was negative
num -= sign;
} else {
// Maybe flip the bits back again.
num = cast(a.type(), num ^ sign);
}

return num;
Expand Down
19 changes: 13 additions & 6 deletions src/FastIntegerDivide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,19 @@ Buffer<uint32_t> integer_divide_table_srz32() {
}

Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_zero) {
if (is_const(denominator)) {
// There's code elsewhere for this case.
return numerator / cast<uint8_t>(denominator);
}
user_assert(denominator.type() == UInt(8))
denominator = lossless_cast(UInt(8), denominator);
user_assert(denominator.defined())
<< "Fast integer divide requires a UInt(8) denominator\n";

if (is_const(denominator) && numerator.type().can_represent(denominator.type())) {
if (round_to_zero) {
return div_round_to_zero(numerator, cast(numerator.type(), denominator));
} else {
// There's code elsewhere for this case.
return numerator / cast(numerator.type(), denominator);
}
}

Type t = numerator.type();
user_assert(t.is_uint() || t.is_int())
<< "Fast integer divide requires an integer numerator\n";
Expand Down Expand Up @@ -269,7 +276,7 @@ Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_ze

// Extract sign bit
// Expr xsign = (t.bits() < 32) ? (numerator / (1 << (t.bits()-1))) : (numerator >> (t.bits()-1));
Expr xsign = select(numerator > 0, cast(t, 0), cast(t, -1));
Expr xsign = select(numerator >= 0, cast(t, 0), cast(t, -1));

// Multiply-keep-high-half
result = (cast(wide, mul) * numerator);
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tests(GROUPS correctness
device_slice.cpp
dilate3x3.cpp
div_by_zero.cpp
div_round_to_zero.cpp
dynamic_allocation_in_gpu_kernel.cpp
dynamic_reduction_bounds.cpp
early_out.cpp
Expand Down
99 changes: 99 additions & 0 deletions test/correctness/div_round_to_zero.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "Halide.h"

using namespace Halide;

template<typename T>
void test() {

{
// Test div_round_to_zero
Func f;
Var x, y;

Expr d = cast<T>(y - 128);
Expr n = cast<T>(x - 128);
d = select(d == 0 || (d == -1 && n == d.type().min()),
cast<T>(1),
d);
f(x, y) = div_round_to_zero(n, d);

f.vectorize(x, 8);

Buffer<T> result = f.realize({256, 256});

for (int d = -128; d < 128; d++) {
if (d == 0) {
continue;
}
for (int n = -128; n < 128; n++) {
if (d == -1 && n == std::numeric_limits<T>::min()) {
continue;
}
int correct = d == 0 ? n : (T)(n / d);
int r = result(n + 128, d + 128);
if (r != correct) {
printf("result(%d, %d) = %d instead of %d\n", n, d, r, correct);
exit(-1);
}
}
}
}

{
// Test the fast version
Func f;
Var x, y;

f(x, y) = fast_integer_divide_round_to_zero(cast<T>(x - 128), cast<uint8_t>(y + 1));

f.vectorize(x, 8);

Buffer<T> result_fast = f.realize({256, 255});

for (int d = 1; d < 256; d++) {
for (int n = -128; n < 128; n++) {
int correct = (T)(n / d);
int r = result_fast(n + 128, d - 1);
if (r != correct) {
printf("result_fast(%d, %d) = %d instead of %d\n", n, d, r, correct);
exit(-1);
}
}
}
}

{
// Try some constant denominators
for (int d : {-128, -54, -3, -1, 1, 2, 25, 32, 127}) {
if (d == 0) {
continue;
}

Func f;
Var x;

f(x) = div_round_to_zero(cast<T>(x - 128), cast<T>(d));

f.vectorize(x, 8);

Buffer<T> result_const = f.realize({256});

for (int n = -128; n < 128; n++) {
int correct = (T)(n / d);
int r = result_const(n + 128);
if (r != correct) {
printf("result_const(%d, %d) = %d instead of %d\n", n, d, r, correct);
exit(-1);
}
}
}
}
}

int main(int argc, char **argv) {
test<int8_t>();
test<int16_t>();
test<int32_t>();
printf("Success!\n");
return 0;
}