diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index b3d2282ffb56..573d4381a52b 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -132,7 +132,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { // If we know the sign of this shift, change it to an unsigned shift. if (b_info.min_defined && b_info.min >= 0) { b = mutate(cast(b.type().with_code(halide_type_uint), b), nullptr); - } else if (b_info.max_defined && b_info.max <= 0) { + } else if (b.type().is_int() && b_info.max_defined && b_info.max <= 0) { result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); b = mutate(cast(b.type().with_code(halide_type_uint), -b), nullptr); } @@ -165,12 +165,14 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } } - // Rewrite shifts with negated RHSes as shifts of the other direction. - if (const Sub *sub = b.as()) { - if (is_const_zero(sub->a)) { - result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); - b = sub->b; - return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + // Rewrite shifts with signed negated RHSes as shifts of the other direction. + if (b.type().is_int()) { + if (const Sub *sub = b.as()) { + if (is_const_zero(sub->a)) { + result_op = Call::get_intrinsic_name(op->is_intrinsic(Call::shift_right) ? Call::shift_left : Call::shift_right); + b = sub->b; + return mutate(Call::make(op->type, result_op, {a, b}, Call::PureIntrinsic), bounds); + } } } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index a7fab3aa76d2..0a4c8835577d 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -285,6 +285,7 @@ tests(GROUPS correctness set_custom_trace.cpp shadowed_bound.cpp shared_self_references.cpp + shift_by_unsigned_negated.cpp shifted_image.cpp side_effects.cpp simd_op_check.cpp diff --git a/test/correctness/shift_by_unsigned_negated.cpp b/test/correctness/shift_by_unsigned_negated.cpp new file mode 100644 index 000000000000..a8ca8dca8e06 --- /dev/null +++ b/test/correctness/shift_by_unsigned_negated.cpp @@ -0,0 +1,46 @@ +#include "Halide.h" + +using namespace Halide; + +template +bool test(Func f, T f_expected, int width) { + Buffer actual = f.realize({width}); + for (int i = 0; i < actual.width(); i++) { + if (actual(i) != f_expected(i)) { + printf("r(%d) = %d, f_expected(%d) = %d\n", + i, actual(i), i, f_expected(i)); + return false; + } + } + return true; +} + +int main(int argc, char **argv) { + Buffer step(31); + for (int i = 0; i < step.width(); i++) { + step(i) = -i; + } + + bool success = true; + Var x; + + { + Func f; + f(x) = Expr(-1U) << -step(x); + auto f_expected = [&](int x) { + return -1U << x; + }; + success &= test(f, f_expected, step.width()); + } + { + Func f; + f(x) = Expr(-1U) >> -step(x); + auto f_expected = [&](int x) { + return -1U >> x; + }; + success &= test(f, f_expected, step.width()); + } + + if (success) printf("Success!\n"); + return success ? 0 : -1; +}