Skip to content

Commit

Permalink
Merge branch 'master' into factor_parallel_codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Dec 8, 2021
2 parents 2923246 + 7199e7d commit 5764d12
Show file tree
Hide file tree
Showing 11 changed files with 2,249 additions and 507 deletions.
6 changes: 3 additions & 3 deletions apps/linear_algebra/benchmarks/macros.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#include "halide_benchmark.h"

#ifdef ENABLE_FTZ_DAZ
#if defined(__i386__) || defined(__x86_64__)
#if (defined(__i386__) || defined(__x86_64__)) && defined(__SSE__)
#include <pmmintrin.h>
#include <xmmintrin.h>
#endif // defined(__i386__) || defined(__x86_64__)
#endif // (defined(__i386__) || defined(__x86_64__)) && defined(__SSE__)
#endif

inline void set_math_flags() {
#ifdef ENABLE_FTZ_DAZ

#if defined(__i386__) || defined(__x86_64__)
#if (defined(__i386__) || defined(__x86_64__)) && defined(__SSE__)
// Flush denormals to zero (the FTZ flag).
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
// Interpret denormal inputs as zero (the DAZ flag).
Expand Down
60 changes: 47 additions & 13 deletions src/CodeGen_Internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ bool can_allocation_fit_on_stack(int64_t size) {
return (size <= (int64_t)Runtime::Internal::Constants::maximum_stack_allocation_bytes);
}

Expr lower_int_uint_div(const Expr &a, const Expr &b) {
Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) {
// Detect if it's a small int division
internal_assert(a.type() == b.type());
const int64_t *const_int_divisor = as_const_int(b);
Expand All @@ -166,7 +166,16 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
int shift_amount;
if (is_const_power_of_two_integer(b, &shift_amount) &&
(t.is_int() || t.is_uint())) {
return a >> make_const(UInt(a.type().bits()), shift_amount);
if (round_to_zero) {
Expr result = a;
// Normally a right-shift isn't right for division rounding to
// zero. It does the wrong thing for negative values. Add a fudge so
// that a right-shift becomes correct.
result += (result >> (t.bits() - 1)) & (b - 1);
return result >> shift_amount;
} else {
return a >> make_const(UInt(a.type().bits()), shift_amount);
}
} else if (const_int_divisor &&
t.is_int() &&
(t.bits() == 8 || t.bits() == 16 || t.bits() == 32) &&
Expand All @@ -176,33 +185,55 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
int64_t multiplier;
int shift;
if (t.bits() == 32) {
multiplier = IntegerDivision::table_s32[*const_int_divisor][2];
shift = IntegerDivision::table_s32[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz32[*const_int_divisor][2];
shift = IntegerDivision::table_srz32[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s32[*const_int_divisor][2];
shift = IntegerDivision::table_s32[*const_int_divisor][3];
}
} else if (t.bits() == 16) {
multiplier = IntegerDivision::table_s16[*const_int_divisor][2];
shift = IntegerDivision::table_s16[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz16[*const_int_divisor][2];
shift = IntegerDivision::table_srz16[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s16[*const_int_divisor][2];
shift = IntegerDivision::table_s16[*const_int_divisor][3];
}
} else {
// 8 bit
multiplier = IntegerDivision::table_s8[*const_int_divisor][2];
shift = IntegerDivision::table_s8[*const_int_divisor][3];
if (round_to_zero) {
multiplier = IntegerDivision::table_srz8[*const_int_divisor][2];
shift = IntegerDivision::table_srz8[*const_int_divisor][3];
} else {
multiplier = IntegerDivision::table_s8[*const_int_divisor][2];
shift = IntegerDivision::table_s8[*const_int_divisor][3];
}
}
Expr num = a;

// Make an all-ones mask if the numerator is negative
Type num_as_uint_t = num.type().with_code(Type::UInt);
Expr sign = cast(num_as_uint_t, num >> make_const(UInt(t.bits()), t.bits() - 1));

// Flip the numerator bits if the mask is high.
num = cast(num_as_uint_t, num);
num = num ^ sign;
if (!round_to_zero) {
// Flip the numerator bits if the mask is high.
num = cast(num_as_uint_t, num);
num = num ^ sign;
}

// Multiply and keep the high half of the
// result, and then apply the shift.
Expr mult = make_const(num.type(), multiplier);
num = mul_shift_right(num, mult, shift + num.type().bits());

// Maybe flip the bits back again.
num = cast(a.type(), num ^ sign);
if (round_to_zero) {
// Add one if the numerator was negative
num -= sign;
} else {
// Maybe flip the bits back again.
num = cast(a.type(), num ^ sign);
}

return num;
} else if (const_uint_divisor &&
Expand Down Expand Up @@ -257,6 +288,9 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b) {
}

return val;
} else if (round_to_zero) {
// Return the input division unchanged.
return Call::make(a.type(), Call::div_round_to_zero, {a, b}, Call::PureIntrinsic);
} else {
return lower_euclidean_div(a, b);
}
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::pair<Expr, Expr> long_div_mod_round_to_zero(const Expr &a, const Expr &b,
* Can introduce mulhi_shr and sorted_avg intrinsics as well as those from the
* lower_euclidean_ operation -- div_round_to_zero or mod_round_to_zero. */
///@{
Expr lower_int_uint_div(const Expr &a, const Expr &b);
Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero = false);
Expr lower_int_uint_mod(const Expr &a, const Expr &b);
///@}

Expand Down
6 changes: 6 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2671,6 +2671,12 @@ void CodeGen_LLVM::visit(const Call *op) {
Let::make(b_name, op->args[1],
Select::make(a_var < b_var, b_var - a_var, a_var - b_var))));
} else if (op->is_intrinsic(Call::div_round_to_zero)) {
// See if we can rewrite it to something faster (e.g. a shift)
Expr e = lower_int_uint_div(op->args[0], op->args[1], /** round to zero */ true);
if (!e.as<Call>()) {
codegen(e);
return;
}
internal_assert(op->args.size() == 2);
Value *a = codegen(op->args[0]);
Value *b = codegen(op->args[1]);
Expand Down
100 changes: 95 additions & 5 deletions src/FastIntegerDivide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ Buffer<uint8_t> integer_divide_table_s8() {
return im;
}

Buffer<uint8_t> integer_divide_table_srz8() {
static auto im = []() {
Buffer<uint8_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz8[i][2];
if (i > 1) {
internal_assert(table_runtime_srz8[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Buffer<uint16_t> integer_divide_table_u16() {
static auto im = []() {
Buffer<uint16_t> im(256);
Expand Down Expand Up @@ -76,6 +90,20 @@ Buffer<uint16_t> integer_divide_table_s16() {
return im;
}

Buffer<uint16_t> integer_divide_table_srz16() {
static auto im = []() {
Buffer<uint16_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz16[i][2];
if (i > 1) {
internal_assert(table_runtime_srz16[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Buffer<uint32_t> integer_divide_table_u32() {
static auto im = []() {
Buffer<uint32_t> im(256);
Expand Down Expand Up @@ -104,9 +132,21 @@ Buffer<uint32_t> integer_divide_table_s32() {
return im;
}

} // namespace
Buffer<uint32_t> integer_divide_table_srz32() {
static auto im = []() {
Buffer<uint32_t> im(256);
for (uint32_t i = 0; i < 256; i++) {
im(i) = table_runtime_srz32[i][2];
if (i > 1) {
internal_assert(table_runtime_srz32[i][3] == shift_for_denominator(i));
}
}
return im;
}();
return im;
}

Expr fast_integer_divide(Expr numerator, Expr denominator) {
Expr fast_integer_divide_impl(Expr numerator, Expr denominator, bool round_to_zero) {
if (is_const(denominator)) {
// There's code elsewhere for this case.
return numerator / cast<uint8_t>(denominator);
Expand Down Expand Up @@ -160,7 +200,7 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {
// Do a final shift
result = result >> cast(result.type(), shift);

} else {
} else if (!round_to_zero) {

Expr mul, shift = shift_for_denominator(denominator);
switch (t.bits()) {
Expand Down Expand Up @@ -205,6 +245,46 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {

// Maybe flip the bits again
result = xsign ^ result;
} else {
// Signed round to zero
Expr mul, shift = shift_for_denominator(denominator);
switch (t.bits()) {
case 8: {
Buffer<uint8_t> table = integer_divide_table_srz8();
mul = table(denominator);
break;
}
case 16: {
Buffer<uint16_t> table = integer_divide_table_srz16();
mul = table(denominator);
break;
}
default: // 32
{
Buffer<uint32_t> table = integer_divide_table_srz32();
mul = table(denominator);
break;
}
}

// Extract sign bit
// Expr xsign = (t.bits() < 32) ? (numerator / (1 << (t.bits()-1))) : (numerator >> (t.bits()-1));
Expr xsign = select(numerator > 0, cast(t, 0), cast(t, -1));

// Multiply-keep-high-half
result = (cast(wide, mul) * numerator);
if (t.bits() < 32) {
result = result / (1 << t.bits());
} else {
result = result >> Internal::make_const(result.type(), t.bits());
}
result = cast(t, result);

// Do the final shift
result = result >> cast(result.type(), shift);

// Add one if the numerator was negative
result -= xsign;
}

// The tables don't work for denominator == 1
Expand All @@ -215,9 +295,19 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {
return result;
}

Expr fast_integer_modulo(Expr numerator, Expr denominator) {
} // namespace

Expr fast_integer_divide_round_to_zero(const Expr &numerator, const Expr &denominator) {
return fast_integer_divide_impl(numerator, denominator, /** round to zero **/ true);
}

Expr fast_integer_divide(const Expr &numerator, const Expr &denominator) {
return fast_integer_divide_impl(numerator, denominator, /** round to zero **/ false);
}

Expr fast_integer_modulo(const Expr &numerator, const Expr &denominator) {
Expr ratio = fast_integer_divide(numerator, denominator);
return std::move(numerator) - ratio * std::move(denominator);
return numerator - ratio * denominator;
}

} // namespace Halide
8 changes: 6 additions & 2 deletions src/FastIntegerDivide.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ namespace Halide {
* 256. I.e. it interprets the uint8 divisor as a number from 1 to 256
* inclusive.
*/
Expr fast_integer_divide(Expr numerator, Expr denominator);
Expr fast_integer_divide(const Expr &numerator, const Expr &denominator);

/** A variant of the above which rounds towards zero instead of rounding towards
* negative infinity. */
Expr fast_integer_divide_round_to_zero(const Expr &numerator, const Expr &denominator);

/** Use the fast integer division tables to implement a modulo
* operation via the Euclidean identity: a%b = a - (a/b)*b
*/
Expr fast_integer_modulo(Expr numerator, Expr denominator);
Expr fast_integer_modulo(const Expr &numerator, const Expr &denominator);

} // namespace Halide

Expand Down
Loading

0 comments on commit 5764d12

Please sign in to comment.