diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index b5b043d5c981..ab31e2ebfb99 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -174,7 +174,12 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) { Type num_as_uint_t = num.type().with_code(Type::UInt); Expr sign = cast(num_as_uint_t, num >> make_const(UInt(t.bits()), t.bits() - 1)); - if (!round_to_zero) { + // If the numerator is negative, we want to either flip the bits (when + // rounding to negative infinity), or negate the numerator (when + // rounding to zero). + if (round_to_zero) { + num = abs(num); + } else { // Flip the numerator bits if the mask is high. num = cast(num_as_uint_t, num); num = num ^ sign; @@ -182,15 +187,14 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) { // Multiply and keep the high half of the // result, and then apply the shift. + internal_assert(num.type().can_represent(multiplier)); Expr mult = make_const(num.type(), multiplier); num = mul_shift_right(num, mult, shift + num.type().bits()); + // Maybe flip the bits back or negate again. + num = cast(a.type(), num ^ sign); if (round_to_zero) { - // Add one if the numerator was negative num -= sign; - } else { - // Maybe flip the bits back again. - num = cast(a.type(), num ^ sign); } return num; diff --git a/src/FastIntegerDivide.cpp b/src/FastIntegerDivide.cpp index 72da3d367546..cab28aedb5d8 100644 --- a/src/FastIntegerDivide.cpp +++ b/src/FastIntegerDivide.cpp @@ -147,12 +147,19 @@ Buffer integer_divide_table_srz32() { } Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_zero) { - if (is_const(denominator)) { - // There's code elsewhere for this case. - return numerator / cast(denominator); - } - user_assert(denominator.type() == UInt(8)) + denominator = lossless_cast(UInt(8), denominator); + user_assert(denominator.defined()) << "Fast integer divide requires a UInt(8) denominator\n"; + + if (is_const(denominator) && numerator.type().can_represent(denominator.type())) { + if (round_to_zero) { + return div_round_to_zero(numerator, cast(numerator.type(), denominator)); + } else { + // There's code elsewhere for this case. + return numerator / cast(numerator.type(), denominator); + } + } + Type t = numerator.type(); user_assert(t.is_uint() || t.is_int()) << "Fast integer divide requires an integer numerator\n"; @@ -269,7 +276,7 @@ Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_ze // Extract sign bit // Expr xsign = (t.bits() < 32) ? (numerator / (1 << (t.bits()-1))) : (numerator >> (t.bits()-1)); - Expr xsign = select(numerator > 0, cast(t, 0), cast(t, -1)); + Expr xsign = select(numerator >= 0, cast(t, 0), cast(t, -1)); // Multiply-keep-high-half result = (cast(wide, mul) * numerator); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 10a108553231..b8256ca09204 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -79,6 +79,7 @@ tests(GROUPS correctness device_slice.cpp dilate3x3.cpp div_by_zero.cpp + div_round_to_zero.cpp dynamic_allocation_in_gpu_kernel.cpp dynamic_reduction_bounds.cpp early_out.cpp diff --git a/test/correctness/div_round_to_zero.cpp b/test/correctness/div_round_to_zero.cpp new file mode 100644 index 000000000000..7dee9a8d1123 --- /dev/null +++ b/test/correctness/div_round_to_zero.cpp @@ -0,0 +1,99 @@ +#include "Halide.h" + +using namespace Halide; + +template +void test() { + + { + // Test div_round_to_zero + Func f; + Var x, y; + + Expr d = cast(y - 128); + Expr n = cast(x - 128); + d = select(d == 0 || (d == -1 && n == d.type().min()), + cast(1), + d); + f(x, y) = div_round_to_zero(n, d); + + f.vectorize(x, 8); + + Buffer result = f.realize({256, 256}); + + for (int d = -128; d < 128; d++) { + if (d == 0) { + continue; + } + for (int n = -128; n < 128; n++) { + if (d == -1 && n == std::numeric_limits::min()) { + continue; + } + int correct = d == 0 ? n : (T)(n / d); + int r = result(n + 128, d + 128); + if (r != correct) { + printf("result(%d, %d) = %d instead of %d\n", n, d, r, correct); + exit(-1); + } + } + } + } + + { + // Test the fast version + Func f; + Var x, y; + + f(x, y) = fast_integer_divide_round_to_zero(cast(x - 128), cast(y + 1)); + + f.vectorize(x, 8); + + Buffer result_fast = f.realize({256, 255}); + + for (int d = 1; d < 256; d++) { + for (int n = -128; n < 128; n++) { + int correct = (T)(n / d); + int r = result_fast(n + 128, d - 1); + if (r != correct) { + printf("result_fast(%d, %d) = %d instead of %d\n", n, d, r, correct); + exit(-1); + } + } + } + } + + { + // Try some constant denominators + for (int d : {-128, -54, -3, -1, 1, 2, 25, 32, 127}) { + if (d == 0) { + continue; + } + + Func f; + Var x; + + f(x) = div_round_to_zero(cast(x - 128), cast(d)); + + f.vectorize(x, 8); + + Buffer result_const = f.realize({256}); + + for (int n = -128; n < 128; n++) { + int correct = (T)(n / d); + int r = result_const(n + 128); + if (r != correct) { + printf("result_const(%d, %d) = %d instead of %d\n", n, d, r, correct); + exit(-1); + } + } + } + } +} + +int main(int argc, char **argv) { + test(); + test(); + test(); + printf("Success!\n"); + return 0; +}