diff --git a/python_bindings/src/halide/halide_/PyCallable.cpp b/python_bindings/src/halide/halide_/PyCallable.cpp index ab25626988d2..1ad7eab58ef6 100644 --- a/python_bindings/src/halide/halide_/PyCallable.cpp +++ b/python_bindings/src/halide/halide_/PyCallable.cpp @@ -109,13 +109,11 @@ class PyCallable { } else { argv[slot] = &scalar_storage[slot]; - // clang-format off - - #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE, FIELD) \ - case halide_type_t(CODE, BITS).as_u32(): \ - scalar_storage[slot].u.FIELD = cast_to(value); \ - cci[slot] = Callable::make_scalar_qcci(halide_type_t(CODE, BITS)); \ - break; +#define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE, FIELD) \ + case halide_type_t(CODE, BITS).as_u32(): \ + scalar_storage[slot].u.FIELD = cast_to(value); \ + cci[slot] = Callable::make_scalar_qcci(halide_type_t(CODE, BITS)); \ + break; switch (((halide_type_t)c_arg.type).element_of().as_u32()) { HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float, f32) @@ -134,9 +132,7 @@ class PyCallable { _halide_user_assert(0) << "Unsupported type in Callable argument list: " << c_arg.type << "\n"; } - #undef HALIDE_HANDLE_TYPE_DISPATCH - - // clang-format on +#undef HALIDE_HANDLE_TYPE_DISPATCH } }; diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index a27f1486164e..ad190765fa49 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -363,7 +363,6 @@ CodeGen_ARM::CodeGen_ARM(const Target &target) negations.emplace_back("saturating_negate", -max(wild_i8x_, -127)); negations.emplace_back("saturating_negate", -max(wild_i16x_, -32767)); negations.emplace_back("saturating_negate", -max(wild_i32x_, -(0x7fffffff))); - // clang-format on } constexpr int max_intrinsic_args = 4; @@ -393,7 +392,6 @@ struct ArmIntrinsic { }; }; -// clang-format off const ArmIntrinsic intrinsic_defs[] = { // TODO(https://github.com/halide/Halide/issues/8093): // Some of the Arm intrinsic have the same name between Neon and SVE2 but with different behavior. For example, @@ -406,7 +404,7 @@ const ArmIntrinsic intrinsic_defs[] = { {"vabs", "abs", UInt(32, 2), "abs", {Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, {"llvm.fabs", "llvm.fabs", Float(16, 4), "abs", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, {"llvm.fabs", "llvm.fabs", Float(32, 2), "abs", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, - {"llvm.fabs", "llvm.fabs", Float(64, 2), "abs", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs", "llvm.fabs", Float(64, 2), "abs", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, {"llvm.fabs.f16", "llvm.fabs.f16", Float(16), "abs", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, {"llvm.fabs.f32", "llvm.fabs.f32", Float(32), "abs", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, {"llvm.fabs.f64", "llvm.fabs.f64", Float(64), "abs", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, @@ -870,7 +868,6 @@ const std::map float16_transcendental_remapping = { {"tan_f16", "tan_f32"}, {"tanh_f16", "tanh_f32"}, }; -// clang-format on llvm::Type *CodeGen_ARM::llvm_type_with_constraint(const Type &t, bool scalars_are_vectors, VectorTypeConstraint constraint) { @@ -2170,7 +2167,7 @@ bool CodeGen_ARM::codegen_dot_product_vector_reduce(const VectorReduce *op, cons Target::Feature required_feature; std::vector extra_operands; }; - // clang-format off + static const Pattern patterns[] = { {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_i8x_)), "dot_product", Target::ARMDotProd}, {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::ARMDotProd}, @@ -2193,7 +2190,6 @@ bool CodeGen_ARM::codegen_dot_product_vector_reduce(const VectorReduce *op, cons {VectorReduce::Add, 4, i64(wild_u16x_), "dot_product", Target::SVE2, {1}}, {VectorReduce::Add, 4, u64(wild_u16x_), "dot_product", Target::SVE2, {1}}, }; - // clang-format on int factor = op->value.type().lanes() / op->type.lanes(); vector matches; diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 8515b020ea64..7bf29bd15aea 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -581,7 +581,6 @@ halide_type_t u8v2 = u8v1.with_lanes(u8v1.lanes * 2); halide_type_t u16v2 = u16v1.with_lanes(u16v1.lanes * 2); halide_type_t u32v2 = u32v1.with_lanes(u32v1.lanes * 2); -// clang-format off #define INTRINSIC_128B(id) llvm::Intrinsic::hexagon_V6_##id##_128B const HvxIntrinsic intrinsic_wrappers[] = { // Zero/sign extension: @@ -689,7 +688,7 @@ const HvxIntrinsic intrinsic_wrappers[] = { {INTRINSIC_128B(vavghrnd), i16v1, "avg_rnd.vh.vh", {i16v1, i16v1}}, {INTRINSIC_128B(vavgwrnd), i32v1, "avg_rnd.vw.vw", {i32v1, i32v1}}, - // This one is weird: i8_sat((u8 - u8)/2). It both saturates and averages. + // This one is weird: i8_sat((u8 - u8)/2). It both saturates and averages. {INTRINSIC_128B(vnavgub), i8v1, "navg.vub.vub", {u8v1, u8v1}}, {INTRINSIC_128B(vnavgb), i8v1, "navg.vb.vb", {i8v1, i8v1}, HvxIntrinsic::v65OrLater}, {INTRINSIC_128B(vnavgh), i16v1, "navg.vh.vh", {i16v1, i16v1}}, @@ -841,7 +840,6 @@ const HvxIntrinsic intrinsic_wrappers[] = { {INTRINSIC_128B(vnormamth), u16v1, "cls.vh", {u16v1}}, {INTRINSIC_128B(vnormamtw), u32v1, "cls.vw", {u32v1}}, }; -// clang-format on // TODO: Many variants of the above functions are missing. They // need to be implemented in the runtime module, or via diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 9aef62a49188..3204002652c9 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -465,7 +465,6 @@ void CodeGen_PTX_Dev::codegen_vector_reduce(const VectorReduce *op, const Expr & // TODO: Support rewriting to arbitrary calls in IRMatch and use that instead // of expr_match here. That would probably allow avoiding the redundant swapping // operands logic. - // clang-format off static const Pattern patterns[] = { {VectorReduce::Add, 4, i32(widening_mul(wild_i8x, wild_i8x)), "dp4a"}, {VectorReduce::Add, 4, i32(widening_mul(wild_i8x, wild_u8x)), "dp4a"}, @@ -480,7 +479,6 @@ void CodeGen_PTX_Dev::codegen_vector_reduce(const VectorReduce *op, const Expr & {VectorReduce::Add, 4, widening_mul(wild_i16x, wild_u16x), "dp2a", Pattern::SwapOps | Pattern::NarrowOp1}, {VectorReduce::Add, 4, widening_mul(wild_u16x, wild_u16x), "dp2a", Pattern::SwapOps | Pattern::NarrowOp1}, }; - // clang-format on const int input_lanes = op->value.type().lanes(); const int factor = input_lanes / op->type.lanes(); diff --git a/src/CodeGen_PowerPC.cpp b/src/CodeGen_PowerPC.cpp index c71ae3aed705..0ee3669ced54 100644 --- a/src/CodeGen_PowerPC.cpp +++ b/src/CodeGen_PowerPC.cpp @@ -51,7 +51,6 @@ struct PowerPCIntrinsic { Target::Feature feature = Target::FeatureEnd; }; -// clang-format off const PowerPCIntrinsic intrinsic_defs[] = { {"llvm.ppc.altivec.vminsb", Int(8, 16), "min", {Int(8, 16), Int(8, 16)}}, {"llvm.ppc.altivec.vminub", UInt(8, 16), "min", {UInt(8, 16), UInt(8, 16)}}, @@ -96,7 +95,6 @@ const PowerPCIntrinsic intrinsic_defs[] = { {"llvm.ppc.altivec.vavgsw", Int(32, 4), "rounding_halving_add", {Int(32, 4), Int(32, 4)}}, {"llvm.ppc.altivec.vavguw", UInt(32, 4), "rounding_halving_add", {UInt(32, 4), UInt(32, 4)}}, }; -// clang-format on void CodeGen_PowerPC::init_module() { CodeGen_Posix::init_module(); diff --git a/src/CodeGen_WebAssembly.cpp b/src/CodeGen_WebAssembly.cpp index 7477b7965766..f173a88e1211 100644 --- a/src/CodeGen_WebAssembly.cpp +++ b/src/CodeGen_WebAssembly.cpp @@ -58,7 +58,6 @@ struct WasmIntrinsic { Target::Feature feature = Target::FeatureEnd; }; -// clang-format off const WasmIntrinsic intrinsic_defs[] = { {"llvm.sadd.sat.v8i16", Int(16, 8), "saturating_add", {Int(16, 8), Int(16, 8)}, Target::WasmSimd128}, {"llvm.uadd.sat.v8i16", UInt(16, 8), "saturating_add", {UInt(16, 8), UInt(16, 8)}, Target::WasmSimd128}, @@ -111,7 +110,6 @@ const WasmIntrinsic intrinsic_defs[] = { {"llvm.nearbyint.f32", Float(32), "nearbyint", {Float(32)}}, {"llvm.nearbyint.f64", Float(64), "nearbyint", {Float(64)}}, }; -// clang-format on void CodeGen_WebAssembly::init_module() { CodeGen_Posix::init_module(); @@ -144,7 +142,6 @@ void CodeGen_WebAssembly::visit(const Cast *op) { Target::Feature required_feature; }; - // clang-format off static const Pattern patterns[] = { {"int_to_double", f64(wild_i32x_), Target::WasmSimd128}, {"int_to_double", f64(wild_u32x_), Target::WasmSimd128}, @@ -155,7 +152,6 @@ void CodeGen_WebAssembly::visit(const Cast *op) { {"widen_integer", i64(wild_i32x_), Target::WasmSimd128}, {"widen_integer", u64(wild_u32x_), Target::WasmSimd128}, }; - // clang-format on if (op->type.is_vector()) { std::vector matches; @@ -193,7 +189,6 @@ void CodeGen_WebAssembly::visit(const Call *op) { Target::Feature required_feature; }; - // clang-format off static const Pattern patterns[] = { {"q15mulr_sat_s", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15), Target::WasmSimd128}, {"saturating_narrow", i8_sat(wild_i16x_), Target::WasmSimd128}, @@ -213,7 +208,6 @@ void CodeGen_WebAssembly::visit(const Call *op) { {u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))}, {i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))}, }; - // clang-format on if (op->type.is_vector()) { std::vector matches; @@ -287,7 +281,7 @@ void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Ex const char *intrin; Target::Feature required_feature; }; - // clang-format off + static const Pattern patterns[] = { {VectorReduce::Add, 2, i16(wild_i8x_), "pairwise_widening_add", Target::WasmSimd128}, {VectorReduce::Add, 2, u16(wild_u8x_), "pairwise_widening_add", Target::WasmSimd128}, @@ -299,7 +293,6 @@ void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Ex {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", Target::WasmSimd128}, }; - // clang-format on // Other values will be added soon, so this switch isn't actually pointless using ValuePtr = llvm::Value *; diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 7854a7160f1b..0e63af410cce 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -133,7 +133,6 @@ struct x86Intrinsic { }; }; -// clang-format off const x86Intrinsic intrinsic_defs[] = { // AVX2/SSSE3 LLVM intrinsics for pabs fail in JIT. The integer wrappers // just call `llvm.abs` (which requires a second argument). @@ -295,17 +294,16 @@ const x86Intrinsic intrinsic_defs[] = { {"tileloadd64_i8", Int(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tileloadd64_i8", UInt(8, 1024), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tileloadd64_bf16", BFloat(16, 512), "tile_load", {Int(16), Int(16), Handle(), Int(64), Int(64)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, - {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, + {"tdpbssd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbsud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), Int(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbusd", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), Int(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbuud", Int(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Int(32, 256), UInt(8, 1024), UInt(8, 1024)}, Target::AVX512_SapphireRapids}, {"tdpbf16ps", Float(32, 256), "tile_matmul", {Int(16), Int(16), Int(16), Float(32, 256), BFloat(16, 512), BFloat(16, 512)}, Target::AVX512_SapphireRapids}, - {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, + {"tilezero_i32", Int(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, {"tilezero_f32", Float(32, 256), "tile_zero", {Int(16), Int(16)}, Target::AVX512_SapphireRapids}, {"tilestored64_i32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Int(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, {"tilestored64_f32", Int(32), "tile_store", {Int(16), Int(16), Handle(), Int(64), Int(64), Float(32, 256)}, Target::AVX512_SapphireRapids, x86Intrinsic::AccessesMemory}, }; -// clang-format on void CodeGen_X86::init_module() { CodeGen_Posix::init_module(); @@ -549,7 +547,6 @@ void CodeGen_X86::visit(const Cast *op) { Expr pattern; }; - // clang-format off static Pattern patterns[] = { // This isn't rounding_mul_shift_right(i16, i16, 15) because it doesn't // saturate the result. @@ -557,7 +554,6 @@ void CodeGen_X86::visit(const Cast *op) { {"f32_to_bf16", bf16(wild_f32x_)}, }; - // clang-format on vector matches; for (const Pattern &p : patterns) { @@ -783,7 +779,7 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init SingleArg = 1 << 2, }; }; - // clang-format off + // These patterns are roughly sorted "best to worst", in case there are two // patterns that match the expression. static const Pattern patterns[] = { @@ -819,7 +815,6 @@ void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init {VectorReduce::Add, 8, u64(absd(wild_u8x_, wild_u8x_)), "sum_of_absolute_differences", {}}, }; - // clang-format on std::vector matches; for (const Pattern &p : patterns) { @@ -1151,11 +1146,11 @@ int CodeGen_X86::vector_lanes_for_slice(const Type &t) const { // type if we can. int vec_bits = t.lanes() * t.bits(); int natural_vec_bits = target.natural_vector_size(t) * t.bits(); - // clang-format off + int slice_bits = ((vec_bits > 256 && natural_vec_bits > 256) ? 512 : (vec_bits > 128 && natural_vec_bits > 128) ? 256 : 128); - // clang-format on + return slice_bits / t.bits(); } diff --git a/src/Debug.h b/src/Debug.h index c0b6c25afcb4..3f82b14b55ba 100644 --- a/src/Debug.h +++ b/src/Debug.h @@ -48,11 +48,10 @@ bool debug_is_active_impl(int verbosity, const char *file, const char *function, * is determined by the value of the environment variable * HL_DEBUG_CODEGEN */ -// clang-format off + #define debug(n) \ /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \ if (debug_is_active((n))) std::cerr -// clang-format on /** Allow easily printing the contents of containers, or std::vector-like containers, * in debug output. Used like so: diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 9166c16f4b58..2873c5eb4ca9 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -734,7 +734,7 @@ class FindIntrinsics : public IRMutator { // We can't do everything we want here with rewrite rules alone. So, we rewrite them // to rounding_shifts with the widening still in place, and narrow it after the rewrite // succeeds. - // clang-format off + if (rewrite(max(min(rounding_shift_right(x, y), upper), lower), rounding_shift_right(x, y), is_x_wide_int_or_uint) || rewrite(rounding_shift_right(x, y), rounding_shift_right(x, y), is_x_wide_int_or_uint) || rewrite(rounding_shift_left(x, y), rounding_shift_left(x, y), is_x_wide_int_or_uint) || @@ -759,7 +759,6 @@ class FindIntrinsics : public IRMutator { } } } - // clang-format on } if (value.same_as(op->value)) { @@ -892,7 +891,7 @@ class FindIntrinsics : public IRMutator { } if (no_overflow(op->type)) { - // clang-format off + if (rewrite(halving_add(x + y, 1), rounding_halving_add(x, y)) || rewrite(halving_add(x, y + 1), rounding_halving_add(x, y)) || rewrite(halving_add(x + 1, y), rounding_halving_add(x, y)) || @@ -903,7 +902,6 @@ class FindIntrinsics : public IRMutator { false) { return mutate(rewrite.result); } - // clang-format on } // Move widening casts inside widening arithmetic outside the arithmetic, diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index b7e6056dae09..5129c5aaa5c8 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -1184,7 +1184,7 @@ class VectorReducePatterns : public IRMutator { // Map of instruction signatures static const vector sigs = ([&]() HALIDE_NEVER_INLINE { return vector{ - // clang-format off + // --------- vrmpy --------- // Sliding window {4, 32, widening_mul(wild_u8x, wild_u8x), Signature::SlidingWindow | Signature::ScalarB}, @@ -1239,7 +1239,6 @@ class VectorReducePatterns : public IRMutator { {2, 16, wild_u8x}, {2, 32, wild_i16x}, }; - // clang-format on })(); std::vector matches; diff --git a/src/Param.h b/src/Param.h index 90548f7a7545..508980d8ce65 100644 --- a/src/Param.h +++ b/src/Param.h @@ -177,18 +177,17 @@ class Param { << "The value " << val << " cannot be losslessly converted to type " << type(); param.set_scalar(val); } else { - // clang-format off // Specialized version for when T = void (thus the type is only known at runtime, // not compiletime). Note that this actually works fine for all Params; we specialize // it just to reduce code size for the common case of T != void. - #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ - case halide_type_t(CODE, BITS).as_u32(): \ - user_assert(Internal::IsRoundtrippable::value(val)) \ - << "The value " << val << " cannot be losslessly converted to type " << type; \ - param.set_scalar(Internal::StaticCast::value(val)); \ - break; +#define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ + case halide_type_t(CODE, BITS).as_u32(): \ + user_assert(Internal::IsRoundtrippable::value(val)) \ + << "The value " << val << " cannot be losslessly converted to type " << type; \ + param.set_scalar(Internal::StaticCast::value(val)); \ + break; const Type type = param.type(); switch (((halide_type_t)type).element_of().as_u32()) { @@ -208,9 +207,7 @@ class Param { internal_error << "Unsupported type in Param::set<" << type << ">\n"; } - #undef HALIDE_HANDLE_TYPE_DISPATCH - - // clang-format on +#undef HALIDE_HANDLE_TYPE_DISPATCH } } @@ -257,18 +254,17 @@ class Param { << "The value " << val << " cannot be losslessly converted to type " << type(); param.set_estimate(Expr(val)); } else { - // clang-format off // Specialized version for when T = void (thus the type is only known at runtime, // not compiletime). Note that this actually works fine for all Params; we specialize // it just to reduce code size for the common case of T != void. - #define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ - case halide_type_t(CODE, BITS).as_u32(): \ - user_assert(Internal::IsRoundtrippable::value(val)) \ - << "The value " << val << " cannot be losslessly converted to type " << type; \ - param.set_estimate(Expr(Internal::StaticCast::value(val))); \ - break; +#define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \ + case halide_type_t(CODE, BITS).as_u32(): \ + user_assert(Internal::IsRoundtrippable::value(val)) \ + << "The value " << val << " cannot be losslessly converted to type " << type; \ + param.set_estimate(Expr(Internal::StaticCast::value(val))); \ + break; const Type type = param.type(); switch (((halide_type_t)type).element_of().as_u32()) { @@ -288,9 +284,7 @@ class Param { internal_error << "Unsupported type in Param::set<" << type << ">\n"; } - #undef HALIDE_HANDLE_TYPE_DISPATCH - - // clang-format on +#undef HALIDE_HANDLE_TYPE_DISPATCH } } diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index 0d7b2c1a9a46..6158cc9cd48c 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -30,23 +30,22 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { return rewrite.result; } - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(c0 + c1, fold(c0 + c1)) || rewrite(x + x, x * 2) || rewrite(ramp(x, y, c0) + ramp(z, w, c0), ramp(x + z, y + w, c0)) || rewrite(ramp(x, y, c0) + broadcast(z, c0), ramp(x + z, y, c0)) || - rewrite(broadcast(x, c0) + broadcast(y, c1), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite(broadcast(y, c1) + broadcast(x, c0), broadcast(x + broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || - - rewrite((x + broadcast(y, c0)) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((x + broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((broadcast(y, c0) + x) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((broadcast(z, c1) + x) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((x - broadcast(y, c0)) + broadcast(z, c1), x + broadcast(broadcast(z, fold(c1/c0)) - y, c0), c1 % c0 == 0) || - rewrite((x - broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y - broadcast(z, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite((broadcast(y, c0) - x) + broadcast(z, c1), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || - rewrite((broadcast(z, c1) - x) + broadcast(y, c0), broadcast(y + broadcast(z, fold(c1/c0)), c0) - x, c1 % c0 == 0) || + rewrite(broadcast(x, c0) + broadcast(y, c1), broadcast(x + broadcast(y, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite(broadcast(y, c1) + broadcast(x, c0), broadcast(x + broadcast(y, fold(c1 / c0)), c0), c1 % c0 == 0) || + + rewrite((x + broadcast(y, c0)) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite((x + broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) + x) + broadcast(z, c1), x + broadcast(y + broadcast(z, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(z, c1) + x) + broadcast(y, c0), x + broadcast(y + broadcast(z, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite((x - broadcast(y, c0)) + broadcast(z, c1), x + broadcast(broadcast(z, fold(c1 / c0)) - y, c0), c1 % c0 == 0) || + rewrite((x - broadcast(z, c1)) + broadcast(y, c0), x + broadcast(y - broadcast(z, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite((broadcast(y, c0) - x) + broadcast(z, c1), broadcast(y + broadcast(z, fold(c1 / c0)), c0) - x, c1 % c0 == 0) || + rewrite((broadcast(z, c1) - x) + broadcast(y, c0), broadcast(y + broadcast(z, fold(c1 / c0)), c0) - x, c1 % c0 == 0) || rewrite(select(x, y, z) + select(x, w, u), select(x, y + w, z + u)) || rewrite(select(x, c0, c1) + c2, select(x, fold(c0 + c2), fold(c1 + c2))) || rewrite(select(x, y + c0, c1) + c2, select(x, y + fold(c0 + c2), fold(c1 + c2))) || @@ -63,15 +62,15 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { rewrite(select(x, y, z + c0) + c1, select(x, y + c1, z), (c0 + c1) == 0) || rewrite(select(x, c0 - y, c1) + c2, fold(c0 + c2) - select(x, y, fold(c0 - c1))) || - rewrite(x + y*(-1), x - y) || - rewrite(x*(-1) + y, y - x) || + rewrite(x + y * (-1), x - y) || + rewrite(x * (-1) + y, y - x) || rewrite((x + c0) + c1, x + fold(c0 + c1)) || rewrite((x + c0) + y, (x + y) + c0) || rewrite(x + (y + c0), (x + y) + c0) || rewrite((c0 - x) + c1, fold(c0 + c1) - x) || rewrite((c0 - x) + y, (y - x) + c0) || - rewrite(max(x, y*c0 + z) + (u - y)*c0, max(x - y*c0, z) + u*c0) || + rewrite(max(x, y * c0 + z) + (u - y) * c0, max(x - y * c0, z) + u * c0) || rewrite((x - y) + y, x) || rewrite(x + (y - x), y) || @@ -100,18 +99,18 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { rewrite(((0 - x) - y) + z, z - (x + y)) || rewrite(((c0 - x) - y) + c1, (fold(c0 + c1) - y) - x) || - rewrite(x*y + z*y, (x + z)*y) || - rewrite(x*y + y*z, (x + z)*y) || - rewrite(y*x + z*y, y*(x + z)) || - rewrite(y*x + y*z, y*(x + z)) || + rewrite(x * y + z * y, (x + z) * y) || + rewrite(x * y + y * z, (x + z) * y) || + rewrite(y * x + z * y, y * (x + z)) || + rewrite(y * x + y * z, y * (x + z)) || - rewrite((x*y) + (z - (w*x)), z + (x*(y - w))) || - rewrite((x*y) + (z - (w*y)), z + (y*(x - w))) || - rewrite((x*y) + (z - (x*w)), z + (x*(y - w))) || - rewrite((x*y) + (z - (y*w)), z + (y*(x - w))) || + rewrite((x * y) + (z - (w * x)), z + (x * (y - w))) || + rewrite((x * y) + (z - (w * y)), z + (y * (x - w))) || + rewrite((x * y) + (z - (x * w)), z + (x * (y - w))) || + rewrite((x * y) + (z - (y * w)), z + (y * (x - w))) || - rewrite(x*c0 + y*c1, (x + y*fold(c1/c0)) * c0, c1 % c0 == 0) || - rewrite(x*c0 + y*c1, (x*fold(c0/c1) + y) * c1, c0 % c1 == 0) || + rewrite(x * c0 + y * c1, (x + y * fold(c1 / c0)) * c0, c1 % c0 == 0) || + rewrite(x * c0 + y * c1, (x * fold(c0 / c1) + y) * c1, c0 % c1 == 0) || // Hoist shuffles. The Shuffle visitor wants to sink // extract_elements to the leaves, and those count as degenerate @@ -123,18 +122,18 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) - z), slice(x + y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || (no_overflow(op->type) && - (rewrite(x + x*y, x * (y + 1)) || - rewrite(x + y*x, (y + 1) * x) || - rewrite(x*y + x, x * (y + 1)) || - rewrite(y*x + x, (y + 1) * x, !is_const(x)) || - rewrite(x + (x + y)/c0, (fold(c0 + 1)*x + y)/c0, c0 != 0) || - rewrite(x + (y + x)/c0, (fold(c0 + 1)*x + y)/c0, c0 != 0) || - rewrite(x + (y - x)/c0, (fold(c0 - 1)*x + y)/c0, c0 != 0) || - rewrite(x + (x - y)/c0, (fold(c0 + 1)*x - y)/c0, c0 != 0) || - rewrite((x - y)/c0 + x, (fold(c0 + 1)*x - y)/c0, c0 != 0) || - rewrite((y - x)/c0 + x, (y + fold(c0 - 1)*x)/c0, c0 != 0) || - rewrite((x + y)/c0 + x, (fold(c0 + 1)*x + y)/c0, c0 != 0) || - rewrite((y + x)/c0 + x, (y + fold(c0 + 1)*x)/c0, c0 != 0) || + (rewrite(x + x * y, x * (y + 1)) || + rewrite(x + y * x, (y + 1) * x) || + rewrite(x * y + x, x * (y + 1)) || + rewrite(y * x + x, (y + 1) * x, !is_const(x)) || + rewrite(x + (x + y) / c0, (fold(c0 + 1) * x + y) / c0, c0 != 0) || + rewrite(x + (y + x) / c0, (fold(c0 + 1) * x + y) / c0, c0 != 0) || + rewrite(x + (y - x) / c0, (fold(c0 - 1) * x + y) / c0, c0 != 0) || + rewrite(x + (x - y) / c0, (fold(c0 + 1) * x - y) / c0, c0 != 0) || + rewrite((x - y) / c0 + x, (fold(c0 + 1) * x - y) / c0, c0 != 0) || + rewrite((y - x) / c0 + x, (y + fold(c0 - 1) * x) / c0, c0 != 0) || + rewrite((x + y) / c0 + x, (fold(c0 + 1) * x + y) / c0, c0 != 0) || + rewrite((y + x) / c0 + x, (y + fold(c0 + 1) * x) / c0, c0 != 0) || rewrite(min(x, y - z) + z, min(x + z, y)) || rewrite(min(y - z, x) + z, min(y, x + z)) || rewrite(min(x, y + c0) + c1, min(x + c1, y), c0 + c1 == 0) || @@ -150,49 +149,48 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) { rewrite(max(x, y) + min(x, y), x + y) || rewrite(max(x, y) + min(y, x), x + y) || - rewrite(min(x, y + (z*c0)) + (z*c1), min(x + (z*c1), y), (c0 + c1) == 0) || - rewrite(min(x, (y*c0) + z) + (y*c1), min(x + (y*c1), z), (c0 + c1) == 0) || - rewrite(min(x, y*c0) + (y*c1), min(x + (y*c1), 0), (c0 + c1) == 0) || - rewrite(min(x + (y*c0), z) + (y*c1), min((y*c1) + z, x), (c0 + c1) == 0) || - rewrite(min((x*c0) + y, z) + (x*c1), min(y, (x*c1) + z), (c0 + c1) == 0) || - rewrite(min(x*c0, y) + (x*c1), min((x*c1) + y, 0), (c0 + c1) == 0) || - rewrite(max(x, y + (z*c0)) + (z*c1), max(x + (z*c1), y), (c0 + c1) == 0) || - rewrite(max(x, (y*c0) + z) + (y*c1), max(x + (y*c1), z), (c0 + c1) == 0) || - rewrite(max(x, y*c0) + (y*c1), max(x + (y*c1), 0), (c0 + c1) == 0) || - rewrite(max(x + (y*c0), z) + (y*c1), max(x, (y*c1) + z), (c0 + c1) == 0) || - rewrite(max((x*c0) + y, z) + (x*c1), max((x*c1) + z, y), (c0 + c1) == 0) || - rewrite(max(x*c0, y) + (x*c1), max((x*c1) + y, 0), (c0 + c1) == 0) || + rewrite(min(x, y + (z * c0)) + (z * c1), min(x + (z * c1), y), (c0 + c1) == 0) || + rewrite(min(x, (y * c0) + z) + (y * c1), min(x + (y * c1), z), (c0 + c1) == 0) || + rewrite(min(x, y * c0) + (y * c1), min(x + (y * c1), 0), (c0 + c1) == 0) || + rewrite(min(x + (y * c0), z) + (y * c1), min((y * c1) + z, x), (c0 + c1) == 0) || + rewrite(min((x * c0) + y, z) + (x * c1), min(y, (x * c1) + z), (c0 + c1) == 0) || + rewrite(min(x * c0, y) + (x * c1), min((x * c1) + y, 0), (c0 + c1) == 0) || + rewrite(max(x, y + (z * c0)) + (z * c1), max(x + (z * c1), y), (c0 + c1) == 0) || + rewrite(max(x, (y * c0) + z) + (y * c1), max(x + (y * c1), z), (c0 + c1) == 0) || + rewrite(max(x, y * c0) + (y * c1), max(x + (y * c1), 0), (c0 + c1) == 0) || + rewrite(max(x + (y * c0), z) + (y * c1), max(x, (y * c1) + z), (c0 + c1) == 0) || + rewrite(max((x * c0) + y, z) + (x * c1), max((x * c1) + z, y), (c0 + c1) == 0) || + rewrite(max(x * c0, y) + (x * c1), max((x * c1) + y, 0), (c0 + c1) == 0) || false)) || (no_overflow_int(op->type) && - (rewrite((x*(y/x)) + (y % x), select(x == 0, 0, y)) || - rewrite(((x/y)*y) + (x % y), select(y == 0, 0, x)) || - rewrite(w*(z + x/w) + x%w, select(w == 0, 0, z*w + x)) || - rewrite((z + x/w)*w + x%w, select(w == 0, 0, z*w + x)) || - rewrite(w*(x/w + z) + x%w, select(w == 0, 0, x + z*w)) || - rewrite((x/w + z)*w + x%w, select(w == 0, 0, x + z*w)) || - rewrite(x%w + (w*(x/w) + z), select(w == 0, 0, x) + z) || - rewrite(x%w + ((x/w)*w + z), select(w == 0, 0, x) + z) || - rewrite(x%w + (w*(x/w) - z), select(w == 0, 0, x) - z) || - rewrite(x%w + ((x/w)*w - z), select(w == 0, 0, x) - z) || - rewrite(x%w + (z + w*(x/w)), select(w == 0, 0, x) + z) || - rewrite(x%w + (z + (x/w)*w), select(w == 0, 0, x) + z) || - rewrite(w*(x/w) + (x%w + z), select(w == 0, 0, x) + z) || - rewrite((x/w)*w + (x%w + z), select(w == 0, 0, x) + z) || - rewrite(w*(x/w) + (x%w - z), select(w == 0, 0, x) - z) || - rewrite((x/w)*w + (x%w - z), select(w == 0, 0, x) - z) || - rewrite(w*(x/w) + (z + x%w), select(w == 0, 0, x) + z) || - rewrite((x/w)*w + (z + x%w), select(w == 0, 0, x) + z) || - rewrite(x/2 + x%2, (x + 1) / 2) || - - rewrite(x + ((c0 - x)/c1)*c1, c0 - ((c0 - x) % c1), c1 > 0) || - rewrite(x + ((c0 - x)/c1 + y)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || - rewrite(x + (y + (c0 - x)/c1)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || + (rewrite((x * (y / x)) + (y % x), select(x == 0, 0, y)) || + rewrite(((x / y) * y) + (x % y), select(y == 0, 0, x)) || + rewrite(w * (z + x / w) + x % w, select(w == 0, 0, z * w + x)) || + rewrite((z + x / w) * w + x % w, select(w == 0, 0, z * w + x)) || + rewrite(w * (x / w + z) + x % w, select(w == 0, 0, x + z * w)) || + rewrite((x / w + z) * w + x % w, select(w == 0, 0, x + z * w)) || + rewrite(x % w + (w * (x / w) + z), select(w == 0, 0, x) + z) || + rewrite(x % w + ((x / w) * w + z), select(w == 0, 0, x) + z) || + rewrite(x % w + (w * (x / w) - z), select(w == 0, 0, x) - z) || + rewrite(x % w + ((x / w) * w - z), select(w == 0, 0, x) - z) || + rewrite(x % w + (z + w * (x / w)), select(w == 0, 0, x) + z) || + rewrite(x % w + (z + (x / w) * w), select(w == 0, 0, x) + z) || + rewrite(w * (x / w) + (x % w + z), select(w == 0, 0, x) + z) || + rewrite((x / w) * w + (x % w + z), select(w == 0, 0, x) + z) || + rewrite(w * (x / w) + (x % w - z), select(w == 0, 0, x) - z) || + rewrite((x / w) * w + (x % w - z), select(w == 0, 0, x) - z) || + rewrite(w * (x / w) + (z + x % w), select(w == 0, 0, x) + z) || + rewrite((x / w) * w + (z + x % w), select(w == 0, 0, x) + z) || + rewrite(x / 2 + x % 2, (x + 1) / 2) || + + rewrite(x + ((c0 - x) / c1) * c1, c0 - ((c0 - x) % c1), c1 > 0) || + rewrite(x + ((c0 - x) / c1 + y) * c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || + rewrite(x + (y + (c0 - x) / c1) * c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) || false)))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index 30cc60995dc6..3a1f0f6901ee 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -26,10 +26,8 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::and_op(a, b), op->type); - // clang-format off - // Cases that fold to a constant - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(x && false, false) || rewrite(false && x, false) || rewrite(x && neg(x), false) || @@ -82,7 +80,7 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { } // Cases that fold to one of the args - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(x && true, a) || rewrite(x && x, a) || rewrite(x && (x && y), b) || @@ -141,7 +139,7 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { return rewrite.result; } - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(broadcast(x, c0) && broadcast(y, c0), broadcast(x && y, c0)) || rewrite((x && broadcast(y, c0)) && broadcast(z, c0), x && broadcast(y && z, c0)) || rewrite((broadcast(x, c0) && y) && broadcast(z, c0), broadcast(x && z, c0) && y) || @@ -272,7 +270,6 @@ Expr Simplify::visit(const And *op, ExprInfo *info) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { diff --git a/src/Simplify_Div.cpp b/src/Simplify_Div.cpp index 09528525c3d4..be78d10a34f8 100644 --- a/src/Simplify_Div.cpp +++ b/src/Simplify_Div.cpp @@ -64,36 +64,35 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { int a_mod = a_info.alignment.modulus; int a_rem = a_info.alignment.remainder; - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(c0 / c1, fold(c0 / c1)) || (!op->type.is_float() && rewrite(x / 0, 0)) || (!op->type.is_float() && denominator_non_zero && rewrite(x / x, 1)) || rewrite(broadcast(x, c0) / broadcast(y, c0), broadcast(x / y, c0)) || - rewrite(select(x, c0, c1) / c2, select(x, fold(c0/c2), fold(c1/c2))) || + rewrite(select(x, c0, c1) / c2, select(x, fold(c0 / c2), fold(c1 / c2))) || (!op->type.is_float() && rewrite(x / x, select(x == 0, 0, 1))) || (no_overflow(op->type) && - (// Fold repeated division - rewrite((x / c0) / c2, x / fold(c0 * c2), c0 > 0 && c2 > 0 && !overflows(c0 * c2)) || - rewrite((x / c0 + c1) / c2, (x + fold(c1 * c0)) / fold(c0 * c2), c0 > 0 && c2 > 0 && !overflows(c0 * c2) && !overflows(c0 * c1)) || - rewrite((x * c0) / c1, x / fold(c1 / c0), c1 % c0 == 0 && c0 > 0 && c1 / c0 != 0) || + // Fold repeated division + (rewrite((x / c0) / c2, x / fold(c0 * c2), c0 > 0 && c2 > 0 && !overflows(c0 * c2)) || + rewrite((x / c0 + c1) / c2, (x + fold(c1 * c0)) / fold(c0 * c2), c0 > 0 && c2 > 0 && !overflows(c0 * c2) && !overflows(c0 * c1)) || + rewrite((x * c0) / c1, x / fold(c1 / c0), c1 % c0 == 0 && c0 > 0 && c1 / c0 != 0) || // Pull out terms that are a multiple of the denominator - rewrite((x * c0) / c1, x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || + rewrite((x * c0) / c1, x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(min((x * c0), c1) / c2, min(x * fold(c0 / c2), fold(c1 / c2)), c0 % c2 == 0 && c2 > 0) || rewrite(max((x * c0), c1) / c2, max(x * fold(c0 / c2), fold(c1 / c2)), c0 % c2 == 0 && c2 > 0) || - rewrite((x * c0 + y) / c1, y / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || + rewrite((x * c0 + y) / c1, y / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite((x * c0 - y) / c0, x + (0 - y) / c0) || - rewrite((x * c1 - y) / c0, (0 - y) / c0 - x, c0 + c1 == 0) || - rewrite((y + x * c0) / c1, y / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || - rewrite((y - x * c0) / c1, y / c1 - x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || + rewrite((x * c1 - y) / c0, (0 - y) / c0 - x, c0 + c1 == 0) || + rewrite((y + x * c0) / c1, y / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || + rewrite((y - x * c0) / c1, y / c1 - x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(((x * c0 + y) + z) / c1, (y + z) / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(((x * c0 - y) + z) / c1, (z - y) / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(((x * c0 + y) - z) / c1, (y - z) / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(((x * c0 - y) - z) / c0, x + (0 - y - z) / c0) || - rewrite(((x * c1 - y) - z) / c0, (0 - y - z) / c0 - x, c0 + c1 == 0) || + rewrite(((x * c1 - y) - z) / c0, (0 - y - z) / c0 - x, c0 + c1 == 0) || rewrite(((y + x * c0) + z) / c1, (y + z) / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || rewrite(((y + x * c0) - z) / c1, (y - z) / c1 + x * fold(c0 / c1), c0 % c1 == 0 && c1 > 0) || @@ -162,7 +161,7 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { c2 == a_rem / c1 + (c0 - a_rem) / c1 */ - rewrite((c0 - x)/c1, fold(a_rem / c1 + (c0 - a_rem) / c1) - x / c1, a_mod % c1 == 0) || + rewrite((c0 - x) / c1, fold(a_rem / c1 + (c0 - a_rem) / c1) - x / c1, a_mod % c1 == 0) || // We can also pull it out when the constant is a // multiple of the denominator. @@ -170,34 +169,32 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { rewrite((c0 - x) / c1, fold(c0 / c1) - x / c1, (c0 + 1) % c1 == 0))) || (denominator_non_zero && - (rewrite((x + y)/x, y/x + 1) || - rewrite((y + x)/x, y/x + 1) || - rewrite((x - y)/x, (-y)/x + 1) || - rewrite((y - x)/x, y/x - 1) || - rewrite(((x + y) + z)/x, (y + z)/x + 1) || - rewrite(((y + x) + z)/x, (y + z)/x + 1) || - rewrite((z + (x + y))/x, (z + y)/x + 1) || - rewrite((z + (y + x))/x, (z + y)/x + 1) || - rewrite((x*y)/x, y) || - rewrite((y*x)/x, y) || - rewrite((x*y + z)/x, y + z/x) || - rewrite((y*x + z)/x, y + z/x) || - rewrite((z + x*y)/x, z/x + y) || - rewrite((z + y*x)/x, z/x + y) || - rewrite((x*y - z)/x, y + (-z)/x) || - rewrite((y*x - z)/x, y + (-z)/x) || - rewrite((z - x*y)/x, z/x - y) || - rewrite((z - y*x)/x, z/x - y) || + (rewrite((x + y) / x, y / x + 1) || + rewrite((y + x) / x, y / x + 1) || + rewrite((x - y) / x, (-y) / x + 1) || + rewrite((y - x) / x, y / x - 1) || + rewrite(((x + y) + z) / x, (y + z) / x + 1) || + rewrite(((y + x) + z) / x, (y + z) / x + 1) || + rewrite((z + (x + y)) / x, (z + y) / x + 1) || + rewrite((z + (y + x)) / x, (z + y) / x + 1) || + rewrite((x * y) / x, y) || + rewrite((y * x) / x, y) || + rewrite((x * y + z) / x, y + z / x) || + rewrite((y * x + z) / x, y + z / x) || + rewrite((z + x * y) / x, z / x + y) || + rewrite((z + y * x) / x, z / x + y) || + rewrite((x * y - z) / x, y + (-z) / x) || + rewrite((y * x - z) / x, y + (-z) / x) || + rewrite((z - x * y) / x, z / x - y) || + rewrite((z - y * x) / x, z / x - y) || false)) || - (op->type.is_float() && rewrite(x/c0, x * fold(1/c0))))) || + (op->type.is_float() && rewrite(x / c0, x * fold(1 / c0))))) || (no_overflow_int(op->type) && - ( - rewrite(ramp(x, c0, lanes) / broadcast(c1, lanes), ramp(x / c1, fold(c0 / c1), lanes), (c0 % c1 == 0)) || + (rewrite(ramp(x, c0, lanes) / broadcast(c1, lanes), ramp(x / c1, fold(c0 / c1), lanes), (c0 % c1 == 0)) || rewrite(ramp(x, c0, lanes) / broadcast(c1, lanes), broadcast(x / c1, lanes), // First and last lanes are the same when... - can_prove((x % c1 + c0 * (lanes - 1)) / c1 == 0, this)) - )) || + can_prove((x % c1 + c0 * (lanes - 1)) / c1 == 0, this)))) || (no_overflow_scalar_int(op->type) && (rewrite(x / -1, -x) || (denominator_non_zero && rewrite(c0 / y, select(y < 0, fold(-c0), c0), c0 == -1)) || @@ -211,7 +208,6 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) { rewrite((x % 2 + c0) / 2, x % 2 + fold(c0 / 2), c0 % 2 == 1))))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_EQ.cpp b/src/Simplify_EQ.cpp index 252e6e1de7bc..994d14cd4cee 100644 --- a/src/Simplify_EQ.cpp +++ b/src/Simplify_EQ.cpp @@ -59,8 +59,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::eq(a, b), op->type, a.type()); - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(x == x, true) || rewrite(c0 == c1, fold(c0 == c1)) || rewrite(min(x, y) == min(y, x), true) || @@ -197,7 +196,7 @@ Expr Simplify::visit(const EQ *op, ExprInfo *info) { rewrite(slice(x, c0, c1, c2) == slice(y, c0, c1, c2) + z, slice(x - y, c0, c1, c2) == z, c2 > 1 && lanes_of(x) == lanes_of(y)) || false) || - (no_overflow(a.type()) && EVAL_IN_LAMBDA + (no_overflow(a.type()) && EVAL_IN_LAMBDA // (rewrite(x * y == 0, (x == 0) || (y == 0)) || rewrite(x * y == x, (x == 0) || (y == 1)) || rewrite(x == x * y, (x == 0) || (y == 1)) || @@ -320,7 +319,6 @@ Expr Simplify::visit(const EQ *op, ExprInfo *info) { false) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_LT.cpp b/src/Simplify_LT.cpp index 5dd234fa1664..ca5f7fe307e9 100644 --- a/src/Simplify_LT.cpp +++ b/src/Simplify_LT.cpp @@ -30,8 +30,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::lt(a, b), op->type, ty); - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(c0 < c1, fold(c0 < c1)) || rewrite(x < x, false) || rewrite(x < ty.min(), false) || @@ -56,7 +55,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { rewrite(x % c0 < c1, x % c0 != fold(c0 - 1), c1 + 1 == c0 && c0 > 0) || rewrite(c0 < x % c1, x % c1 == fold(c1 - 1), c0 + 2 == c1 && c1 > 0)) || - (no_overflow(ty) && EVAL_IN_LAMBDA + (no_overflow(ty) && EVAL_IN_LAMBDA // (rewrite(ramp(x, y, c0) < ramp(z, y, c0), broadcast(x < z, c0)) || // Move constants to the RHS rewrite(x + c0 < y, x < y + fold(-c0)) || @@ -153,7 +152,7 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { rewrite(x * c0 < y * c0, x < y, c0 > 0) || rewrite(x * c0 < y * c0, y < x, c0 < 0) || - (ty.is_int() && rewrite(x * c0 < c1, x < fold((c1 + c0 - 1) / c0), c0 > 0)) || + (ty.is_int() && rewrite(x * c0 < c1, x < fold((c1 + c0 - 1) / c0), c0 > 0)) || (ty.is_float() && rewrite(x * c0 < c1, x < fold(c1 / c0), c0 > 0)) || (ty.is_float() && rewrite(x * c0 < c1, fold(c1 / c0) < x, c0 < 0)) || rewrite(c1 < x * c0, fold(c1 / c0) < x, c0 > 0) || @@ -323,12 +322,12 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { can_prove(z >= x + fold(max(0, c1 * (lanes - 1))), this)) || false)) || - (no_overflow_int(ty) && EVAL_IN_LAMBDA + (no_overflow_int(ty) && EVAL_IN_LAMBDA // (rewrite(x * c0 < y * c1, x < y * fold(c1 / c0), c1 % c0 == 0 && c0 > 0) || rewrite(x * c0 < y * c1, x * fold(c0 / c1) < y, c0 % c1 == 0 && c1 > 0) || - rewrite(x * c0 < y * c0 + c1, x < y + fold((c1 + c0 - 1)/c0), c0 > 0) || - rewrite(x * c0 + c1 < y * c0, x + fold(c1/c0) < y, c0 > 0) || + rewrite(x * c0 < y * c0 + c1, x < y + fold((c1 + c0 - 1) / c0), c0 > 0) || + rewrite(x * c0 + c1 < y * c0, x + fold(c1 / c0) < y, c0 > 0) || // Comparison of stair-step functions. The basic transformation is: // ((x + y)/c1)*c1 < x @@ -338,140 +337,140 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { // This cancels x but duplicates y, so we only do it when y is a constant. // A more general version with extra terms w and z - rewrite(((x + c0)/c1)*c1 + w < x + z, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || - rewrite(w + ((x + c0)/c1)*c1 < x + z, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || - rewrite(((x + c0)/c1)*c1 + w < z + x, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || - rewrite(w + ((x + c0)/c1)*c1 < z + x, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || - rewrite(x + z < ((x + c0)/c1)*c1 + w, ((x + c0) % c1) + z < w + c0, c1 > 0) || - rewrite(x + z < w + ((x + c0)/c1)*c1, ((x + c0) % c1) + z < w + c0, c1 > 0) || - rewrite(z + x < ((x + c0)/c1)*c1 + w, ((x + c0) % c1) + z < w + c0, c1 > 0) || - rewrite(z + x < w + ((x + c0)/c1)*c1, ((x + c0) % c1) + z < w + c0, c1 > 0) || + rewrite(((x + c0) / c1) * c1 + w < x + z, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || + rewrite(w + ((x + c0) / c1) * c1 < x + z, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || + rewrite(((x + c0) / c1) * c1 + w < z + x, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || + rewrite(w + ((x + c0) / c1) * c1 < z + x, (w + c0) < ((x + c0) % c1) + z, c1 > 0) || + rewrite(x + z < ((x + c0) / c1) * c1 + w, ((x + c0) % c1) + z < w + c0, c1 > 0) || + rewrite(x + z < w + ((x + c0) / c1) * c1, ((x + c0) % c1) + z < w + c0, c1 > 0) || + rewrite(z + x < ((x + c0) / c1) * c1 + w, ((x + c0) % c1) + z < w + c0, c1 > 0) || + rewrite(z + x < w + ((x + c0) / c1) * c1, ((x + c0) % c1) + z < w + c0, c1 > 0) || // w = 0 - rewrite(((x + c0)/c1)*c1 < x + z, c0 < ((x + c0) % c1) + z, c1 > 0) || - rewrite(((x + c0)/c1)*c1 < z + x, c0 < ((x + c0) % c1) + z, c1 > 0) || - rewrite(x + z < ((x + c0)/c1)*c1, ((x + c0) % c1) + z < c0, c1 > 0) || - rewrite(z + x < ((x + c0)/c1)*c1, ((x + c0) % c1) + z < c0, c1 > 0) || + rewrite(((x + c0) / c1) * c1 < x + z, c0 < ((x + c0) % c1) + z, c1 > 0) || + rewrite(((x + c0) / c1) * c1 < z + x, c0 < ((x + c0) % c1) + z, c1 > 0) || + rewrite(x + z < ((x + c0) / c1) * c1, ((x + c0) % c1) + z < c0, c1 > 0) || + rewrite(z + x < ((x + c0) / c1) * c1, ((x + c0) % c1) + z < c0, c1 > 0) || // z = 0 - rewrite(((x + c0)/c1)*c1 + w < x, (w + c0) < ((x + c0) % c1), c1 > 0) || - rewrite(w + ((x + c0)/c1)*c1 < x, (w + c0) < ((x + c0) % c1), c1 > 0) || - rewrite(x < ((x + c0)/c1)*c1 + w, ((x + c0) % c1) < w + c0, c1 > 0) || - rewrite(x < w + ((x + c0)/c1)*c1, ((x + c0) % c1) < w + c0, c1 > 0) || + rewrite(((x + c0) / c1) * c1 + w < x, (w + c0) < ((x + c0) % c1), c1 > 0) || + rewrite(w + ((x + c0) / c1) * c1 < x, (w + c0) < ((x + c0) % c1), c1 > 0) || + rewrite(x < ((x + c0) / c1) * c1 + w, ((x + c0) % c1) < w + c0, c1 > 0) || + rewrite(x < w + ((x + c0) / c1) * c1, ((x + c0) % c1) < w + c0, c1 > 0) || // c0 = 0 - rewrite((x/c1)*c1 + w < x + z, w < (x % c1) + z, c1 > 0) || - rewrite(w + (x/c1)*c1 < x + z, w < (x % c1) + z, c1 > 0) || - rewrite((x/c1)*c1 + w < z + x, w < (x % c1) + z, c1 > 0) || - rewrite(w + (x/c1)*c1 < z + x, w < (x % c1) + z, c1 > 0) || - rewrite(x + z < (x/c1)*c1 + w, (x % c1) + z < w, c1 > 0) || - rewrite(x + z < w + (x/c1)*c1, (x % c1) + z < w, c1 > 0) || - rewrite(z + x < (x/c1)*c1 + w, (x % c1) + z < w, c1 > 0) || - rewrite(z + x < w + (x/c1)*c1, (x % c1) + z < w, c1 > 0) || + rewrite((x / c1) * c1 + w < x + z, w < (x % c1) + z, c1 > 0) || + rewrite(w + (x / c1) * c1 < x + z, w < (x % c1) + z, c1 > 0) || + rewrite((x / c1) * c1 + w < z + x, w < (x % c1) + z, c1 > 0) || + rewrite(w + (x / c1) * c1 < z + x, w < (x % c1) + z, c1 > 0) || + rewrite(x + z < (x / c1) * c1 + w, (x % c1) + z < w, c1 > 0) || + rewrite(x + z < w + (x / c1) * c1, (x % c1) + z < w, c1 > 0) || + rewrite(z + x < (x / c1) * c1 + w, (x % c1) + z < w, c1 > 0) || + rewrite(z + x < w + (x / c1) * c1, (x % c1) + z < w, c1 > 0) || // w = 0, z = 0 - rewrite(((x + c0)/c1)*c1 < x, c0 < ((x + c0) % c1), c1 > 0) || - rewrite(x < ((x + c0)/c1)*c1, ((x + c0) % c1) < c0, c1 > 0) || + rewrite(((x + c0) / c1) * c1 < x, c0 < ((x + c0) % c1), c1 > 0) || + rewrite(x < ((x + c0) / c1) * c1, ((x + c0) % c1) < c0, c1 > 0) || // w = 0, c0 = 0 - rewrite((x/c1)*c1 < x + z, 0 < (x % c1) + z, c1 > 0) || - rewrite((x/c1)*c1 < z + x, 0 < (x % c1) + z, c1 > 0) || - rewrite(x + z < (x/c1)*c1, (x % c1) + z < 0, c1 > 0) || - rewrite(z + x < (x/c1)*c1, (x % c1) + z < 0, c1 > 0) || + rewrite((x / c1) * c1 < x + z, 0 < (x % c1) + z, c1 > 0) || + rewrite((x / c1) * c1 < z + x, 0 < (x % c1) + z, c1 > 0) || + rewrite(x + z < (x / c1) * c1, (x % c1) + z < 0, c1 > 0) || + rewrite(z + x < (x / c1) * c1, (x % c1) + z < 0, c1 > 0) || // z = 0, c0 = 0 - rewrite((x/c1)*c1 + w < x, w < (x % c1), c1 > 0) || - rewrite(w + (x/c1)*c1 < x, w < (x % c1), c1 > 0) || - rewrite(x < (x/c1)*c1 + w, (x % c1) < w, c1 > 0) || - rewrite(x < w + (x/c1)*c1, (x % c1) < w, c1 > 0) || + rewrite((x / c1) * c1 + w < x, w < (x % c1), c1 > 0) || + rewrite(w + (x / c1) * c1 < x, w < (x % c1), c1 > 0) || + rewrite(x < (x / c1) * c1 + w, (x % c1) < w, c1 > 0) || + rewrite(x < w + (x / c1) * c1, (x % c1) < w, c1 > 0) || // z = 0, c0 = 0, w = 0 - rewrite((x/c1)*c1 < x, (x % c1) != 0, c1 > 0) || - rewrite(x < (x/c1)*c1, false, c1 > 0) || + rewrite((x / c1) * c1 < x, (x % c1) != 0, c1 > 0) || + rewrite(x < (x / c1) * c1, false, c1 > 0) || // Cancel a division - rewrite((x + c1)/c0 < (x + c2)/c0, false, c0 > 0 && c1 >= c2) || - rewrite((x + c1)/c0 < (x + c2)/c0, true, c0 > 0 && c1 <= c2 - c0) || + rewrite((x + c1) / c0 < (x + c2) / c0, false, c0 > 0 && c1 >= c2) || + rewrite((x + c1) / c0 < (x + c2) / c0, true, c0 > 0 && c1 <= c2 - c0) || // c1 == 0 - rewrite(x/c0 < (x + c2)/c0, false, c0 > 0 && 0 >= c2) || - rewrite(x/c0 < (x + c2)/c0, true, c0 > 0 && 0 <= c2 - c0) || + rewrite(x / c0 < (x + c2) / c0, false, c0 > 0 && 0 >= c2) || + rewrite(x / c0 < (x + c2) / c0, true, c0 > 0 && 0 <= c2 - c0) || // c2 == 0 - rewrite((x + c1)/c0 < x/c0, false, c0 > 0 && c1 >= 0) || - rewrite((x + c1)/c0 < x/c0, true, c0 > 0 && c1 <= 0 - c0) || + rewrite((x + c1) / c0 < x / c0, false, c0 > 0 && c1 >= 0) || + rewrite((x + c1) / c0 < x / c0, true, c0 > 0 && c1 <= 0 - c0) || // The addition on the right could be outside - rewrite((x + c1)/c0 < x/c0 + c2, false, c0 > 0 && c1 >= c2 * c0) || - rewrite((x + c1)/c0 < x/c0 + c2, true, c0 > 0 && c1 <= c2 * c0 - c0) || + rewrite((x + c1) / c0 < x / c0 + c2, false, c0 > 0 && c1 >= c2 * c0) || + rewrite((x + c1) / c0 < x / c0 + c2, true, c0 > 0 && c1 <= c2 * c0 - c0) || // With a confounding max or min - rewrite((x + c1)/c0 < (min(x/c0, y) + c2), false, c0 > 0 && c1 >= c2 * c0) || - rewrite((x + c1)/c0 < (max(x/c0, y) + c2), true, c0 > 0 && c1 <= c2 * c0 - c0) || - rewrite((x + c1)/c0 < min((x + c2)/c0, y), false, c0 > 0 && c1 >= c2) || - rewrite((x + c1)/c0 < max((x + c2)/c0, y), true, c0 > 0 && c1 <= c2 - c0) || - rewrite((x + c1)/c0 < min(x/c0, y), false, c0 > 0 && c1 >= 0) || - rewrite((x + c1)/c0 < max(x/c0, y), true, c0 > 0 && c1 <= 0 - c0) || - - rewrite((x + c1)/c0 < (min(y, x/c0) + c2), false, c0 > 0 && c1 >= c2 * c0) || - rewrite((x + c1)/c0 < (max(y, x/c0) + c2), true, c0 > 0 && c1 <= c2 * c0 - c0) || - rewrite((x + c1)/c0 < min(y, (x + c2)/c0), false, c0 > 0 && c1 >= c2) || - rewrite((x + c1)/c0 < max(y, (x + c2)/c0), true, c0 > 0 && c1 <= c2 - c0) || - rewrite((x + c1)/c0 < min(y, x/c0), false, c0 > 0 && c1 >= 0) || - rewrite((x + c1)/c0 < max(y, x/c0), true, c0 > 0 && c1 <= 0 - c0) || - - rewrite(max((x + c2)/c0, y) < (x + c1)/c0, false, c0 > 0 && c2 >= c1) || - rewrite(min((x + c2)/c0, y) < (x + c1)/c0, true, c0 > 0 && c2 <= c1 - c0) || - rewrite(max(x/c0, y) < (x + c1)/c0, false, c0 > 0 && 0 >= c1) || - rewrite(min(x/c0, y) < (x + c1)/c0, true, c0 > 0 && 0 <= c1 - c0) || - rewrite(max(y, (x + c2)/c0) < (x + c1)/c0, false, c0 > 0 && c2 >= c1) || - rewrite(min(y, (x + c2)/c0) < (x + c1)/c0, true, c0 > 0 && c2 <= c1 - c0) || - rewrite(max(y, x/c0) < (x + c1)/c0, false, c0 > 0 && 0 >= c1) || - rewrite(min(y, x/c0) < (x + c1)/c0, true, c0 > 0 && 0 <= c1 - c0) || + rewrite((x + c1) / c0 < (min(x / c0, y) + c2), false, c0 > 0 && c1 >= c2 * c0) || + rewrite((x + c1) / c0 < (max(x / c0, y) + c2), true, c0 > 0 && c1 <= c2 * c0 - c0) || + rewrite((x + c1) / c0 < min((x + c2) / c0, y), false, c0 > 0 && c1 >= c2) || + rewrite((x + c1) / c0 < max((x + c2) / c0, y), true, c0 > 0 && c1 <= c2 - c0) || + rewrite((x + c1) / c0 < min(x / c0, y), false, c0 > 0 && c1 >= 0) || + rewrite((x + c1) / c0 < max(x / c0, y), true, c0 > 0 && c1 <= 0 - c0) || + + rewrite((x + c1) / c0 < (min(y, x / c0) + c2), false, c0 > 0 && c1 >= c2 * c0) || + rewrite((x + c1) / c0 < (max(y, x / c0) + c2), true, c0 > 0 && c1 <= c2 * c0 - c0) || + rewrite((x + c1) / c0 < min(y, (x + c2) / c0), false, c0 > 0 && c1 >= c2) || + rewrite((x + c1) / c0 < max(y, (x + c2) / c0), true, c0 > 0 && c1 <= c2 - c0) || + rewrite((x + c1) / c0 < min(y, x / c0), false, c0 > 0 && c1 >= 0) || + rewrite((x + c1) / c0 < max(y, x / c0), true, c0 > 0 && c1 <= 0 - c0) || + + rewrite(max((x + c2) / c0, y) < (x + c1) / c0, false, c0 > 0 && c2 >= c1) || + rewrite(min((x + c2) / c0, y) < (x + c1) / c0, true, c0 > 0 && c2 <= c1 - c0) || + rewrite(max(x / c0, y) < (x + c1) / c0, false, c0 > 0 && 0 >= c1) || + rewrite(min(x / c0, y) < (x + c1) / c0, true, c0 > 0 && 0 <= c1 - c0) || + rewrite(max(y, (x + c2) / c0) < (x + c1) / c0, false, c0 > 0 && c2 >= c1) || + rewrite(min(y, (x + c2) / c0) < (x + c1) / c0, true, c0 > 0 && c2 <= c1 - c0) || + rewrite(max(y, x / c0) < (x + c1) / c0, false, c0 > 0 && 0 >= c1) || + rewrite(min(y, x / c0) < (x + c1) / c0, true, c0 > 0 && 0 <= c1 - c0) || // Same as above with c1 outside the division, with redundant cases removed. - rewrite(max((x + c2)/c0, y) < x/c0 + c1, false, c0 > 0 && c2 >= c1 * c0) || - rewrite(min((x + c2)/c0, y) < x/c0 + c1, true, c0 > 0 && c2 <= c1 * c0 - c0) || - rewrite(max(y, (x + c2)/c0) < x/c0 + c1, false, c0 > 0 && c2 >= c1 * c0) || - rewrite(min(y, (x + c2)/c0) < x/c0 + c1, true, c0 > 0 && c2 <= c1 * c0 - c0) || + rewrite(max((x + c2) / c0, y) < x / c0 + c1, false, c0 > 0 && c2 >= c1 * c0) || + rewrite(min((x + c2) / c0, y) < x / c0 + c1, true, c0 > 0 && c2 <= c1 * c0 - c0) || + rewrite(max(y, (x + c2) / c0) < x / c0 + c1, false, c0 > 0 && c2 >= c1 * c0) || + rewrite(min(y, (x + c2) / c0) < x / c0 + c1, true, c0 > 0 && c2 <= c1 * c0 - c0) || // Same as above with c1 = 0 and the predicates and redundant cases simplified accordingly. - rewrite(x/c0 < min((x + c2)/c0, y), false, c0 > 0 && c2 < 0) || - rewrite(x/c0 < max((x + c2)/c0, y), true, c0 > 0 && c0 <= c2) || - rewrite(x/c0 < min(y, (x + c2)/c0), false, c0 > 0 && c2 < 0) || - rewrite(x/c0 < max(y, (x + c2)/c0), true, c0 > 0 && c0 <= c2) || - rewrite(max((x + c2)/c0, y) < x/c0, false, c0 > 0 && c2 >= 0) || - rewrite(min((x + c2)/c0, y) < x/c0, true, c0 > 0 && c2 + c0 <= 0) || - rewrite(max(y, (x + c2)/c0) < x/c0, false, c0 > 0 && c2 >= 0) || - rewrite(min(y, (x + c2)/c0) < x/c0, true, c0 > 0 && c2 + c0 <= 0) || - - rewrite(((max(x, (y*c0) + c1) + c2)/c0) < y, ((x + c2)/c0) < y, c0 > 0 && (c1 + c2) < 0) || - rewrite(((max(x, (y*c0) + c1) + c2)/c0) < y, false, c0 > 0 && (c1 + c2) >= 0) || - rewrite(((max(x, y*c0) + c1)/c0) < y, ((x + c1)/c0) < y, c0 > 0 && c1 < 0) || - rewrite(((max(x, y*c0) + c1)/c0) < y, false, c0 > 0 && c1 >= 0) || - rewrite(((max((x*c0) + c1, y) + c2)/c0) < x, ((y + c2)/c0) < x, c0 > 0 && (c1 + c2) < 0) || - rewrite(((max((x*c0) + c1, y) + c2)/c0) < x, false, c0 > 0 && (c1 + c2) >= 0) || - rewrite(((max(x*c0, y) + c1)/c0) < x, ((y + c1)/c0) < x, c0 > 0 && c1 < 0) || - rewrite(((max(x*c0, y) + c1)/c0) < x, false, c0 > 0 && c1 >= 0) || - rewrite((max(x, (y*c0) + c1)/c0) < y, (x/c0) < y, c0 > 0 && c1 < 0) || - rewrite((max(x, (y*c0) + c1)/c0) < y, false, c0 > 0 && c1 >= 0) || - rewrite((max(x, y*c0)/c0) < y, false, c0 > 0) || - rewrite((max((x*c0) + c1, y)/c0) < x, (y/c0) < x, c0 > 0 && c1 < 0) || - rewrite((max((x*c0) + c1, y)/c0) < x, false, c0 > 0 && c1 >= 0) || - rewrite((max(x*c0, y)/c0) < x, false, c0 > 0) || - - rewrite(((min(x, (y*c0) + c1) + c2)/c0) < y, true, c0 > 0 && (c1 + c2) < 0) || - rewrite(((min(x, (y*c0) + c1) + c2)/c0) < y, ((x + c2)/c0) < y, c0 > 0 && (c1 + c2) >= 0) || - rewrite(((min(x, y*c0) + c1)/c0) < y, true, c0 > 0 && c1 < 0) || - rewrite(((min(x, y*c0) + c1)/c0) < y, ((x + c1)/c0) < y, c0 > 0 && c1 >= 0) || - rewrite(((min((x*c0) + c1, y) + c2)/c0) < x, true, c0 > 0 && (c1 + c2) < 0) || - rewrite(((min((x*c0) + c1, y) + c2)/c0) < x, ((y + c2)/c0) < x, c0 > 0 && (c1 + c2) >= 0) || - rewrite(((min(x*c0, y) + c1)/c0) < x, true, c0 > 0 && c1 < 0) || - rewrite(((min(x*c0, y) + c1)/c0) < x, ((y + c1)/c0) < x, c0 > 0 && c1 >= 0) || - rewrite((min(x, (y*c0) + c1)/c0) < y, true, c0 > 0 && c1 < 0) || - rewrite((min(x, (y*c0) + c1)/c0) < y, (x/c0) < y, c0 > 0 && c1 >= 0) || - rewrite((min(x, y*c0)/c0) < y, (x/c0) < y, c0 > 0) || - rewrite((min((x*c0) + c1, y)/c0) < x, true, c0 > 0 && c1 < 0) || - rewrite((min((x*c0) + c1, y)/c0) < x, (y/c0) < x, c0 > 0 && c1 >= 0) || - rewrite((min(x*c0, y)/c0) < x, (y/c0) < x, c0 > 0) || + rewrite(x / c0 < min((x + c2) / c0, y), false, c0 > 0 && c2 < 0) || + rewrite(x / c0 < max((x + c2) / c0, y), true, c0 > 0 && c0 <= c2) || + rewrite(x / c0 < min(y, (x + c2) / c0), false, c0 > 0 && c2 < 0) || + rewrite(x / c0 < max(y, (x + c2) / c0), true, c0 > 0 && c0 <= c2) || + rewrite(max((x + c2) / c0, y) < x / c0, false, c0 > 0 && c2 >= 0) || + rewrite(min((x + c2) / c0, y) < x / c0, true, c0 > 0 && c2 + c0 <= 0) || + rewrite(max(y, (x + c2) / c0) < x / c0, false, c0 > 0 && c2 >= 0) || + rewrite(min(y, (x + c2) / c0) < x / c0, true, c0 > 0 && c2 + c0 <= 0) || + + rewrite(((max(x, (y * c0) + c1) + c2) / c0) < y, ((x + c2) / c0) < y, c0 > 0 && (c1 + c2) < 0) || + rewrite(((max(x, (y * c0) + c1) + c2) / c0) < y, false, c0 > 0 && (c1 + c2) >= 0) || + rewrite(((max(x, y * c0) + c1) / c0) < y, ((x + c1) / c0) < y, c0 > 0 && c1 < 0) || + rewrite(((max(x, y * c0) + c1) / c0) < y, false, c0 > 0 && c1 >= 0) || + rewrite(((max((x * c0) + c1, y) + c2) / c0) < x, ((y + c2) / c0) < x, c0 > 0 && (c1 + c2) < 0) || + rewrite(((max((x * c0) + c1, y) + c2) / c0) < x, false, c0 > 0 && (c1 + c2) >= 0) || + rewrite(((max(x * c0, y) + c1) / c0) < x, ((y + c1) / c0) < x, c0 > 0 && c1 < 0) || + rewrite(((max(x * c0, y) + c1) / c0) < x, false, c0 > 0 && c1 >= 0) || + rewrite((max(x, (y * c0) + c1) / c0) < y, (x / c0) < y, c0 > 0 && c1 < 0) || + rewrite((max(x, (y * c0) + c1) / c0) < y, false, c0 > 0 && c1 >= 0) || + rewrite((max(x, y * c0) / c0) < y, false, c0 > 0) || + rewrite((max((x * c0) + c1, y) / c0) < x, (y / c0) < x, c0 > 0 && c1 < 0) || + rewrite((max((x * c0) + c1, y) / c0) < x, false, c0 > 0 && c1 >= 0) || + rewrite((max(x * c0, y) / c0) < x, false, c0 > 0) || + + rewrite(((min(x, (y * c0) + c1) + c2) / c0) < y, true, c0 > 0 && (c1 + c2) < 0) || + rewrite(((min(x, (y * c0) + c1) + c2) / c0) < y, ((x + c2) / c0) < y, c0 > 0 && (c1 + c2) >= 0) || + rewrite(((min(x, y * c0) + c1) / c0) < y, true, c0 > 0 && c1 < 0) || + rewrite(((min(x, y * c0) + c1) / c0) < y, ((x + c1) / c0) < y, c0 > 0 && c1 >= 0) || + rewrite(((min((x * c0) + c1, y) + c2) / c0) < x, true, c0 > 0 && (c1 + c2) < 0) || + rewrite(((min((x * c0) + c1, y) + c2) / c0) < x, ((y + c2) / c0) < x, c0 > 0 && (c1 + c2) >= 0) || + rewrite(((min(x * c0, y) + c1) / c0) < x, true, c0 > 0 && c1 < 0) || + rewrite(((min(x * c0, y) + c1) / c0) < x, ((y + c1) / c0) < x, c0 > 0 && c1 >= 0) || + rewrite((min(x, (y * c0) + c1) / c0) < y, true, c0 > 0 && c1 < 0) || + rewrite((min(x, (y * c0) + c1) / c0) < y, (x / c0) < y, c0 > 0 && c1 >= 0) || + rewrite((min(x, y * c0) / c0) < y, (x / c0) < y, c0 > 0) || + rewrite((min((x * c0) + c1, y) / c0) < x, true, c0 > 0 && c1 < 0) || + rewrite((min((x * c0) + c1, y) / c0) < x, (y / c0) < x, c0 > 0 && c1 >= 0) || + rewrite((min(x * c0, y) / c0) < x, (y / c0) < x, c0 > 0) || // Comparison of two mins/maxes that don't cancel when subtracted rewrite(min(x, c0) < min(x, c1), false, c0 >= c1) || @@ -481,20 +480,18 @@ Expr Simplify::visit(const LT *op, ExprInfo *info) { // Comparison of aligned ramps can simplify to a comparison of the base rewrite(ramp(x * c3 + c2, c1, lanes) < broadcast(z * c0, lanes), - broadcast(x * fold(c3/c0) + fold(c2/c0) < z, lanes), - c0 > 0 && (c3 % c0 == 0) && - (c2 % c0) + c1 * (lanes - 1) < c0 && - (c2 % c0) + c1 * (lanes - 1) >= 0) || + broadcast(x * fold(c3 / c0) + fold(c2 / c0) < z, lanes), + (c0 > 0 && (c3 % c0 == 0) && + (c2 % c0) + c1 * (lanes - 1) < c0 && + (c2 % c0) + c1 * (lanes - 1) >= 0)) || // c2 = 0 rewrite(ramp(x * c3, c1, lanes) < broadcast(z * c0, lanes), - broadcast(x * fold(c3/c0) < z, lanes), + broadcast(x * fold(c3 / c0) < z, lanes), c0 > 0 && (c3 % c0 == 0) && - c1 * (lanes - 1) < c0 && - c1 * (lanes - 1) >= 0) - ))) { + c1 * (lanes - 1) < c0 && + c1 * (lanes - 1) >= 0)))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index 1d9889042ed3..1926bc9a069e 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -65,15 +65,14 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { return rewrite.result; } - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(max(x, x), a) || rewrite(max(c0, c1), fold(max(c0, c1))) || // Cases where one side dominates: rewrite(max(x, c0), b, is_max_value(c0)) || rewrite(max(x, c0), a, is_min_value(c0)) || - rewrite(max((x/c0)*c0, x), b, c0 > 0) || - rewrite(max(x, (x/c0)*c0), a, c0 > 0) || + rewrite(max((x / c0) * c0, x), b, c0 > 0) || + rewrite(max(x, (x / c0) * c0), a, c0 > 0) || rewrite(max(max(x, y), x), a) || rewrite(max(max(x, y), y), a) || rewrite(max(max(max(x, y), z), x), a) || @@ -117,20 +116,20 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { rewrite(max(ramp(x, y, lanes), broadcast(z, lanes)), b, can_prove(x + y * (lanes - 1) <= z && x <= z, this)) || // Compare x to a stair-step function in x - rewrite(max(((x + c0)/c1)*c1 + c2, x), a, c1 > 0 && c0 + c2 >= c1 - 1) || - rewrite(max(x, ((x + c0)/c1)*c1 + c2), b, c1 > 0 && c0 + c2 >= c1 - 1) || - rewrite(max(((x + c0)/c1)*c1 + c2, x), b, c1 > 0 && c0 + c2 <= 0) || - rewrite(max(x, ((x + c0)/c1)*c1 + c2), a, c1 > 0 && c0 + c2 <= 0) || - rewrite(max((x/c0)*c0, (x/c1)*c1 + c2), b, c2 >= c1 && c1 > 0 && c0 != 0) || + rewrite(max(((x + c0) / c1) * c1 + c2, x), a, c1 > 0 && c0 + c2 >= c1 - 1) || + rewrite(max(x, ((x + c0) / c1) * c1 + c2), b, c1 > 0 && c0 + c2 >= c1 - 1) || + rewrite(max(((x + c0) / c1) * c1 + c2, x), b, c1 > 0 && c0 + c2 <= 0) || + rewrite(max(x, ((x + c0) / c1) * c1 + c2), a, c1 > 0 && c0 + c2 <= 0) || + rewrite(max((x / c0) * c0, (x / c1) * c1 + c2), b, c2 >= c1 && c1 > 0 && c0 != 0) || // Special cases where c0 or c2 is zero - rewrite(max((x/c1)*c1 + c2, x), a, c1 > 0 && c2 >= c1 - 1) || - rewrite(max(x, (x/c1)*c1 + c2), b, c1 > 0 && c2 >= c1 - 1) || - rewrite(max(((x + c0)/c1)*c1, x), a, c1 > 0 && c0 >= c1 - 1) || - rewrite(max(x, ((x + c0)/c1)*c1), b, c1 > 0 && c0 >= c1 - 1) || - rewrite(max((x/c1)*c1 + c2, x), b, c1 > 0 && c2 <= 0) || - rewrite(max(x, (x/c1)*c1 + c2), a, c1 > 0 && c2 <= 0) || - rewrite(max(((x + c0)/c1)*c1, x), b, c1 > 0 && c0 <= 0) || - rewrite(max(x, ((x + c0)/c1)*c1), a, c1 > 0 && c0 <= 0) || + rewrite(max((x / c1) * c1 + c2, x), a, c1 > 0 && c2 >= c1 - 1) || + rewrite(max(x, (x / c1) * c1 + c2), b, c1 > 0 && c2 >= c1 - 1) || + rewrite(max(((x + c0) / c1) * c1, x), a, c1 > 0 && c0 >= c1 - 1) || + rewrite(max(x, ((x + c0) / c1) * c1), b, c1 > 0 && c0 >= c1 - 1) || + rewrite(max((x / c1) * c1 + c2, x), b, c1 > 0 && c2 <= 0) || + rewrite(max(x, (x / c1) * c1 + c2), a, c1 > 0 && c2 <= 0) || + rewrite(max(((x + c0) / c1) * c1, x), b, c1 > 0 && c0 <= 0) || + rewrite(max(x, ((x + c0) / c1) * c1), a, c1 > 0 && c0 <= 0) || rewrite(max(x, min(x, y) + c0), a, c0 <= 0) || rewrite(max(x, min(y, x) + c0), a, c0 <= 0) || @@ -139,8 +138,8 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { rewrite(max(min(x, y + c0), y), b, c0 <= 0) || (no_overflow_int(op->type) && - (rewrite(max(min(c0 - x, x), c1), b, 2*c1 >= c0 - 1) || - rewrite(max(min(x, c0 - x), c1), b, 2*c1 >= c0 - 1))) || + (rewrite(max(min(c0 - x, x), c1), b, 2 * c1 >= c0 - 1) || + rewrite(max(min(x, c0 - x), c1), b, 2 * c1 >= c0 - 1))) || false)))) { @@ -156,10 +155,8 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { return rewrite.result; } - // clang-format on - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(max(max(x, c0), c1), max(x, fold(max(c0, c1)))) || rewrite(max(max(x, c0), y), max(max(x, y), c0)) || rewrite(max(max(x, y), max(x, z)), max(max(y, z), x)) || @@ -246,10 +243,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { rewrite(max(y + x, x), max(y, 0) + x) || rewrite(max(x + y, x), x + max(y, 0)) || - rewrite(max((x*c0 + y)*c1, x*c2 + z), max(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(max((y + x*c0)*c1, x*c2 + z), max(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(max((x*c0 + y)*c1, z + x*c2), max(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(max((y + x*c0)*c1, z + x*c2), max(y*c1, z) + x*c2, c0 * c1 == c2) || + rewrite(max((x * c0 + y) * c1, x * c2 + z), max(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(max((y + x * c0) * c1, x * c2 + z), max(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(max((x * c0 + y) * c1, z + x * c2), max(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(max((y + x * c0) * c1, z + x * c2), max(y * c1, z) + x * c2, c0 * c1 == c2) || rewrite(max(max(x + y, z), x + w), max(x + max(y, w), z)) || rewrite(max(max(z, x + y), x + w), max(x + max(y, w), z)) || @@ -311,19 +308,18 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) { rewrite(max(((x + c0) / c1) * c1, x + c2), ((x + c0) / c1) * c1, c1 > 0 && c0 + 1 >= c1 + c2) || - rewrite(max((x + c0)/c1, ((x + c2)/c3)*c4), (x + c0)/c1, c2 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max((x + c0)/c1, ((x + c2)/c3)*c4), ((x + c2)/c3)*c4, c0 + c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max(x/c1, ((x + c2)/c3)*c4), x/c1, c2 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max(x/c1, ((x + c2)/c3)*c4), ((x + c2)/c3)*c4, c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max((x + c0)/c1, (x/c3)*c4), (x + c0)/c1, 0 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max((x + c0)/c1, (x/c3)*c4), (x/c3)*c4, c0 + c3 - c1 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(max(x/c1, (x/c3)*c4), x/c1, c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max((x + c0) / c1, ((x + c2) / c3) * c4), (x + c0) / c1, c2 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max((x + c0) / c1, ((x + c2) / c3) * c4), ((x + c2) / c3) * c4, c0 + c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max(x / c1, ((x + c2) / c3) * c4), x / c1, c2 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max(x / c1, ((x + c2) / c3) * c4), ((x + c2) / c3) * c4, c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max((x + c0) / c1, (x / c3) * c4), (x + c0) / c1, 0 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max((x + c0) / c1, (x / c3) * c4), (x / c3) * c4, c0 + c3 - c1 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(max(x / c1, (x / c3) * c4), x / c1, c1 > 0 && c3 > 0 && c1 * c4 == c3) || rewrite(max(c0 - x, c1), c0 - min(x, fold(c0 - c1))))))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index a51bdaa0d0ef..3f6084c6c4f1 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -66,15 +66,14 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { return rewrite.result; } - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(min(x, x), a) || rewrite(min(c0, c1), fold(min(c0, c1))) || // Cases where one side dominates: rewrite(min(x, c0), b, is_min_value(c0)) || rewrite(min(x, c0), a, is_max_value(c0)) || - rewrite(min((x/c0)*c0, x), a, c0 > 0) || - rewrite(min(x, (x/c0)*c0), b, c0 > 0) || + rewrite(min((x / c0) * c0, x), a, c0 > 0) || + rewrite(min(x, (x / c0) * c0), b, c0 > 0) || rewrite(min(min(x, y), x), a) || rewrite(min(min(x, y), y), a) || rewrite(min(min(min(x, y), z), x), a) || @@ -118,20 +117,20 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { rewrite(min(ramp(x, y, lanes), broadcast(z, lanes)), b, can_prove(x + y * (lanes - 1) >= z && x >= z, this)) || // Compare x to a stair-step function in x - rewrite(min(((x + c0)/c1)*c1 + c2, x), b, c1 > 0 && c0 + c2 >= c1 - 1) || - rewrite(min(x, ((x + c0)/c1)*c1 + c2), a, c1 > 0 && c0 + c2 >= c1 - 1) || - rewrite(min(((x + c0)/c1)*c1 + c2, x), a, c1 > 0 && c0 + c2 <= 0) || - rewrite(min(x, ((x + c0)/c1)*c1 + c2), b, c1 > 0 && c0 + c2 <= 0) || - rewrite(min((x/c0)*c0, (x/c1)*c1 + c2), a, c2 >= c1 && c1 > 0 && c0 != 0) || + rewrite(min(((x + c0) / c1) * c1 + c2, x), b, c1 > 0 && c0 + c2 >= c1 - 1) || + rewrite(min(x, ((x + c0) / c1) * c1 + c2), a, c1 > 0 && c0 + c2 >= c1 - 1) || + rewrite(min(((x + c0) / c1) * c1 + c2, x), a, c1 > 0 && c0 + c2 <= 0) || + rewrite(min(x, ((x + c0) / c1) * c1 + c2), b, c1 > 0 && c0 + c2 <= 0) || + rewrite(min((x / c0) * c0, (x / c1) * c1 + c2), a, c2 >= c1 && c1 > 0 && c0 != 0) || // Special cases where c0 or c2 is zero - rewrite(min((x/c1)*c1 + c2, x), b, c1 > 0 && c2 >= c1 - 1) || - rewrite(min(x, (x/c1)*c1 + c2), a, c1 > 0 && c2 >= c1 - 1) || - rewrite(min(((x + c0)/c1)*c1, x), b, c1 > 0 && c0 >= c1 - 1) || - rewrite(min(x, ((x + c0)/c1)*c1), a, c1 > 0 && c0 >= c1 - 1) || - rewrite(min((x/c1)*c1 + c2, x), a, c1 > 0 && c2 <= 0) || - rewrite(min(x, (x/c1)*c1 + c2), b, c1 > 0 && c2 <= 0) || - rewrite(min(((x + c0)/c1)*c1, x), a, c1 > 0 && c0 <= 0) || - rewrite(min(x, ((x + c0)/c1)*c1), b, c1 > 0 && c0 <= 0) || + rewrite(min((x / c1) * c1 + c2, x), b, c1 > 0 && c2 >= c1 - 1) || + rewrite(min(x, (x / c1) * c1 + c2), a, c1 > 0 && c2 >= c1 - 1) || + rewrite(min(((x + c0) / c1) * c1, x), b, c1 > 0 && c0 >= c1 - 1) || + rewrite(min(x, ((x + c0) / c1) * c1), a, c1 > 0 && c0 >= c1 - 1) || + rewrite(min((x / c1) * c1 + c2, x), a, c1 > 0 && c2 <= 0) || + rewrite(min(x, (x / c1) * c1 + c2), b, c1 > 0 && c2 <= 0) || + rewrite(min(((x + c0) / c1) * c1, x), a, c1 > 0 && c0 <= 0) || + rewrite(min(x, ((x + c0) / c1) * c1), b, c1 > 0 && c0 <= 0) || rewrite(min(x, max(x, y) + c0), a, 0 <= c0) || rewrite(min(x, max(y, x) + c0), a, 0 <= c0) || @@ -140,8 +139,8 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { rewrite(min(max(x, y + c0), y), b, 0 <= c0) || (no_overflow_int(op->type) && - (rewrite(min(max(c0 - x, x), c1), b, 2*c1 <= c0 + 1) || - rewrite(min(max(x, c0 - x), c1), b, 2*c1 <= c0 + 1))) || + (rewrite(min(max(c0 - x, x), c1), b, 2 * c1 <= c0 + 1) || + rewrite(min(max(x, c0 - x), c1), b, 2 * c1 <= c0 + 1))) || false)))) { // One of the cancellation rules above may give us tighter bounds @@ -155,10 +154,8 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { } return rewrite.result; } - // clang-format on - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(min(min(x, c0), c1), min(x, fold(min(c0, c1)))) || rewrite(min(min(x, c0), y), min(min(x, y), c0)) || rewrite(min(min(x, y), min(x, z)), min(min(y, z), x)) || @@ -249,10 +246,10 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { rewrite(min(y + x, x), min(y, 0) + x) || rewrite(min(x + y, x), x + min(y, 0)) || - rewrite(min((x*c0 + y)*c1, x*c2 + z), min(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(min((y + x*c0)*c1, x*c2 + z), min(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(min((x*c0 + y)*c1, z + x*c2), min(y*c1, z) + x*c2, c0 * c1 == c2) || - rewrite(min((y + x*c0)*c1, z + x*c2), min(y*c1, z) + x*c2, c0 * c1 == c2) || + rewrite(min((x * c0 + y) * c1, x * c2 + z), min(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(min((y + x * c0) * c1, x * c2 + z), min(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(min((x * c0 + y) * c1, z + x * c2), min(y * c1, z) + x * c2, c0 * c1 == c2) || + rewrite(min((y + x * c0) * c1, z + x * c2), min(y * c1, z) + x * c2, c0 * c1 == c2) || rewrite(min(min(x + y, z), x + w), min(x + min(y, w), z)) || rewrite(min(min(z, x + y), x + w), min(x + min(y, w), z)) || @@ -317,23 +314,22 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) { rewrite(min(c0 - x, c1), c0 - max(x, fold(c0 - c1))) || // Required for nested GuardWithIf tilings - rewrite(min((min(((y + c0)/c1), x)*c1), y + c2), min(x * c1, y + c2), c1 > 0 && c1 + c2 <= c0 + 1) || - rewrite(min((min(((y + c0)/c1), x)*c1) + c2, y), min(x * c1 + c2, y), c1 > 0 && c1 <= c0 + c2 + 1) || - rewrite(min(min(((y + c0)/c1), x)*c1, y), min(x * c1, y), c1 > 0 && c1 <= c0 + 1) || + rewrite(min((min(((y + c0) / c1), x) * c1), y + c2), min(x * c1, y + c2), c1 > 0 && c1 + c2 <= c0 + 1) || + rewrite(min((min(((y + c0) / c1), x) * c1) + c2, y), min(x * c1 + c2, y), c1 > 0 && c1 <= c0 + c2 + 1) || + rewrite(min(min(((y + c0) / c1), x) * c1, y), min(x * c1, y), c1 > 0 && c1 <= c0 + 1) || + + rewrite(min((x + c0) / c1, ((x + c2) / c3) * c4), (x + c0) / c1, c0 + c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min((x + c0) / c1, ((x + c2) / c3) * c4), ((x + c2) / c3) * c4, c2 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min(x / c1, ((x + c2) / c3) * c4), x / c1, c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min(x / c1, ((x + c2) / c3) * c4), ((x + c2) / c3) * c4, c2 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min((x + c0) / c1, (x / c3) * c4), (x + c0) / c1, c0 + c3 - c1 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min((x + c0) / c1, (x / c3) * c4), (x / c3) * c4, 0 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || + rewrite(min(x / c1, (x / c3) * c4), (x / c3) * c4, c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min((x + c0)/c1, ((x + c2)/c3)*c4), (x + c0)/c1, c0 + c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min((x + c0)/c1, ((x + c2)/c3)*c4), ((x + c2)/c3)*c4, c2 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min(x/c1, ((x + c2)/c3)*c4), x/c1, c3 - c1 <= c2 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min(x/c1, ((x + c2)/c3)*c4), ((x + c2)/c3)*c4, c2 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min((x + c0)/c1, (x/c3)*c4), (x + c0)/c1, c0 + c3 - c1 <= 0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min((x + c0)/c1, (x/c3)*c4), (x/c3)*c4, 0 <= c0 && c1 > 0 && c3 > 0 && c1 * c4 == c3) || - rewrite(min(x/c1, (x/c3)*c4), (x/c3)*c4, c1 > 0 && c3 > 0 && c1 * c4 == c3) || - - false )))) { + false)))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_Mod.cpp b/src/Simplify_Mod.cpp index 30ef375fa3c9..7e5232da0975 100644 --- a/src/Simplify_Mod.cpp +++ b/src/Simplify_Mod.cpp @@ -43,10 +43,8 @@ Expr Simplify::visit(const Mod *op, ExprInfo *info) { return rewrite.result; } - // clang-format off - if (EVAL_IN_LAMBDA - ( - rewrite(c0 % c1, fold(c0 % c1)) || + if (EVAL_IN_LAMBDA // + (rewrite(c0 % c1, fold(c0 % c1)) || rewrite(0 % x, 0) || rewrite(x % x, 0) || rewrite(x % 0, 0) || @@ -55,12 +53,12 @@ Expr Simplify::visit(const Mod *op, ExprInfo *info) { (no_overflow_int(op->type) && (rewrite((x * c0) % c1, (x * fold(c0 % c1)) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) || rewrite((x + c0) % c1, (x + fold(c0 % c1)) % c1, c1 > 0 && (c0 >= c1 || c0 < 0)) || - rewrite((x * c0) % c1, (x % fold(c1/c0)) * c0, c0 > 0 && c1 % c0 == 0) || + rewrite((x * c0) % c1, (x % fold(c1 / c0)) * c0, c0 > 0 && c1 % c0 == 0) || rewrite((x * c0 + y) % c1, y % c1, c0 % c1 == 0) || rewrite((y + x * c0) % c1, y % c1, c0 % c1 == 0) || rewrite((x * c0 - y) % c1, (-y) % c1, c0 % c1 == 0) || rewrite((y - x * c0) % c1, y % c1, c0 % c1 == 0) || - rewrite((x - y) % 2, (x + y) % 2) || // Addition and subtraction are the same modulo 2, because -1 == 1 + rewrite((x - y) % 2, (x + y) % 2) || // Addition and subtraction are the same modulo 2, because -1 == 1 rewrite(ramp(x, c0, c2) % broadcast(c1, c2), broadcast(x, c2) % broadcast(c1, c2), (c0 % c1 == 0)) || rewrite(ramp(x, c0, lanes) % broadcast(c1, lanes), ramp(x % c1, c0, lanes), @@ -80,7 +78,6 @@ Expr Simplify::visit(const Mod *op, ExprInfo *info) { c0 % c1 == 0))))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_Or.cpp b/src/Simplify_Or.cpp index 7537f4b44ff9..36fc0d4b334b 100644 --- a/src/Simplify_Or.cpp +++ b/src/Simplify_Or.cpp @@ -25,10 +25,8 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::or_op(a, b), op->type); - // clang-format off - // Cases that fold to a constant - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(x || true, true) || rewrite(true || x, true) || rewrite(x || neg(x), true) || @@ -74,7 +72,7 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { } // Cases that fold to one of the args - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(x || false, a) || rewrite(x || x, a) || @@ -131,9 +129,8 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { return rewrite.result; } - // Cases that need remutation - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(broadcast(x, c0) || broadcast(y, c0), broadcast(x || y, c0)) || rewrite((x || broadcast(y, c0)) || broadcast(z, c0), x || broadcast(y || z, c0)) || rewrite((broadcast(x, c0) || y) || broadcast(z, c0), broadcast(x || z, c0) || y) || @@ -265,8 +262,6 @@ Expr Simplify::visit(const Or *op, ExprInfo *info) { return mutate(rewrite.result, info); } - // clang-format on - if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Simplify_Select.cpp b/src/Simplify_Select.cpp index 2a4aa953e072..3bc4507fc74b 100644 --- a/src/Simplify_Select.cpp +++ b/src/Simplify_Select.cpp @@ -18,8 +18,7 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::select(condition, true_value, false_value), op->type); - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(select(IRMatcher::likely(true), x, y), true_value) || rewrite(select(IRMatcher::likely(false), x, y), false_value) || rewrite(select(IRMatcher::likely_if_innermost(true), x, y), true_value) || @@ -41,10 +40,8 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { } return rewrite.result; } - // clang-format on - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(select(x != y, z, w), select(x == y, w, z)) || rewrite(select(x <= y, z, w), select(y < x, w, z)) || rewrite(select(x, select(y, z, w), z), select(x && !y, w, z)) || @@ -160,9 +157,9 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { rewrite(select(x, max(w, y), max(z, w)), max(w, select(x, y, z))) || rewrite(select(x, max(w, y), max(w, z)), max(w, select(x, y, z))) || - rewrite(select(0 < x, min(x*c0, c1), x*c0), min(x*c0, c1), c1 >= 0 && c0 >= 0) || + rewrite(select(0 < x, min(x * c0, c1), x * c0), min(x * c0, c1), c1 >= 0 && c0 >= 0) || rewrite(select(x < c0, 0, min(x, c0) + c1), 0, c0 == -c1) || - rewrite(select(0 < x, ((x*c0) + c1) / x, y), select(0 < x, c0 - 1, y), c1 == -1) || + rewrite(select(0 < x, ((x * c0) + c1) / x, y), select(0 < x, c0 - 1, y), c1 == -1) || rewrite(select(x, select(y, z, min(w, z)), min(u, z)), min(select(x, select(y, z, w), u), z)) || rewrite(select(x, select(y, min(w, z), z), min(u, z)), min(select(x, select(y, w, z), u), z)) || @@ -233,7 +230,6 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) { rewrite(select(x, true, y), x || y))))) { return mutate(rewrite.result, info); } - // clang-format on if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && diff --git a/src/Simplify_Sub.cpp b/src/Simplify_Sub.cpp index faeb4a17ff61..828cbeec2b2d 100644 --- a/src/Simplify_Sub.cpp +++ b/src/Simplify_Sub.cpp @@ -26,18 +26,17 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { return rewrite.result; } - // clang-format off - if (EVAL_IN_LAMBDA + if (EVAL_IN_LAMBDA // (rewrite(c0 - c1, fold(c0 - c1)) || (!op->type.is_uint() && rewrite(x - c0, x + fold(-c0), !overflows(-c0))) || - rewrite(x - x, 0) || // We want to remutate this just to get better bounds + rewrite(x - x, 0) || // We want to remutate this just to get better bounds rewrite(ramp(x, y, c0) - ramp(z, w, c0), ramp(x - z, y - w, c0)) || rewrite(ramp(x, y, c0) - broadcast(z, c0), ramp(x - z, y, c0)) || rewrite(broadcast(x, c0) - ramp(z, w, c0), ramp(x - z, -w, c0)) || rewrite(broadcast(x, c0) - broadcast(y, c0), broadcast(x - y, c0)) || - rewrite(broadcast(x, c0) - broadcast(y, c1), broadcast(x - broadcast(y, fold(c1/c0)), c0), c1 % c0 == 0) || - rewrite(broadcast(y, c1) - broadcast(x, c0), broadcast(broadcast(y, fold(c1/c0)) - x, c0), c1 % c0 == 0) || + rewrite(broadcast(x, c0) - broadcast(y, c1), broadcast(x - broadcast(y, fold(c1 / c0)), c0), c1 % c0 == 0) || + rewrite(broadcast(y, c1) - broadcast(x, c0), broadcast(broadcast(y, fold(c1 / c0)) - x, c0), c1 % c0 == 0) || rewrite((x - broadcast(y, c0)) - broadcast(z, c0), x - broadcast(y + z, c0)) || rewrite((x + broadcast(y, c0)) - broadcast(z, c0), x + broadcast(y - z, c0)) || @@ -84,45 +83,45 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite((c0 - x) - (c1 - y), (y - x) + fold(c0 - c1)) || rewrite((c0 - x) - (y + c1), fold(c0 - c1) - (x + y)) || rewrite(x - (y - z), x + (z - y)) || - rewrite(x - y*c0, x + y*fold(-c0), c0 < 0 && -c0 > 0) || + rewrite(x - y * c0, x + y * fold(-c0), c0 < 0 && -c0 > 0) || rewrite(x - (y + c0), (x - y) - c0) || rewrite((c0 - x) - c1, fold(c0 - c1) - x) || - rewrite(x*y - z*y, (x - z)*y) || - rewrite(x*y - y*z, (x - z)*y) || - rewrite(y*x - z*y, y*(x - z)) || - rewrite(y*x - y*z, y*(x - z)) || - rewrite((u + x*y) - z*y, u + (x - z)*y) || - rewrite((u + x*y) - y*z, u + (x - z)*y) || - rewrite((u + y*x) - z*y, u + y*(x - z)) || - rewrite((u + y*x) - y*z, u + y*(x - z)) || - rewrite((u - x*y) - z*y, u - (x + z)*y) || - rewrite((u - x*y) - y*z, u - (x + z)*y) || - rewrite((u - y*x) - z*y, u - y*(x + z)) || - rewrite((u - y*x) - y*z, u - y*(x + z)) || - rewrite((x*y + u) - z*y, u + (x - z)*y) || - rewrite((x*y + u) - y*z, u + (x - z)*y) || - rewrite((y*x + u) - z*y, u + y*(x - z)) || - rewrite((y*x + u) - y*z, u + y*(x - z)) || - rewrite((x*y - u) - z*y, (x - z)*y - u) || - rewrite((x*y - u) - y*z, (x - z)*y - u) || - rewrite((y*x - u) - z*y, y*(x - z) - u) || - rewrite((y*x - u) - y*z, y*(x - z) - u) || - rewrite(x*y - (u + z*y), (x - z)*y - u) || - rewrite(x*y - (u + y*z), (x - z)*y - u) || - rewrite(y*x - (u + z*y), y*(x - z) - u) || - rewrite(y*x - (u + y*z), y*(x - z) - u) || - rewrite(x*y - (u - z*y), (x + z)*y - u) || - rewrite(x*y - (u - y*z), (x + z)*y - u) || - rewrite(y*x - (u - z*y), y*(x + z) - u) || - rewrite(y*x - (u - y*z), y*(x + z) - u) || - rewrite(x*y - (z*y + u), (x - z)*y - u) || - rewrite(x*y - (y*z + u), (x - z)*y - u) || - rewrite(y*x - (z*y + u), y*(x - z) - u) || - rewrite(y*x - (y*z + u), y*(x - z) - u) || - rewrite(x*y - (z*y - u), (x - z)*y + u) || - rewrite(x*y - (y*z - u), (x - z)*y + u) || - rewrite(y*x - (z*y - u), y*(x - z) + u) || - rewrite(y*x - (y*z - u), y*(x - z) + u) || + rewrite(x * y - z * y, (x - z) * y) || + rewrite(x * y - y * z, (x - z) * y) || + rewrite(y * x - z * y, y * (x - z)) || + rewrite(y * x - y * z, y * (x - z)) || + rewrite((u + x * y) - z * y, u + (x - z) * y) || + rewrite((u + x * y) - y * z, u + (x - z) * y) || + rewrite((u + y * x) - z * y, u + y * (x - z)) || + rewrite((u + y * x) - y * z, u + y * (x - z)) || + rewrite((u - x * y) - z * y, u - (x + z) * y) || + rewrite((u - x * y) - y * z, u - (x + z) * y) || + rewrite((u - y * x) - z * y, u - y * (x + z)) || + rewrite((u - y * x) - y * z, u - y * (x + z)) || + rewrite((x * y + u) - z * y, u + (x - z) * y) || + rewrite((x * y + u) - y * z, u + (x - z) * y) || + rewrite((y * x + u) - z * y, u + y * (x - z)) || + rewrite((y * x + u) - y * z, u + y * (x - z)) || + rewrite((x * y - u) - z * y, (x - z) * y - u) || + rewrite((x * y - u) - y * z, (x - z) * y - u) || + rewrite((y * x - u) - z * y, y * (x - z) - u) || + rewrite((y * x - u) - y * z, y * (x - z) - u) || + rewrite(x * y - (u + z * y), (x - z) * y - u) || + rewrite(x * y - (u + y * z), (x - z) * y - u) || + rewrite(y * x - (u + z * y), y * (x - z) - u) || + rewrite(y * x - (u + y * z), y * (x - z) - u) || + rewrite(x * y - (u - z * y), (x + z) * y - u) || + rewrite(x * y - (u - y * z), (x + z) * y - u) || + rewrite(y * x - (u - z * y), y * (x + z) - u) || + rewrite(y * x - (u - y * z), y * (x + z) - u) || + rewrite(x * y - (z * y + u), (x - z) * y - u) || + rewrite(x * y - (y * z + u), (x - z) * y - u) || + rewrite(y * x - (z * y + u), y * (x - z) - u) || + rewrite(y * x - (y * z + u), y * (x - z) - u) || + rewrite(x * y - (z * y - u), (x - z) * y + u) || + rewrite(x * y - (y * z - u), (x - z) * y + u) || + rewrite(y * x - (z * y - u), y * (x - z) + u) || + rewrite(y * x - (y * z - u), y * (x - z) + u) || rewrite((x + y) - (x + z), y - z) || rewrite((x + y) - (z + x), y - z) || rewrite((y + x) - (x + z), y - z) || @@ -167,8 +166,8 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite(0 - ((x - y) + z), y - (x + z)) || rewrite(((x - y) - z) - x, 0 - (y + z)) || - rewrite(x - x%c0, (x/c0)*c0) || - rewrite(x - ((x + c0)/c1)*c1, (x + c0)%c1 - c0, c1 > 0) || + rewrite(x - x % c0, (x / c0) * c0) || + rewrite(x - ((x + c0) / c1) * c1, (x + c0) % c1 - c0, c1 > 0) || // Hoist shuffles. The Shuffle visitor wants to sink // extract_elements to the leaves, and those count as degenerate @@ -179,7 +178,7 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite((slice(x, c0, c1, c2) - z) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || rewrite((z - slice(x, c0, c1, c2)) - slice(y, c0, c1, c2), z - slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || - (no_overflow(op->type) && EVAL_IN_LAMBDA + (no_overflow(op->type) && EVAL_IN_LAMBDA // (rewrite(max(x, y) - x, max(y - x, 0)) || rewrite(min(x, y) - x, min(y - x, 0)) || rewrite(max(x, y) - y, max(x - y, 0)) || @@ -208,10 +207,10 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite(z - max(min(x - y, c0), c1), z + min(max(y - x, fold(-c0)), fold(-c1))) || rewrite(z - min(max(x - y, c0), c1), z + max(min(y - x, fold(-c0)), fold(-c1))) || - rewrite(x*y - x, x*(y - 1)) || - rewrite(x*y - y, (x - 1)*y) || - rewrite(x - x*y, x*(1 - y)) || - rewrite(x - y*x, (1 - y)*x) || + rewrite(x * y - x, x * (y - 1)) || + rewrite(x * y - y, (x - 1) * y) || + rewrite(x - x * y, x * (1 - y)) || + rewrite(x - y * x, (1 - y) * x) || // Cancel a term from one side of a min or max. Some of // these rules introduce a new constant zero, so we require @@ -248,7 +247,7 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite(min(x, y) - min(y, x), 0) || rewrite(min(x, y) - min(z, w), y - w, can_prove(x - y == z - w, this)) || rewrite(min(x, y) - min(w, z), y - w, can_prove(x - y == z - w, this)) || - rewrite(min(x*c0, c1) - min(x, c2)*c0, min(c1 - min(x, c2)*c0, 0), c0 > 0 && c1 <= c2*c0) || + rewrite(min(x * c0, c1) - min(x, c2) * c0, min(c1 - min(x, c2) * c0, 0), c0 > 0 && c1 <= c2 * c0) || rewrite((x - max(z, (x + y))), (0 - max(z - x, y)), !is_const(x)) || rewrite((x - max(z, (y + x))), (0 - max(z - x, y)), !is_const(x)) || @@ -382,38 +381,38 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite(max(y, x + c0) - max(x + c1, w), max(y - max(x + c1, w), fold(c0 - c1)), can_prove(y + c1 >= w + c0, this)) || rewrite(max(y, x + c0) - max(x + c1, w), min(max(x + c0, y) - w, fold(c0 - c1)), can_prove(y + c1 <= w + c0, this)))) || - (no_overflow_int(op->type) && EVAL_IN_LAMBDA - (rewrite(c0 - (c1 - x)/c2, (fold(c0*c2 - c1 + c2 - 1) + x)/c2, c2 > 0) || - rewrite(c0 - (x + c1)/c2, (fold(c0*c2 - c1 + c2 - 1) - x)/c2, c2 > 0) || - rewrite(x - (x + y)/c0, (x*fold(c0 - 1) - y + fold(c0 - 1))/c0, c0 > 0) || - rewrite(x - (x - y)/c0, (x*fold(c0 - 1) + y + fold(c0 - 1))/c0, c0 > 0) || - rewrite(x - (y + x)/c0, (x*fold(c0 - 1) - y + fold(c0 - 1))/c0, c0 > 0) || - rewrite(x - (y - x)/c0, (x*fold(c0 + 1) - y + fold(c0 - 1))/c0, c0 > 0) || - rewrite((x + y)/c0 - x, (x*fold(1 - c0) + y)/c0) || - rewrite((y + x)/c0 - x, (y + x*fold(1 - c0))/c0) || - rewrite((x - y)/c0 - x, (x*fold(1 - c0) - y)/c0) || - rewrite((y - x)/c0 - x, (y - x*fold(1 + c0))/c0) || - - rewrite((x/c0)*c0 - x, -(x % c0), c0 > 0) || - rewrite(x - (x/c0)*c0, x % c0, c0 > 0) || - rewrite(((x + c0)/c1)*c1 - x, (-x) % c1, c1 > 0 && c0 + 1 == c1) || - rewrite(x - ((x + c0)/c1)*c1, ((x + c0) % c1) + fold(-c0), c1 > 0 && c0 + 1 == c1) || + (no_overflow_int(op->type) && EVAL_IN_LAMBDA // + (rewrite(c0 - (c1 - x) / c2, (fold(c0 * c2 - c1 + c2 - 1) + x) / c2, c2 > 0) || + rewrite(c0 - (x + c1) / c2, (fold(c0 * c2 - c1 + c2 - 1) - x) / c2, c2 > 0) || + rewrite(x - (x + y) / c0, (x * fold(c0 - 1) - y + fold(c0 - 1)) / c0, c0 > 0) || + rewrite(x - (x - y) / c0, (x * fold(c0 - 1) + y + fold(c0 - 1)) / c0, c0 > 0) || + rewrite(x - (y + x) / c0, (x * fold(c0 - 1) - y + fold(c0 - 1)) / c0, c0 > 0) || + rewrite(x - (y - x) / c0, (x * fold(c0 + 1) - y + fold(c0 - 1)) / c0, c0 > 0) || + rewrite((x + y) / c0 - x, (x * fold(1 - c0) + y) / c0) || + rewrite((y + x) / c0 - x, (y + x * fold(1 - c0)) / c0) || + rewrite((x - y) / c0 - x, (x * fold(1 - c0) - y) / c0) || + rewrite((y - x) / c0 - x, (y - x * fold(1 + c0)) / c0) || + + rewrite((x / c0) * c0 - x, -(x % c0), c0 > 0) || + rewrite(x - (x / c0) * c0, x % c0, c0 > 0) || + rewrite(((x + c0) / c1) * c1 - x, (-x) % c1, c1 > 0 && c0 + 1 == c1) || + rewrite(x - ((x + c0) / c1) * c1, ((x + c0) % c1) + fold(-c0), c1 > 0 && c0 + 1 == c1) || rewrite(x * c0 - y * c1, (x * fold(c0 / c1) - y) * c1, c0 % c1 == 0) || rewrite(x * c0 - y * c1, (x - y * fold(c1 / c0)) * c0, c1 % c0 == 0) || // Various forms of (x +/- a)/c - (x +/- b)/c. We can // *almost* cancel the x. The right thing to do depends // on which of a or b is a constant, and we also need to // catch the cases where that constant is zero. - rewrite(((x + y) + z)/c0 - ((y + x) + w)/c0, ((x + y) + z)/c0 - ((x + y) + w)/c0, c0 > 0) || - rewrite((x + y)/c0 - (y + x)/c0, 0, c0 != 0) || - rewrite((x + y)/c0 - (x + c1)/c0, (((x + fold(c1 % c0)) % c0) + (y - c1))/c0, c0 > 0) || - rewrite((x + c1)/c0 - (x + y)/c0, ((fold(c0 + c1 - 1) - y) - ((x + fold(c1 % c0)) % c0))/c0, c0 > 0) || - rewrite((x - y)/c0 - (x + c1)/c0, (((x + fold(c1 % c0)) % c0) - y - c1)/c0, c0 > 0) || - rewrite((x + c1)/c0 - (x - y)/c0, ((y + fold(c0 + c1 - 1)) - ((x + fold(c1 % c0)) % c0))/c0, c0 > 0) || - rewrite(x/c0 - (x + y)/c0, ((fold(c0 - 1) - y) - (x % c0))/c0, c0 > 0) || - rewrite((x + y)/c0 - x/c0, ((x % c0) + y)/c0, c0 > 0) || - rewrite(x/c0 - (x - y)/c0, ((y + fold(c0 - 1)) - (x % c0))/c0, c0 > 0) || - rewrite((x - y)/c0 - x/c0, ((x % c0) - y)/c0, c0 > 0) || + rewrite(((x + y) + z) / c0 - ((y + x) + w) / c0, ((x + y) + z) / c0 - ((x + y) + w) / c0, c0 > 0) || + rewrite((x + y) / c0 - (y + x) / c0, 0, c0 != 0) || + rewrite((x + y) / c0 - (x + c1) / c0, (((x + fold(c1 % c0)) % c0) + (y - c1)) / c0, c0 > 0) || + rewrite((x + c1) / c0 - (x + y) / c0, ((fold(c0 + c1 - 1) - y) - ((x + fold(c1 % c0)) % c0)) / c0, c0 > 0) || + rewrite((x - y) / c0 - (x + c1) / c0, (((x + fold(c1 % c0)) % c0) - y - c1) / c0, c0 > 0) || + rewrite((x + c1) / c0 - (x - y) / c0, ((y + fold(c0 + c1 - 1)) - ((x + fold(c1 % c0)) % c0)) / c0, c0 > 0) || + rewrite(x / c0 - (x + y) / c0, ((fold(c0 - 1) - y) - (x % c0)) / c0, c0 > 0) || + rewrite((x + y) / c0 - x / c0, ((x % c0) + y) / c0, c0 > 0) || + rewrite(x / c0 - (x - y) / c0, ((y + fold(c0 - 1)) - (x % c0)) / c0, c0 > 0) || + rewrite((x - y) / c0 - x / c0, ((x % c0) - y) / c0, c0 > 0) || // Simplification of bounds code for various tail // strategies requires cancellations of the form: @@ -429,21 +428,20 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) { rewrite(min(min(x + z, y), w) - x, min(min(y, w) - x, z)) || rewrite(min(min(y, x + z), w) - x, min(min(y, w) - x, z)) || - rewrite(min((x + y)*u + z, w) - x*u, min(w - x*u, y*u + z)) || - rewrite(min((y + x)*u + z, w) - x*u, min(w - x*u, y*u + z)) || + rewrite(min((x + y) * u + z, w) - x * u, min(w - x * u, y * u + z)) || + rewrite(min((y + x) * u + z, w) - x * u, min(w - x * u, y * u + z)) || // Splits can introduce confounding divisions - rewrite(min(x*c0 + y, z) / c1 - x*c2, min(y, z - x*c0) / c1, c0 == c1 * c2) || - rewrite(min(z, x*c0 + y) / c1 - x*c2, min(y, z - x*c0) / c1, c0 == c1 * c2) || + rewrite(min(x * c0 + y, z) / c1 - x * c2, min(y, z - x * c0) / c1, c0 == c1 * c2) || + rewrite(min(z, x * c0 + y) / c1 - x * c2, min(y, z - x * c0) / c1, c0 == c1 * c2) || // There could also be an addition inside the division (e.g. if it's division rounding up) - rewrite((min(x*c0 + y, z) + w) / c1 - x*c2, (min(y, z - x*c0) + w) / c1, c0 == c1 * c2) || - rewrite((min(z, x*c0 + y) + w) / c1 - x*c2, (min(z - x*c0, y) + w) / c1, c0 == c1 * c2) || + rewrite((min(x * c0 + y, z) + w) / c1 - x * c2, (min(y, z - x * c0) + w) / c1, c0 == c1 * c2) || + rewrite((min(z, x * c0 + y) + w) / c1 - x * c2, (min(z - x * c0, y) + w) / c1, c0 == c1 * c2) || false)))) { return mutate(rewrite.result, info); } - // clang-format on if (a.same_as(op->a) && b.same_as(op->b)) { return op; diff --git a/src/Target.cpp b/src/Target.cpp index 875ce4ae164e..7ec3294cd2c6 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -1566,7 +1566,6 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result) // (b) must be included if both targets have the feature (intersection) // (c) must match across both targets; it is an error if one target has the feature and the other doesn't - // clang-format off const std::vector union_features = {{ // These are true union features. CUDA, @@ -1609,9 +1608,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result) ARMv88a, ARMv89a, }}; - // clang-format on - // clang-format off const std::vector intersection_features = {{ ARMv7s, AVX, @@ -1630,9 +1627,7 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result) SSE41, VSX, }}; - // clang-format on - // clang-format off const std::vector matching_features = {{ ASAN, Debug, @@ -1646,7 +1641,6 @@ bool Target::get_runtime_compatible_target(const Target &other, Target &result) SanitizerCoverage, Simulator, }}; - // clang-format on // bitsets need to be the same width. decltype(result.features) union_mask; diff --git a/src/Type.h b/src/Type.h index d6143f38b6de..0b18b86353a6 100644 --- a/src/Type.h +++ b/src/Type.h @@ -218,12 +218,10 @@ template (is_const ? Const : 0) | (is_volatile ? Volatile : 0)); - // clang-format off constexpr ReferenceType ref_type = (is_lvalue_reference ? LValueReference : is_rvalue_reference ? RValueReference : NotReference); - // clang-format on using TNonCVBase = std::remove_cv_t; constexpr bool known_type = halide_c_type_to_name::known_type; diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index 716359088129..55ffa7325e22 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -385,11 +385,10 @@ std::vector compile_to_wasm(const Module &module, const std::string &fn_na // variants *will* be instantiated (increasing code size), so this approach // should only be used when strictly necessary. -// clang-format off template class Functor, typename... Args> -auto dynamic_type_dispatch(const halide_type_t &type, Args &&... args) -> decltype(std::declval>()(std::forward(args)...)) { +auto dynamic_type_dispatch(const halide_type_t &type, Args &&...args) -> decltype(std::declval>()(std::forward(args)...)) { -#define HANDLE_CASE(CODE, BITS, TYPE) \ +#define HANDLE_CASE(CODE, BITS, TYPE) \ case halide_type_t(CODE, BITS).as_u32(): \ return Functor()(std::forward(args)...); @@ -416,7 +415,6 @@ auto dynamic_type_dispatch(const halide_type_t &type, Args &&... args) -> declty #undef HANDLE_CASE } -// clang-format on // ----------------------- // extern callback helper code @@ -2118,23 +2116,21 @@ void add_extern_callbacks(const Local &context, #endif // WITH_V8 -// clang-format off - #if WITH_WABT using HostCallbackMap = std::unordered_map; -#define DEFINE_CALLBACK(f) { #f, wabt_jit_##f##_callback }, -#define DEFINE_POSIX_MATH_CALLBACK(t, f) { #f, wabt_posix_math_1 }, -#define DEFINE_POSIX_MATH_CALLBACK2(t, f) { #f, wabt_posix_math_2 }, +#define DEFINE_CALLBACK(f) {#f, wabt_jit_##f##_callback} +#define DEFINE_POSIX_MATH_CALLBACK(t, f) {#f, wabt_posix_math_1} +#define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wabt_posix_math_2} #endif #ifdef WITH_V8 using HostCallbackMap = std::unordered_map; -#define DEFINE_CALLBACK(f) { #f, wasm_jit_##f##_callback }, -#define DEFINE_POSIX_MATH_CALLBACK(t, f) { #f, wasm_jit_posix_math_callback }, -#define DEFINE_POSIX_MATH_CALLBACK2(t, f) { #f, wasm_jit_posix_math2_callback }, +#define DEFINE_CALLBACK(f) {#f, wasm_jit_##f##_callback} +#define DEFINE_POSIX_MATH_CALLBACK(t, f) {#f, wasm_jit_posix_math_callback} +#define DEFINE_POSIX_MATH_CALLBACK2(t, f) {#f, wasm_jit_posix_math2_callback} #endif const HostCallbackMap &get_host_callback_map() { @@ -2142,69 +2138,68 @@ const HostCallbackMap &get_host_callback_map() { static HostCallbackMap m = { // General runtime functions. - DEFINE_CALLBACK(__cxa_atexit) - DEFINE_CALLBACK(__extendhfsf2) - DEFINE_CALLBACK(__truncsfhf2) - DEFINE_CALLBACK(abort) - DEFINE_CALLBACK(fclose) - DEFINE_CALLBACK(fileno) - DEFINE_CALLBACK(fopen) - DEFINE_CALLBACK(free) - DEFINE_CALLBACK(fwrite) - DEFINE_CALLBACK(getenv) - DEFINE_CALLBACK(halide_error) - DEFINE_CALLBACK(halide_print) - DEFINE_CALLBACK(halide_trace_helper) - DEFINE_CALLBACK(malloc) - DEFINE_CALLBACK(memcmp) - DEFINE_CALLBACK(memcpy) - DEFINE_CALLBACK(memmove) - DEFINE_CALLBACK(memset) - DEFINE_CALLBACK(strlen) - DEFINE_CALLBACK(write) + DEFINE_CALLBACK(__cxa_atexit), + DEFINE_CALLBACK(__extendhfsf2), + DEFINE_CALLBACK(__truncsfhf2), + DEFINE_CALLBACK(abort), + DEFINE_CALLBACK(fclose), + DEFINE_CALLBACK(fileno), + DEFINE_CALLBACK(fopen), + DEFINE_CALLBACK(free), + DEFINE_CALLBACK(fwrite), + DEFINE_CALLBACK(getenv), + DEFINE_CALLBACK(halide_error), + DEFINE_CALLBACK(halide_print), + DEFINE_CALLBACK(halide_trace_helper), + DEFINE_CALLBACK(malloc), + DEFINE_CALLBACK(memcmp), + DEFINE_CALLBACK(memcpy), + DEFINE_CALLBACK(memmove), + DEFINE_CALLBACK(memset), + DEFINE_CALLBACK(strlen), + DEFINE_CALLBACK(write), // Posix math. - DEFINE_POSIX_MATH_CALLBACK(double, acos) - DEFINE_POSIX_MATH_CALLBACK(double, acosh) - DEFINE_POSIX_MATH_CALLBACK(double, asin) - DEFINE_POSIX_MATH_CALLBACK(double, asinh) - DEFINE_POSIX_MATH_CALLBACK(double, atan) - DEFINE_POSIX_MATH_CALLBACK(double, atanh) - DEFINE_POSIX_MATH_CALLBACK(double, cos) - DEFINE_POSIX_MATH_CALLBACK(double, cosh) - DEFINE_POSIX_MATH_CALLBACK(double, exp) - DEFINE_POSIX_MATH_CALLBACK(double, log) - DEFINE_POSIX_MATH_CALLBACK(double, round) - DEFINE_POSIX_MATH_CALLBACK(double, sin) - DEFINE_POSIX_MATH_CALLBACK(double, sinh) - DEFINE_POSIX_MATH_CALLBACK(double, tan) - DEFINE_POSIX_MATH_CALLBACK(double, tanh) - - DEFINE_POSIX_MATH_CALLBACK(float, acosf) - DEFINE_POSIX_MATH_CALLBACK(float, acoshf) - DEFINE_POSIX_MATH_CALLBACK(float, asinf) - DEFINE_POSIX_MATH_CALLBACK(float, asinhf) - DEFINE_POSIX_MATH_CALLBACK(float, atanf) - DEFINE_POSIX_MATH_CALLBACK(float, atanhf) - DEFINE_POSIX_MATH_CALLBACK(float, cosf) - DEFINE_POSIX_MATH_CALLBACK(float, coshf) - DEFINE_POSIX_MATH_CALLBACK(float, expf) - DEFINE_POSIX_MATH_CALLBACK(float, logf) - DEFINE_POSIX_MATH_CALLBACK(float, roundf) - DEFINE_POSIX_MATH_CALLBACK(float, sinf) - DEFINE_POSIX_MATH_CALLBACK(float, sinhf) - DEFINE_POSIX_MATH_CALLBACK(float, tanf) - DEFINE_POSIX_MATH_CALLBACK(float, tanhf) - - DEFINE_POSIX_MATH_CALLBACK2(float, atan2f) - DEFINE_POSIX_MATH_CALLBACK2(double, atan2) - DEFINE_POSIX_MATH_CALLBACK2(float, fminf) - DEFINE_POSIX_MATH_CALLBACK2(double, fmin) - DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf) - DEFINE_POSIX_MATH_CALLBACK2(double, fmax) - DEFINE_POSIX_MATH_CALLBACK2(float, powf) - DEFINE_POSIX_MATH_CALLBACK2(double, pow) - }; + DEFINE_POSIX_MATH_CALLBACK(double, acos), + DEFINE_POSIX_MATH_CALLBACK(double, acosh), + DEFINE_POSIX_MATH_CALLBACK(double, asin), + DEFINE_POSIX_MATH_CALLBACK(double, asinh), + DEFINE_POSIX_MATH_CALLBACK(double, atan), + DEFINE_POSIX_MATH_CALLBACK(double, atanh), + DEFINE_POSIX_MATH_CALLBACK(double, cos), + DEFINE_POSIX_MATH_CALLBACK(double, cosh), + DEFINE_POSIX_MATH_CALLBACK(double, exp), + DEFINE_POSIX_MATH_CALLBACK(double, log), + DEFINE_POSIX_MATH_CALLBACK(double, round), + DEFINE_POSIX_MATH_CALLBACK(double, sin), + DEFINE_POSIX_MATH_CALLBACK(double, sinh), + DEFINE_POSIX_MATH_CALLBACK(double, tan), + DEFINE_POSIX_MATH_CALLBACK(double, tanh), + + DEFINE_POSIX_MATH_CALLBACK(float, acosf), + DEFINE_POSIX_MATH_CALLBACK(float, acoshf), + DEFINE_POSIX_MATH_CALLBACK(float, asinf), + DEFINE_POSIX_MATH_CALLBACK(float, asinhf), + DEFINE_POSIX_MATH_CALLBACK(float, atanf), + DEFINE_POSIX_MATH_CALLBACK(float, atanhf), + DEFINE_POSIX_MATH_CALLBACK(float, cosf), + DEFINE_POSIX_MATH_CALLBACK(float, coshf), + DEFINE_POSIX_MATH_CALLBACK(float, expf), + DEFINE_POSIX_MATH_CALLBACK(float, logf), + DEFINE_POSIX_MATH_CALLBACK(float, roundf), + DEFINE_POSIX_MATH_CALLBACK(float, sinf), + DEFINE_POSIX_MATH_CALLBACK(float, sinhf), + DEFINE_POSIX_MATH_CALLBACK(float, tanf), + DEFINE_POSIX_MATH_CALLBACK(float, tanhf), + + DEFINE_POSIX_MATH_CALLBACK2(float, atan2f), + DEFINE_POSIX_MATH_CALLBACK2(double, atan2), + DEFINE_POSIX_MATH_CALLBACK2(float, fminf), + DEFINE_POSIX_MATH_CALLBACK2(double, fmin), + DEFINE_POSIX_MATH_CALLBACK2(float, fmaxf), + DEFINE_POSIX_MATH_CALLBACK2(double, fmax), + DEFINE_POSIX_MATH_CALLBACK2(float, powf), + DEFINE_POSIX_MATH_CALLBACK2(double, pow)}; return m; } @@ -2215,8 +2210,6 @@ const HostCallbackMap &get_host_callback_map() { #undef DEFINE_POSIX_MATH_CALLBACK #undef DEFINE_POSIX_MATH_CALLBACK2 -// clang-format on - #endif // WITH_WABT || WITH_V8 struct WasmModuleContents { @@ -2259,22 +2252,22 @@ struct WasmModuleContents { ~WasmModuleContents() = default; }; -// clang-format off WasmModuleContents::WasmModuleContents( const Module &halide_module, const std::vector &arguments, const std::string &fn_name, const std::map &jit_externs, const std::vector &extern_deps) - : target(halide_module.target()) - , arguments(arguments) - , jit_externs(jit_externs) - , extern_deps(extern_deps) - , trampolines(JITModule::make_trampolines_module(get_host_target(), jit_externs, kTrampolineSuffix, extern_deps)) + : target(halide_module.target()), // + arguments(arguments), // + jit_externs(jit_externs), // + extern_deps(extern_deps), // + trampolines(JITModule::make_trampolines_module(get_host_target(), jit_externs, kTrampolineSuffix, extern_deps)) #if WITH_WABT - , store(wabt::interp::Store(calc_features(halide_module.target()))) + , + store(wabt::interp::Store(calc_features(halide_module.target()))) #endif -// clang-format on + { #if WITH_WABT || WITH_V8 diff --git a/src/runtime/HalideBuffer.h b/src/runtime/HalideBuffer.h index 05a68ead98b4..9366db58d566 100644 --- a/src/runtime/HalideBuffer.h +++ b/src/runtime/HalideBuffer.h @@ -64,55 +64,53 @@ static_assert(((HALIDE_RUNTIME_BUFFER_ALLOCATION_ALIGNMENT & (HALIDE_RUNTIME_BUF // we found supports the former but not the latter.) #ifndef HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC -// clang-format off #ifdef _WIN32 - // Windows (regardless of which compiler) doesn't implement aligned_alloc(), - // even in C++17 mode, and has stated they probably never will, as the issue - // is in the incompatibility that free() needs to be able to free both pointers - // returned by malloc() and aligned_alloc(). So, always default it off here. - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 +// Windows (regardless of which compiler) doesn't implement aligned_alloc(), +// even in C++17 mode, and has stated they probably never will, as the issue +// is in the incompatibility that free() needs to be able to free both pointers +// returned by malloc() and aligned_alloc(). So, always default it off here. +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 #elif defined(__ANDROID_API__) && __ANDROID_API__ < 28 - // Android doesn't provide aligned_alloc until API 28 - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 +// Android doesn't provide aligned_alloc until API 28 +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 #elif defined(__APPLE__) - #if TARGET_OS_OSX && (__MAC_OS_X_VERSION_MIN_REQUIRED < __MAC_10_15) +#if TARGET_OS_OSX && (__MAC_OS_X_VERSION_MIN_REQUIRED < __MAC_10_15) - // macOS doesn't provide aligned_alloc until 10.15 - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 +// macOS doesn't provide aligned_alloc until 10.15 +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 - #elif TARGET_OS_IPHONE && (__IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_14_0) +#elif TARGET_OS_IPHONE && (__IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_14_0) - // iOS doesn't provide aligned_alloc until 14.0 - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 +// iOS doesn't provide aligned_alloc until 14.0 +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 - #else +#else - // Assume it's ok on all other Apple targets - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 1 +// Assume it's ok on all other Apple targets +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 1 - #endif +#endif #else - #if defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) +#if defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC) - // ARM GNU-A baremetal compiler doesn't provide aligned_alloc as of 12.2 - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 +// ARM GNU-A baremetal compiler doesn't provide aligned_alloc as of 12.2 +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 0 - #else +#else - // Not Windows, Android, or Apple: just assume it's ok - #define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 1 +// Not Windows, Android, or Apple: just assume it's ok +#define HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC 1 - #endif +#endif #endif -// clang-format on #endif // HALIDE_RUNTIME_BUFFER_USE_ALIGNED_ALLOC diff --git a/src/runtime/mini_qurt.h b/src/runtime/mini_qurt.h index 0ef1840b6771..e7100126085b 100644 --- a/src/runtime/mini_qurt.h +++ b/src/runtime/mini_qurt.h @@ -21,8 +21,6 @@ typedef unsigned int qurt_thread_t; Macros for QuRT thread attributes. */ -// clang-format off - /** * \defgroup qurt_thread_macros QURT threading macros * @{ @@ -46,8 +44,6 @@ typedef unsigned int qurt_thread_t; #define QURT_THREAD_ATTR_TIMETEST_ID_DEFAULT (-2) /**< */ /** @} */ -// clang-format on - /** Thread attributes */ typedef struct _qurt_thread_attr { /** @cond */ diff --git a/src/runtime/mini_webgpu.h b/src/runtime/mini_webgpu.h index 3d6bf862f0b7..f43dbc46ea10 100644 --- a/src/runtime/mini_webgpu.h +++ b/src/runtime/mini_webgpu.h @@ -33,7 +33,7 @@ #ifndef WEBGPU_H_ #define WEBGPU_H_ -// clang-format off + #if defined(WGPU_SHARED_LIBRARY) # if defined(_WIN32) @@ -2628,6 +2628,6 @@ WGPU_EXPORT void wgpuTextureViewRelease(WGPUTextureView textureView) WGPU_FUNCTI } // extern "C" #endif -// clang-format on + #endif // WEBGPU_H_ diff --git a/src/runtime/profiler_inlined.cpp b/src/runtime/profiler_inlined.cpp index fc5bda1b1bcc..325a7169e46d 100644 --- a/src/runtime/profiler_inlined.cpp +++ b/src/runtime/profiler_inlined.cpp @@ -9,11 +9,10 @@ WEAK_INLINE int halide_profiler_set_current_func(halide_profiler_instance_state // Use empty volatile asm blocks to prevent code motion. Otherwise // llvm reorders or elides the stores. volatile int *ptr = &(instance->current_func); - // clang-format off - asm volatile ("":::); + + asm volatile("" :::); *ptr = func; - asm volatile ("":::); - // clang-format on + asm volatile("" :::); } return 0; } diff --git a/src/runtime/vulkan_interface.h b/src/runtime/vulkan_interface.h index 6dd6888fb54e..d93f2aebb060 100644 --- a/src/runtime/vulkan_interface.h +++ b/src/runtime/vulkan_interface.h @@ -69,7 +69,7 @@ extern "C" WEAK void *halide_vulkan_get_symbol(void *user_context, const char *n } // Declare all the function pointers for the Vulkan API methods that will be resolved dynamically -// clang-format off + #define VULKAN_FN(fn) WEAK PFN_##fn fn = nullptr; #define HL_USE_VULKAN_LOADER_FNS #define HL_USE_VULKAN_INSTANCE_FNS @@ -79,7 +79,6 @@ extern "C" WEAK void *halide_vulkan_get_symbol(void *user_context, const char *n #undef HL_USE_VULKAN_INSTANCE_FNS #undef HL_USE_VULKAN_LOADER_FNS #undef VULKAN_FN -// clang-format on // Get the function pointers to the Vulkan loader (to find all available instances) void WEAK vk_load_vulkan_loader_functions(void *user_context) { diff --git a/test/correctness/float16_t_neon_op_check.cpp b/test/correctness/float16_t_neon_op_check.cpp index 33d2541cbd4a..e67883dfcba3 100644 --- a/test/correctness/float16_t_neon_op_check.cpp +++ b/test/correctness/float16_t_neon_op_check.cpp @@ -68,12 +68,10 @@ class SimdOpCheck : public SimdOpCheckTest { std::vector> vl_params; Expr f_1, f_2, f_3, u_1, i_1; }; - // clang-format off + TestParams test_params[2] = { - {32, in_f32, {{1, "s"}, {2, ".2s"}, {4, ".4s"}, { 8, ".4s"}}, f32_1, f32_2, f32_3, u32_1, i32_1}, - {16, in_f16, {{1, "h"}, {4, ".4h"}, {8, ".8h"}, {16, ".8h"}}, f16_1, f16_2, f16_3, u16_1, i16_1} - }; - // clang-format on + {32, in_f32, {{1, "s"}, {2, ".2s"}, {4, ".4s"}, {8, ".4s"}}, f32_1, f32_2, f32_3, u32_1, i32_1}, + {16, in_f16, {{1, "h"}, {4, ".4h"}, {8, ".8h"}, {16, ".8h"}}, f16_1, f16_2, f16_3, u16_1, i16_1}}; for (auto &test_param : test_params) { // outer loop for {fp32, fp16} const int bits = test_param.bits; diff --git a/test/correctness/math.cpp b/test/correctness/math.cpp index 68ff3c0e56e8..25864a0c3314 100644 --- a/test/correctness/math.cpp +++ b/test/correctness/math.cpp @@ -184,54 +184,50 @@ struct TestArgs { } \ } -// clang-format off - -#define fun_1_float_types(name) \ - fun_1(float, float, name, name) \ - fun_1(double, double, name, name) - -#define fun_2_float_types(name) \ - fun_2(float, float, name, name) \ - fun_2(double, double, name, name) - -fun_1_float_types(sqrt) -fun_1_float_types(sin) -fun_1_float_types(cos) -fun_1_float_types(exp) -fun_1_float_types(log) -fun_1_float_types(floor) -fun_1_float_types(ceil) -fun_1_float_types(trunc) -fun_1_float_types(asin) -fun_1_float_types(acos) -fun_1_float_types(tan) -fun_1_float_types(atan) -fun_1_float_types(sinh) -fun_1_float_types(cosh) -fun_1_float_types(tanh) -fun_1_float_types(asinh) -fun_1_float_types(acosh) -fun_1_float_types(atanh) -fun_1_float_types(round) - -fun_2_float_types(pow) -fun_2_float_types(atan2) - -fun_1(float, float, abs, fabsf) -fun_1(double, double, abs, fabs) -fun_1(uint8_t, int8_t, abs, abs) -fun_1(uint16_t, int16_t, abs, abs) -fun_1(uint32_t, int32_t, abs, abs) - -fun_2_float_types(absd) -fun_2(uint8_t, int8_t, absd, absd) -fun_2(uint16_t, int16_t, absd, absd) -fun_2(uint32_t, int32_t, absd, absd) -fun_2(uint8_t, uint8_t, absd, absd) -fun_2(uint16_t, uint16_t, absd, absd) -fun_2(uint32_t, uint32_t, absd, absd) - -// clang-format on +#define fun_1_float_types(name) \ + fun_1(float, float, name, name) \ + fun_1(double, double, name, name) + +#define fun_2_float_types(name) \ + fun_2(float, float, name, name) \ + fun_2(double, double, name, name) + +fun_1_float_types(sqrt); +fun_1_float_types(sin); +fun_1_float_types(cos); +fun_1_float_types(exp); +fun_1_float_types(log); +fun_1_float_types(floor); +fun_1_float_types(ceil); +fun_1_float_types(trunc); +fun_1_float_types(asin); +fun_1_float_types(acos); +fun_1_float_types(tan); +fun_1_float_types(atan); +fun_1_float_types(sinh); +fun_1_float_types(cosh); +fun_1_float_types(tanh); +fun_1_float_types(asinh); +fun_1_float_types(acosh); +fun_1_float_types(atanh); +fun_1_float_types(round); + +fun_2_float_types(pow); +fun_2_float_types(atan2); + +fun_1(float, float, abs, fabsf); +fun_1(double, double, abs, fabs); +fun_1(uint8_t, int8_t, abs, abs); +fun_1(uint16_t, int16_t, abs, abs); +fun_1(uint32_t, int32_t, abs, abs); + +fun_2_float_types(absd); +fun_2(uint8_t, int8_t, absd, absd); +fun_2(uint16_t, int16_t, absd, absd); +fun_2(uint32_t, int32_t, absd, absd); +fun_2(uint8_t, uint8_t, absd, absd); +fun_2(uint16_t, uint16_t, absd, absd); +fun_2(uint32_t, uint32_t, absd, absd); // Note this test is more oriented toward making sure the paths // through to math functions all work on a given target rather diff --git a/test/correctness/simd_op_check_sve2.cpp b/test/correctness/simd_op_check_sve2.cpp index 5dda276e8c80..cc2eae5c0f80 100644 --- a/test/correctness/simd_op_check_sve2.cpp +++ b/test/correctness/simd_op_check_sve2.cpp @@ -69,16 +69,16 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { private: void check_arm_integer() { - // clang-format off + vector> test_params{ - {8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat}, - {16, in_i16, in_u16, in_f16, in_i32, in_u32, i16, i16_sat, i32, i8, i8_sat, u16, u16_sat, u32, u8, u8_sat}, - {32, in_i32, in_u32, in_f32, in_i64, in_u64, i32, i32_sat, i64, i16, i16_sat, u32, u32_sat, u64, u16, u16_sat}, - {64, in_i64, in_u64, in_f64, in_i64, in_u64, i64, i64_sat, i64, i32, i32_sat, u64, u64_sat, u64, u32, u32_sat}, - }; - // clang-format on + CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> + test_params{ + {8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat}, + {16, in_i16, in_u16, in_f16, in_i32, in_u32, i16, i16_sat, i32, i8, i8_sat, u16, u16_sat, u32, u8, u8_sat}, + {32, in_i32, in_u32, in_f32, in_i64, in_u64, i32, i32_sat, i64, i16, i16_sat, u32, u32_sat, u64, u16, u16_sat}, + {64, in_i64, in_u64, in_f64, in_i64, in_u64, i64, i64_sat, i64, i32, i32_sat, u64, u64_sat, u64, u32, u32_sat}, + }; for (const auto &[bits, in_i, in_u, in_f, in_i_wide, in_u_wide, cast_i, satcast_i, widen_i, narrow_i, satnarrow_i, @@ -874,7 +874,6 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { {32, in_i32, in_u32, i64, i64, u64, u64}, {64, in_i64, in_u64, i64, i64, u64, u64}, }; - // clang-format on for (const auto &[bits, in_i, in_u, widen_i, widenx4_i, widen_u, widenx4_u] : test_params) { @@ -975,13 +974,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // Tests for Float type { - // clang-format off + vector> test_params{ {16, in_f16}, {32, in_f32}, {64, in_f64}, }; - // clang-format on + if (!has_sve()) { for (const auto &[bits, in_f] : test_params) { for (auto &total_bits : {64, 128}) { diff --git a/test/generator/metadata_tester_aottest.cpp b/test/generator/metadata_tester_aottest.cpp index 6aa2e7243ae4..92d3e392840a 100644 --- a/test/generator/metadata_tester_aottest.cpp +++ b/test/generator/metadata_tester_aottest.cpp @@ -1364,7 +1364,6 @@ constexpr size_t count_buffers(const std::array<::HalideFunctionInfo::ArgumentIn return buffer_count; } -// clang-format off constexpr char arginfo_to_sigchar(::HalideFunctionInfo::ArgumentInfo arg) { if (arg.kind == HalideFunctionInfo::InputBuffer) { return '@'; @@ -1372,9 +1371,9 @@ constexpr char arginfo_to_sigchar(::HalideFunctionInfo::ArgumentInfo arg) { return '#'; } else { - #define HANDLE_CASE(CODE, BITS, CHAR) \ - case halide_type_t(CODE, BITS).as_u32(): \ - return (CHAR); +#define HANDLE_CASE(CODE, BITS, CHAR) \ + case halide_type_t(CODE, BITS).as_u32(): \ + return (CHAR); switch (arg.type.as_u32()) { HANDLE_CASE(halide_type_bfloat, 16, '!') @@ -1393,13 +1392,12 @@ constexpr char arginfo_to_sigchar(::HalideFunctionInfo::ArgumentInfo arg) { HANDLE_CASE(halide_type_handle, 64, 'P') } - #undef HANDLE_CASE +#undef HANDLE_CASE } // Shouldn't ever get here, but if we do, we'll fail at *compile* time abort(); } -// clang-format on template constexpr std::array compute_signature_impl(const std::array<::HalideFunctionInfo::ArgumentInfo, arg_count> args,