diff --git a/src/Lerp.cpp b/src/Lerp.cpp index 9cba96e65844..e145780853a9 100644 --- a/src/Lerp.cpp +++ b/src/Lerp.cpp @@ -134,17 +134,11 @@ Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight) { case 8: case 16: case 32: { - Expr zero_expand = Cast::make(UInt(2 * bits, computation_type.lanes()), - zero_val); - Expr one_expand = Cast::make(UInt(2 * bits, one_val.type().lanes()), - one_val); - - Expr rounding = Cast::make(UInt(2 * bits), 1) << Cast::make(UInt(2 * bits), (bits - 1)); - Expr divisor = Cast::make(UInt(2 * bits), 1) << Cast::make(UInt(2 * bits), bits); - - Expr prod_sum = zero_expand * inverse_typed_weight + - one_expand * typed_weight + rounding; - Expr divided = ((prod_sum / divisor) + prod_sum) / divisor; + Expr shift = Cast::make(UInt(2 * bits), bits); + Expr prod_sum = widening_mul(zero_val, inverse_typed_weight) + widening_mul(one_val, typed_weight); + // Computes x / (2 ** N - 1) as (x / 2 ** N + x) / 2 ** N. + // TODO: on x86 it's actually one instruction cheaper to do the division directly. + Expr divided = rounding_shift_right(rounding_shift_right(prod_sum, shift) + prod_sum, shift); result = Cast::make(UInt(bits, computation_type.lanes()), divided); break;