Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let lerp lowering incorporate a final cast. #6480

Merged
merged 4 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2314,7 +2314,7 @@ void CodeGen_C::visit(const Call *op) {
}
} else if (op->is_intrinsic(Call::lerp)) {
internal_assert(op->args.size() == 3);
Expr e = lower_lerp(op->args[0], op->args[1], op->args[2], target);
Expr e = lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target);
rhs << print_expr(e);
} else if (op->is_intrinsic(Call::absd)) {
internal_assert(op->args.size() == 2);
Expand Down
18 changes: 16 additions & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,20 @@ void CodeGen_LLVM::visit(const Cast *op) {
return;
}

if (const Call *c = Call::as_intrinsic(op->value, {Call::lerp})) {
// We want to codegen a cast of a lerp as a single thing, because it can
// be done more intelligently than a lerp followed by a cast.
Type t = upgrade_type_for_arithmetic(c->type);
Type wt = upgrade_type_for_arithmetic(c->args[2].type());
Expr e = lower_lerp(op->type,
cast(t, c->args[0]),
cast(t, c->args[1]),
cast(wt, c->args[2]),
target);
codegen(e);
return;
}

value = codegen(op->value);
llvm::Type *llvm_dst = llvm_type_of(dst);

Expand Down Expand Up @@ -2698,11 +2712,11 @@ void CodeGen_LLVM::visit(const Call *op) {
// TODO: This might be surprising behavior?
Type t = upgrade_type_for_arithmetic(op->type);
Type wt = upgrade_type_for_arithmetic(op->args[2].type());
Expr e = lower_lerp(cast(t, op->args[0]),
Expr e = lower_lerp(op->type,
cast(t, op->args[0]),
cast(t, op->args[1]),
cast(wt, op->args[2]),
target);
e = cast(op->type, e);
codegen(e);
} else if (op->is_intrinsic(Call::popcount)) {
internal_assert(op->args.size() == 1);
Expand Down
2 changes: 1 addition & 1 deletion src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ class OptimizePatterns : public IRMutator {
// We need to lower lerps now to optimize the arithmetic
// that they generate.
internal_assert(op->args.size() == 3);
return mutate(lower_lerp(op->args[0], op->args[1], op->args[2], target));
return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target));
} else if ((op->is_intrinsic(Call::div_round_to_zero) ||
op->is_intrinsic(Call::mod_round_to_zero)) &&
!op->type.is_float() && op->type.is_vector()) {
Expand Down
14 changes: 12 additions & 2 deletions src/Lerp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace Halide {
namespace Internal {

Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target) {
Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target) {

Expr result;

Expand Down Expand Up @@ -153,7 +153,6 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t
} else {
result = rounding_shift_right(rounding_shift_right(prod_sum, bits) + prod_sum, bits);
}
result = Cast::make(UInt(bits, computation_type.lanes()), result);
break;
}
case 64:
Expand All @@ -165,13 +164,24 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &t
default:
break;
}

if (weight.type().is_float()) {
// Insert an explicit cast to the computation type, even if
// we're going to widen, because out-of-range floats can produce
// out-of-range outputs.
result = Cast::make(computation_type, result);
}
}

if (!is_const_zero(bias_value)) {
result = Cast::make(result_type, result + bias_value);
}
}

if (result.type() != final_type) {
result = Cast::make(final_type, result);
}

return simplify(common_subexpression_elimination(result));
}

Expand Down
8 changes: 5 additions & 3 deletions src/Lerp.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ struct Target;

namespace Internal {

/** Build Halide IR that computes a lerp. Use by codegen targets that
* don't have a native lerp. */
Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight, const Target &target);
/** Build Halide IR that computes a lerp. Use by codegen targets that don't have
* a native lerp. The lerp is done in the type of the zero value. The final_type
* is a cast that should occur after the lerp. It's included because in some
* cases you can incorporate a final cast into the lerp math. */
Expr lower_lerp(Type final_type, Expr zero_val, Expr one_val, const Expr &weight, const Target &target);

} // namespace Internal
} // namespace Halide
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ tests(GROUPS correctness
vectorized_initialization.cpp
vectorized_load_from_vectorized_allocation.cpp
vectorized_reduction_bug.cpp
widening_lerp.cpp
widening_reduction.cpp
)

Expand Down
64 changes: 64 additions & 0 deletions test/correctness/widening_lerp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "Halide.h"

using namespace Halide;

std::mt19937 rng(0);

int main(int argc, char **argv) {

int fuzz_seed = argc > 1 ? atoi(argv[1]) : time(nullptr);
rng.seed(fuzz_seed);
printf("Lerp test seed: %d\n", fuzz_seed);

// Lerp lowering incorporates a cast. This test checks that a widening lerp
// is equal to the widened version of the lerp.
for (Type t1 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) {
if (rng() & 1) continue;
for (Type t2 : {UInt(8), UInt(16), UInt(32), Float(32)}) {
if (rng() & 1) continue;
for (Type t3 : {UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32), Float(32)}) {
if (rng() & 1) continue;
Func f;
Var x;
f(x) = cast(t1, random_uint((int)rng()));

Expr weight = cast(t2, f(x + 16));
if (t2.is_float()) {
weight /= 256.f;
weight = clamp(weight, 0.f, 1.f);
}

Expr zero_val = f(x);
Expr one_val = f(x + 8);
Expr lerped = lerp(zero_val, one_val, weight);

Func cast_and_lerp, lerp_alone, cast_of_lerp;
cast_and_lerp(x) = cast(t3, lerped);
lerp_alone(x) = lerped;
cast_of_lerp(x) = cast(t3, lerp_alone(x));

RDom r(0, 32 * 1024);
Func check;
check() = maximum(abs(cast<double>(cast_and_lerp(r)) -
cast<double>(cast_of_lerp(r))));

f.compute_root().vectorize(x, 8, TailStrategy::RoundUp);
lerp_alone.compute_root().vectorize(x, 8, TailStrategy::RoundUp);
cast_and_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp);
cast_of_lerp.compute_root().vectorize(x, 8, TailStrategy::RoundUp);

double err = evaluate<double>(check());

if (err > 1e-5) {
printf("Difference of lerp + cast and lerp alone is %f,"
" which exceeds threshold for seed %d\n",
err, fuzz_seed);
return -1;
}
}
}
}

printf("Success!\n");
return 0;
}