Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue where the Halide compiler hits an internal error for bool types in widening intrinsics. #8099

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 154 additions & 119 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace Halide::ConciseCasts;

namespace {

// This routine provides a guard on the return type of intrisics such that only
// these types will ever be considered in the visiting that happens here.
bool find_intrinsics_for_type(const Type &t) {
// Currently, we only try to find and replace intrinsics for vector types that aren't bools.
return t.is_vector() && t.bits() >= 8;
Expand All @@ -28,17 +30,36 @@ Expr narrow(Expr a) {
return Cast::make(result_type, std::move(a));
}

// Check a type to make sure it can be narrowed. find_intrinsics_for_type
// attempts to prevent this code from narrowing in cases that do not work, but
// it is incomplete for two reasons:
//
// - Arguments can be narrowed and that guard is only on return type, which
// are different for widening operations.
//
// - find_intrinsics_for_type does not cull out float16, and it probably
// should not as while it's ok to skip matching bool things, float16 things
// are useful.
bool can_narrow(const Type &t) {
return (t.is_float() && t.bits() >= 32) ||
t.bits() >= 8;
}

Expr lossless_narrow(const Expr &x) {
return lossless_cast(x.type().narrow(), x);
return can_narrow(x.type()) ? lossless_cast(x.type().narrow(), x) : Expr();
}

// Remove a widening cast even if it changes the sign of the result.
Expr strip_widening_cast(const Expr &x) {
Expr narrow = lossless_narrow(x);
if (narrow.defined()) {
return narrow;
if (can_narrow(x.type())) {
Expr narrow = lossless_narrow(x);
if (narrow.defined()) {
return narrow;
}
return lossless_cast(x.type().narrow().with_code(halide_type_uint), x);
} else {
return Expr();
}
return lossless_cast(x.type().narrow().with_code(halide_type_uint), x);
}

Expr saturating_narrow(const Expr &a) {
Expand Down Expand Up @@ -217,16 +238,18 @@ class FindIntrinsics : public IRMutator {

// Try widening both from the same signedness as the result, and from uint.
for (halide_type_code_t code : {op->type.code(), halide_type_uint}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);
if (can_narrow(op->type)) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);

if (narrow_a.defined() && narrow_b.defined()) {
Expr result = widening_add(narrow_a, narrow_b);
if (result.type() != op->type) {
result = Cast::make(op->type, result);
if (narrow_a.defined() && narrow_b.defined()) {
Expr result = widening_add(narrow_a, narrow_b);
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
return mutate(result);
}
}

Expand All @@ -235,41 +258,43 @@ class FindIntrinsics : public IRMutator {
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
// Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now.
Expr narrow_a = (a.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, a);
Expr narrow_b = (b.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, b);

// This case should have been handled by the above check for widening_add.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_add: " << a << " + " << b << "\n";

if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_add(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_add(b, narrow_a);
}
internal_assert(result.type() == op->type);
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_add(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_add(a, narrow_b);
if (can_narrow(op->type)) {
Type narrow = op->type.narrow().with_code(code);
// Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now.
Expr narrow_a = (a.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, a);
Expr narrow_b = (b.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, b);

// This case should have been handled by the above check for widening_add.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_add: " << a << " + " << b << "\n";

if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_add(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_add(b, narrow_a);
}
internal_assert(result.type() == op->type);
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_add(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_add(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
Expand All @@ -294,22 +319,24 @@ class FindIntrinsics : public IRMutator {

// Try widening both from the same type as the result, and from uint.
for (halide_type_code_t code : {op->type.code(), halide_type_uint}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);
if (can_narrow(op->type)) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);

if (narrow_a.defined() && narrow_b.defined()) {
Expr negative_narrow_b = lossless_negate(narrow_b);
Expr result;
if (negative_narrow_b.defined()) {
result = widening_add(narrow_a, negative_narrow_b);
} else {
result = widening_sub(narrow_a, narrow_b);
}
if (result.type() != op->type) {
result = Cast::make(op->type, result);
if (narrow_a.defined() && narrow_b.defined()) {
Expr negative_narrow_b = lossless_negate(narrow_b);
Expr result;
if (negative_narrow_b.defined()) {
result = widening_add(narrow_a, negative_narrow_b);
} else {
result = widening_sub(narrow_a, narrow_b);
}
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
return mutate(result);
}
}

Expand All @@ -324,22 +351,24 @@ class FindIntrinsics : public IRMutator {
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_b = lossless_cast(narrow, b);

if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_sub(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_sub(a, narrow_b);
if (can_narrow(op->type)) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_b = lossless_cast(narrow, b);

if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_sub(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_sub(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
Expand Down Expand Up @@ -401,40 +430,42 @@ class FindIntrinsics : public IRMutator {
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);

// This case should have been handled by the above check for widening_mul.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_mul: " << a << " + " << b << "\n";

if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_mul(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_mul(b, narrow_a);
}
internal_assert(result.type() == op->type);
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_mul(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_mul(a, narrow_b);
if (can_narrow(op->type)) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);

// This case should have been handled by the above check for widening_mul.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_mul: " << a << " + " << b << "\n";

if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_mul(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_mul(b, narrow_a);
}
internal_assert(result.type() == op->type);
return mutate(result);
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_mul(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = cast(op->type, result);
} else {
result = widen_right_mul(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
Expand Down Expand Up @@ -853,21 +884,25 @@ class FindIntrinsics : public IRMutator {
} else if (op->is_intrinsic(Call::widening_add) && (op->type.bits() >= 16)) {
internal_assert(op->args.size() == 2);
for (halide_type_code_t t : {op->type.code(), halide_type_uint}) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_add(narrow_a, narrow_b)));
if (can_narrow(op->type)) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_add(narrow_a, narrow_b)));
}
}
}
} else if (op->is_intrinsic(Call::widening_sub) && (op->type.bits() >= 16)) {
internal_assert(op->args.size() == 2);
for (halide_type_code_t t : {op->type.code(), halide_type_uint}) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_sub(narrow_a, narrow_b)));
if (can_narrow(op->type)) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_sub(narrow_a, narrow_b)));
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions test/correctness/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ Expr make_leaf(Type t, const char *name) {
}

int main(int argc, char **argv) {
Expr i1x = make_leaf(Int(1, 4), "i1x");
Expr i1y = make_leaf(Int(1, 4), "i1y");
Expr i8x = make_leaf(Int(8, 4), "i8x");
Expr i8y = make_leaf(Int(8, 4), "i8y");
Expr i8z = make_leaf(Int(8, 4), "i8w");
Expand Down Expand Up @@ -150,15 +152,18 @@ int main(int argc, char **argv) {
// check(u32(u8x) * 256, u32(widening_shift_left(u8x, u8(8))));

// Check widening arithmetic
check(i8(i1x) + i1y, widening_add(i1x, i1y));
check(i16(i8x) + i8y, widening_add(i8x, i8y));
check(u16(u8x) + u8y, widening_add(u8x, u8y));
check(i16(u8x) + u8y, i16(widening_add(u8x, u8y)));
check(f32(f16x) + f32(f16y), widening_add(f16x, f16y));

check(i8(i1x) - i1y, widening_sub(i1x, i1y));
check(i16(i8x) - i8y, widening_sub(i8x, i8y));
check(i16(u8x) - u8y, widening_sub(u8x, u8y));
check(f32(f16x) - f32(f16y), widening_sub(f16x, f16y));

check(i8(i1x) * i1y, widening_mul(i1x, i1y));
check(i16(i8x) * i8y, widening_mul(i8x, i8y));
check(u16(u8x) * u8y, widening_mul(u8x, u8y));
check(i32(i8x) * i8y, i32(widening_mul(i8x, i8y)));
Expand Down
Loading