diff --git a/src/CodeGen_WebAssembly.cpp b/src/CodeGen_WebAssembly.cpp index 6f37f1447df1..53329ed52172 100644 --- a/src/CodeGen_WebAssembly.cpp +++ b/src/CodeGen_WebAssembly.cpp @@ -3,6 +3,7 @@ #include "CodeGen_Posix.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "IRMatch.h" #include "IROperator.h" #include "LLVM_Headers.h" @@ -206,6 +207,12 @@ void CodeGen_WebAssembly::visit(const Call *op) { {"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128}, {"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128}, }; + static const Pattern reinterpret_patterns[] = { + {"saturating_narrow", i8_sat(wild_u16x_), Target::WasmSimd128}, + {"saturating_narrow", u8_sat(wild_u16x_), Target::WasmSimd128}, + {"saturating_narrow", i16_sat(wild_u32x_), Target::WasmSimd128}, + {"saturating_narrow", u16_sat(wild_u32x_), Target::WasmSimd128}, + }; static const vector> cast_rewrites = { // Some double-narrowing saturating casts can be better expressed as // combinations of single-narrowing saturating casts. @@ -235,6 +242,36 @@ void CodeGen_WebAssembly::visit(const Call *op) { return; } } + + // Search for saturating casts where the inner value can be + // reinterpreted to signed, so that we can use existing + // saturating_narrow instructions. + // TODO: should use lossless_cast once it is fixed. + for (const Pattern &p : reinterpret_patterns) { + if (!target.has_feature(p.required_feature)) { + continue; + } + if (expr_match(p.pattern, op, matches)) { + const Expr &expr = matches[0]; + const Type &t = expr.type(); + // TODO(8212): might want to keep track of scope of bounds information. + const ConstantInterval ibounds = constant_integer_bounds(expr); + const Type reint_type = t.with_code(halide_type_int); + // If the signed type can represent the maximum value unsigned value, + // we can safely reinterpret this unsigned expression as signed. + if (reint_type.can_represent(ibounds)) { + // Can safely reinterpret to signed integer. + matches[0] = cast(reint_type, matches[0]); + + value = call_overloaded_intrin(op->type, p.intrin, matches); + if (value) { + return; + } + } + // No reinterpret patterns match the same input, so stop matching. + break; + } + } } if (op->is_intrinsic(Call::round)) { diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index b0df27af0f2f..5dd6a17e02d2 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -1,6 +1,7 @@ #include "CodeGen_Internal.h" #include "CodeGen_Posix.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "Debug.h" #include "IRMatch.h" #include "IRMutator.h" @@ -537,7 +538,7 @@ void CodeGen_X86::visit(const Cast *op) { }; // clang-format off - static Pattern patterns[] = { + static const Pattern patterns[] = { // This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't // saturate the result. {"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))}, @@ -647,7 +648,7 @@ void CodeGen_X86::visit(const Call *op) { }; // clang-format off - static Pattern patterns[] = { + static const Pattern patterns[] = { {"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)}, {"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)}, {"saturating_narrow", i16_sat(wild_i32x_)}, @@ -667,6 +668,41 @@ void CodeGen_X86::visit(const Call *op) { } } + // clang-format off + static const Pattern reinterpret_patterns[] = { + {"saturating_narrow", i16_sat(wild_u32x_)}, + {"saturating_narrow", u16_sat(wild_u32x_)}, + {"saturating_narrow", i8_sat(wild_u16x_)}, + {"saturating_narrow", u8_sat(wild_u16x_)}, + }; + // clang-format on + + // Search for saturating casts where the inner value can be + // reinterpreted to signed, so that we can use existing + // saturating_narrow instructions. + // TODO: should use lossless_cast once it is fixed. + for (const auto &pattern : reinterpret_patterns) { + if (expr_match(pattern.pattern, op, matches)) { + const Expr &expr = matches[0]; + const Type &t = expr.type(); + // TODO(8212): might want to keep track of scope of bounds information. + const ConstantInterval ibounds = constant_integer_bounds(expr); + const Type reint_type = t.with_code(halide_type_int); + // If the signed type can represent the maximum value unsigned value, + // we can safely reinterpret this unsigned expression as signed. + if (reint_type.can_represent(ibounds)) { + // Can safely reinterpret to signed integer. + matches[0] = cast(reint_type, matches[0]); + value = call_overloaded_intrin(op->type, pattern.intrin, matches); + if (value) { + return; + } + } + // No reinterpret patterns match the same input, so stop matching. + break; + } + } + static const vector> cast_rewrites = { // Some double-narrowing saturating casts can be better expressed as // combinations of single-narrowing saturating casts. diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index f11fa3348399..a123738fe298 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -3,6 +3,7 @@ #include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" +#include "ConstantBounds.h" #include "DistributeShifts.h" #include "ExprUsesVar.h" #include "FindIntrinsics.h" @@ -189,8 +190,10 @@ struct Pattern { // re-interleave the result. ReinterleaveOp0 = InterleaveResult | DeinterleaveOp0, - v65orLater = 1 << 10, // Pattern should be matched only for v65 target or later - v66orLater = 1 << 11, // Pattern should be matched only for v66 target or later + SafeReinterpretOp0 = 1 << 10, // Pattern should be matched only if the first arg can be safely reinterpreted. + + v65orLater = 1 << 11, // Pattern should be matched only for v65 target or later + v66orLater = 1 << 12, // Pattern should be matched only for v66 target or later }; string intrin; // Name of the intrinsic @@ -260,6 +263,27 @@ bool process_match_flags(vector &matches, int flags) { internal_assert(matches.size() >= 3); std::swap(matches[1], matches[2]); } + if (flags & Pattern::SafeReinterpretOp0) { + // Use bounds inference to check if the first operand can + // be safely reinterpreted. + // TODO: should use lossless_cast once it is fixed. + const Expr &expr = matches[0]; + const Type &t = expr.type(); + if (t.is_int()) { + // TODO(8212): might want to keep track of scope of bounds information. + const ConstantInterval ibounds = constant_integer_bounds(expr); + const Type reint_type = UInt(t.bits()); + // A signed integer can be reinterpreted as unsigned if strictly positive. + return reint_type.can_represent(ibounds); + } else { + internal_assert(t.is_uint()); + // TODO(8212): might want to keep track of scope of bounds information. + const ConstantInterval ibounds = constant_integer_bounds(expr); + const Type reint_type = Int(t.bits()); + // An unsigned integer can be reinterpreted as signed if less than int max. + return reint_type.can_represent(ibounds); + } + } return true; } @@ -915,10 +939,18 @@ class OptimizePatterns : public IRMutator { // Saturating narrowing casts. These may interleave later with trunc_sat. {"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)}, - {"halide.hexagon.pack_satub.vuh", u8_sat(wild_u16x)}, {"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)}, {"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)}, {"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)}, + // The same patterns as above, but with safely reinterpreting the + // argument to be signed. + {"halide.hexagon.pack_satub.vh", u8_sat(wild_u16x), Pattern::SafeReinterpretOp0}, + {"halide.hexagon.pack_satuh.vw", u16_sat(wild_u32x), Pattern::SafeReinterpretOp0}, + {"halide.hexagon.pack_satb.vh", i8_sat(wild_u16x), Pattern::SafeReinterpretOp0}, + {"halide.hexagon.pack_sath.vw", i16_sat(wild_u32x), Pattern::SafeReinterpretOp0}, + // Slightly more expensive versions of uint saturation casts than the reinterpret + // patterns above, these perform vpack(min(UMAX, x)). + {"halide.hexagon.pack_satub.vuh", u8_sat(wild_u16x)}, {"halide.hexagon.pack_satuh.vuw", u16_sat(wild_u32x)}, // We don't have a vpack equivalent to this one, so we match it directly. diff --git a/test/correctness/simd_op_check_hvx.cpp b/test/correctness/simd_op_check_hvx.cpp index 29bdde4a9163..db9564b9e460 100644 --- a/test/correctness/simd_op_check_hvx.cpp +++ b/test/correctness/simd_op_check_hvx.cpp @@ -304,6 +304,11 @@ class SimdOpCheckHVX : public SimdOpCheckTest { // for a more detailed explanation. check("v*.uh = vsat(v*.uw,v*.uw)", hvx_width / 2, u16_sat(u32_1)); check("v*.h = vpack(v*.w,v*.w):sat", hvx_width / 2, i16_sat(i32_1)); + // Test that bounds-inference instruction selection is working properly. + check("v*.ub = vpack(v*.h,v*.h):sat", hvx_width / 1, u8_sat(u16_1 >> 1)); + check("v*.b = vpack(v*.h,v*.h):sat", hvx_width / 1, i8_sat(u16_1 >> 1)); + check("v*.uh = vpack(v*.w,v*.w):sat", hvx_width / 2, u16_sat(u32_1 >> 1)); + check("v*.h = vpack(v*.w,v*.w):sat", hvx_width / 2, i16_sat(u32_1 >> 1)); // vpack doesn't interleave its inputs, which means it doesn't // simplify with widening. This is preferable for when the diff --git a/test/correctness/simd_op_check_wasm.cpp b/test/correctness/simd_op_check_wasm.cpp index 2045b42699f4..a9d46cd2c1ad 100644 --- a/test/correctness/simd_op_check_wasm.cpp +++ b/test/correctness/simd_op_check_wasm.cpp @@ -506,6 +506,11 @@ class SimdOpCheckWASM : public SimdOpCheckTest { check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(i32_1)); check("i16x8.narrow_i32x4_s", 8 * w, i8_sat(i32_1)); check("i16x8.narrow_i32x4_s", 8 * w, u8_sat(i32_1)); + // Test that bounds-inference instruction selection is working properly. + check("i8x16.narrow_i16x8_s", 16 * w, i8_sat(u16_1 >> 1)); + check("i8x16.narrow_i16x8_u", 16 * w, u8_sat(u16_1 >> 1)); + check("i16x8.narrow_i32x4_s", 8 * w, i16_sat(u32_1 >> 1)); + check("i16x8.narrow_i32x4_u", 8 * w, u16_sat(u32_1 >> 1)); // Integer to integer widening check("i16x8.extend_low_i8x16_s", 16 * w, i16(i8_1));