From 76a7dd4f7fb538deaf7c2ade56c02bc84e5221e8 Mon Sep 17 00:00:00 2001 From: Zalman Stern Date: Fri, 15 Mar 2024 13:01:51 -0700 Subject: [PATCH] Support for ARM SVE2. (#8051) * Checkpoint SVE2 restart. * Remove dead code. Add new test. * Update cmake for new file. * Checkpoint progress on SVE2. * Checkpoint ARM SVE2 support. Passes correctness_simd_op_check_sve2 test at 128 and 256 bits. * Remove an opportunity for RISC V codegen to change due to SVE2 support. * Ensure SVE intrinsics get vscale vectors and non-SVE ones get fixed vectors. Use proper prefix for neon intrinsics. Comment cleanups. * Checkpoint SVE2 work. Generally passes test, though using both NEON and SVE2 with simd_op_check_sve2 fails as both posibilities need to be allowed for 128-bit or smaller operations. * Remove an unfavored implementation possibility. * Fix opcode recognition in test to handle some cases that show up. Change name of test class to avoid confusion. * Formatting fixes. Replace internal_error with nop return for CodeGen_LLVM::match_vector_type_scalable called on scalar. * Formatting fix. * Limit SVE2 test to LLVM 19. Remove dead code. * Fix a degenerate case asking for zero sized vectors via a HAlide type with lanes of zero, which is not correct. * Fix confusion about Neon64/Neon128 and make it clear this is just the width multiplier applied to intrinsics. * REmove extraneous commented out line. * Address some review feedback. Mostly comment fixes. * Fix missed conflict resolution. * Fix some TODOs in SVE code. Move utility function to Util.h and common code the other obvious use. * Formatting. * Add missed refactor change. * Add issue to TODO comment. * Remove TODOs that don't seem necessary. * Add issue for TODO. * Add issue for TODO. * Remove dubious looking FP to int code that was ifdef'ed out. Doesn't look like a TODO is needed anymore. * Add issues for TODOs. * Update simd_op_check_sve2.cpp * Make a deep copy of each piece of test IR so that we can parallelize * Fix two clang-tidy warnings * Remove try/catch block from simd-op-check-sve2 * Don't try to run SVE2 code if vector_bits doesn't match host. * Add support for fcvtm/p, make scalars go through pattern matching too (#8151) * Don't do arm neon instruction selection on scalars This revealed a bug. FindIntrinsics was not enabled for scalars anyway, so it was semi-pointless. --------- Co-authored-by: Zalman Stern Co-authored-by: Steven Johnson Co-authored-by: Andrew Adams --- src/CodeGen_ARM.cpp | 1388 ++++++++++++++++++----- src/CodeGen_LLVM.cpp | 230 +++- src/CodeGen_LLVM.h | 7 + src/Function.cpp | 6 +- src/IR.cpp | 1 + src/IR.h | 2 + src/IRMatch.cpp | 3 + src/LLVM_Output.cpp | 6 + src/StorageFolding.cpp | 5 +- src/Util.h | 11 + src/WasmExecutor.cpp | 11 +- src/runtime/HalideRuntime.h | 6 +- src/runtime/aarch64.ll | 76 +- src/runtime/errors.cpp | 8 + src/runtime/posix_math.ll | 28 +- src/runtime/runtime_api.cpp | 1 + test/correctness/CMakeLists.txt | 1 + test/correctness/simd_op_check_arm.cpp | 7 + test/correctness/simd_op_check_sve2.cpp | 1387 ++++++++++++++++++++++ 19 files changed, 2836 insertions(+), 348 deletions(-) create mode 100644 test/correctness/simd_op_check_sve2.cpp diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 7852532183bf..d0538d6ccca8 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -105,17 +105,30 @@ class CodeGen_ARM : public CodeGen_Posix { CodeGen_ARM(const Target &); protected: + using codegen_func_t = std::function &)>; using CodeGen_Posix::visit; - /** Assuming 'inner' is a function that takes two vector arguments, define a wrapper that - * takes one vector argument and splits it into two to call inner. */ - llvm::Function *define_concat_args_wrapper(llvm::Function *inner, const string &name); + /** Similar to llvm_type_of, but allows providing a VectorTypeConstraint to + * force Fixed or VScale vector results. */ + llvm::Type *llvm_type_with_constraint(const Type &t, bool scalars_are_vectors, VectorTypeConstraint constraint); + + /** Define a wrapper LLVM func that takes some arguments which Halide defines + * and call inner LLVM intrinsic with an additional argument which LLVM requires. */ + llvm::Function *define_intrin_wrapper(const std::string &inner_name, + const Type &ret_type, + const std::string &mangled_name, + const std::vector &arg_types, + int intrinsic_flags, + bool sve_intrinsic); void init_module() override; void compile_func(const LoweredFunc &f, const std::string &simple_name, const std::string &extern_name) override; - /** Nodes for which we want to emit specific neon intrinsics */ + void begin_func(LinkageType linkage, const std::string &simple_name, + const std::string &extern_name, const std::vector &args) override; + + /** Nodes for which we want to emit specific ARM vector intrinsics */ // @{ void visit(const Cast *) override; void visit(const Add *) override; @@ -125,15 +138,25 @@ class CodeGen_ARM : public CodeGen_Posix { void visit(const Store *) override; void visit(const Load *) override; void visit(const Shuffle *) override; + void visit(const Ramp *) override; void visit(const Call *) override; void visit(const LT *) override; void visit(const LE *) override; void codegen_vector_reduce(const VectorReduce *, const Expr &) override; + bool codegen_dot_product_vector_reduce(const VectorReduce *, const Expr &); + bool codegen_pairwise_vector_reduce(const VectorReduce *, const Expr &); + bool codegen_across_vector_reduce(const VectorReduce *, const Expr &); // @} Type upgrade_type_for_arithmetic(const Type &t) const override; Type upgrade_type_for_argument_passing(const Type &t) const override; Type upgrade_type_for_storage(const Type &t) const override; + /** Helper function to perform codegen of vector operation in a way that + * total_lanes are divided into slices, codegen is performed for each slice + * and results are concatenated into total_lanes. + */ + Value *codegen_with_lanes(int slice_lanes, int total_lanes, const std::vector &args, codegen_func_t &cg_func); + /** Various patterns to peephole match against */ struct Pattern { string intrin; ///< Name of the intrinsic @@ -150,10 +173,12 @@ class CodeGen_ARM : public CodeGen_Posix { string mattrs() const override; bool use_soft_float_abi() const override; int native_vector_bits() const override; + int target_vscale() const override; // NEON can be disabled for older processors. - bool neon_intrinsics_disabled() { - return target.has_feature(Target::NoNEON); + bool simd_intrinsics_disabled() { + return target.has_feature(Target::NoNEON) && + !target.has_feature(Target::SVE2); } bool is_float16_and_has_feature(const Type &t) const { @@ -161,11 +186,28 @@ class CodeGen_ARM : public CodeGen_Posix { return t.code() == Type::Float && t.bits() == 16 && target.has_feature(Target::ARMFp16); } bool supports_call_as_float16(const Call *op) const override; + + /** Make predicate vector which starts with consecutive true followed by consecutive false */ + Expr make_vector_predicate_1s_0s(int true_lanes, int false_lanes) { + internal_assert((true_lanes + false_lanes) != 0) << "CodeGen_ARM::make_vector_predicate_1s_0s called with total of 0 lanes.\n"; + if (true_lanes == 0) { + return const_false(false_lanes); + } else if (false_lanes == 0) { + return const_true(true_lanes); + } else { + return Shuffle::make_concat({const_true(true_lanes), const_false(false_lanes)}); + } + } }; CodeGen_ARM::CodeGen_ARM(const Target &target) : CodeGen_Posix(target) { + // TODO(https://github.com/halide/Halide/issues/8088): See if + // use_llvm_vp_intrinsics can replace architecture specific code in this + // file, specifically in Load and Store visitors. Depends on quality of + // LLVM aarch64 backend lowering for these intrinsics on SVE2. + // RADDHN - Add and narrow with rounding // These must come before other narrowing rounding shift patterns casts.emplace_back("rounding_add_narrow", i8(rounding_shift_right(wild_i16x_ + wild_i16x_, 8))); @@ -211,6 +253,12 @@ CodeGen_ARM::CodeGen_ARM(const Target &target) casts.emplace_back("shift_right_narrow", i32(wild_i64x_ >> wild_u64_)); casts.emplace_back("shift_right_narrow", u32(wild_u64x_ >> wild_u64_)); + // VCVTP/M + casts.emplace_back("fp_to_int_floor", i32(floor(wild_f32x_))); + casts.emplace_back("fp_to_int_floor", u32(floor(wild_f32x_))); + casts.emplace_back("fp_to_int_ceil", i32(ceil(wild_f32x_))); + casts.emplace_back("fp_to_int_ceil", u32(ceil(wild_f32x_))); + // SQRSHL, UQRSHL - Saturating rounding shift left (by signed vector) // TODO: We need to match rounding shift right, and negate the RHS. @@ -299,26 +347,66 @@ struct ArmIntrinsic { SplitArg0 = 1 << 6, // This intrinsic requires splitting the argument into the low and high halves. NoPrefix = 1 << 7, // Don't prefix the intrinsic with llvm.* RequireFp16 = 1 << 8, // Available only if Target has ARMFp16 feature + Neon64Unavailable = 1 << 9, // Unavailable for 64 bit NEON + SveUnavailable = 1 << 10, // Unavailable for SVE + SveNoPredicate = 1 << 11, // In SVE intrinsics, additional predicate argument is required as default, unless this flag is set. + SveInactiveArg = 1 << 12, // This intrinsic needs the additional argument for fallback value for the lanes inactivated by predicate. + SveRequired = 1 << 13, // This intrinsic requires SVE. }; }; // clang-format off const ArmIntrinsic intrinsic_defs[] = { - {"vabs", "abs", UInt(8, 8), "abs", {Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vabs", "abs", UInt(16, 4), "abs", {Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vabs", "abs", UInt(32, 2), "abs", {Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"llvm.fabs", "llvm.fabs", Float(32, 2), "abs", {Float(32, 2)}, ArmIntrinsic::HalfWidth}, - {"llvm.fabs", "llvm.fabs", Float(16, 4), "abs", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, - - {"llvm.sqrt", "llvm.sqrt", Float(32, 2), "sqrt_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth}, - {"llvm.sqrt", "llvm.sqrt", Float(64, 2), "sqrt_f64", {Float(64, 2)}}, - - {"llvm.roundeven", "llvm.roundeven", Float(16, 8), "round", {Float(16, 8)}, ArmIntrinsic::RequireFp16}, - {"llvm.roundeven", "llvm.roundeven", Float(32, 4), "round", {Float(32, 4)}}, - {"llvm.roundeven", "llvm.roundeven", Float(64, 2), "round", {Float(64, 2)}}, - {"llvm.roundeven.f16", "llvm.roundeven.f16", Float(16), "round", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle}, - {"llvm.roundeven.f32", "llvm.roundeven.f32", Float(32), "round", {Float(32)}, ArmIntrinsic::NoMangle}, - {"llvm.roundeven.f64", "llvm.roundeven.f64", Float(64), "round", {Float(64)}, ArmIntrinsic::NoMangle}, + // TODO(https://github.com/halide/Halide/issues/8093): + // Some of the Arm intrinsic have the same name between Neon and SVE2 but with different behavior. For example, + // widening, narrowing and pair-wise operations which are performed in even (top) and odd (bottom) lanes basis in SVE, + // while in high and low lanes in Neon. Therefore, peep-hole code-gen with those SVE2 intrinsic is not enabled for now, + // because additional interleaving/deinterleaveing would be required to restore the element order in a vector. + + {"vabs", "abs", UInt(8, 8), "abs", {Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"vabs", "abs", UInt(16, 4), "abs", {Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"vabs", "abs", UInt(32, 2), "abs", {Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"llvm.fabs", "llvm.fabs", Float(16, 4), "abs", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs", "llvm.fabs", Float(32, 2), "abs", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs", "llvm.fabs", Float(64, 2), "abs", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs.f16", "llvm.fabs.f16", Float(16), "abs", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs.f32", "llvm.fabs.f32", Float(32), "abs", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.fabs.f64", "llvm.fabs.f64", Float(64), "abs", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + + {"llvm.sqrt", "llvm.sqrt", Float(16, 4), "sqrt_f16", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.sqrt", "llvm.sqrt", Float(32, 2), "sqrt_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.sqrt", "llvm.sqrt", Float(64, 2), "sqrt_f64", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.sqrt.f16", "llvm.sqrt.f16", Float(16), "sqrt_f16", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.sqrt.f32", "llvm.sqrt.f32", Float(32), "sqrt_f32", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.sqrt.f64", "llvm.sqrt.f64", Float(64), "sqrt_f64", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + + {"llvm.floor", "llvm.floor", Float(16, 4), "floor_f16", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.floor", "llvm.floor", Float(32, 2), "floor_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.floor", "llvm.floor", Float(64, 2), "floor_f64", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.floor.f16", "llvm.floor.f16", Float(16), "floor_f16", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.floor.f32", "llvm.floor.f32", Float(32), "floor_f32", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.floor.f64", "llvm.floor.f64", Float(64), "floor_f64", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + + {"llvm.ceil", "llvm.ceil", Float(16, 4), "ceil_f16", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.ceil", "llvm.ceil", Float(32, 2), "ceil_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.ceil", "llvm.ceil", Float(64, 2), "ceil_f64", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.ceil.f16", "llvm.ceil.f16", Float(16), "ceil_f16", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.ceil.f32", "llvm.ceil.f32", Float(32), "ceil_f32", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.ceil.f64", "llvm.ceil.f64", Float(64), "ceil_f64", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + + {"llvm.trunc", "llvm.trunc", Float(16, 4), "trunc_f16", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.trunc", "llvm.trunc", Float(32, 2), "trunc_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.trunc", "llvm.trunc", Float(64, 2), "trunc_f64", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.trunc.f16", "llvm.trunc.f16", Float(16), "trunc_f16", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.trunc.f32", "llvm.trunc.f32", Float(32), "trunc_f32", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.trunc.f64", "llvm.trunc.f64", Float(64), "trunc_f64", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + + {"llvm.roundeven", "llvm.roundeven", Float(16, 4), "round", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveNoPredicate}, + {"llvm.roundeven", "llvm.roundeven", Float(32, 2), "round", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"llvm.roundeven", "llvm.roundeven", Float(64, 2), "round", {Float(64, 2)}, ArmIntrinsic::SveNoPredicate}, + {"llvm.roundeven.f16", "llvm.roundeven.f16", Float(16), "round", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.roundeven.f32", "llvm.roundeven.f32", Float(32), "round", {Float(32)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, + {"llvm.roundeven.f64", "llvm.roundeven.f64", Float(64), "round", {Float(64)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate}, // SABD, UABD - Absolute difference {"vabds", "sabd", UInt(8, 8), "absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, @@ -329,12 +417,12 @@ const ArmIntrinsic intrinsic_defs[] = { {"vabdu", "uabd", UInt(32, 2), "absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::HalfWidth}, // SMULL, UMULL - Widening multiply - {"vmulls", "smull", Int(16, 8), "widening_mul", {Int(8, 8), Int(8, 8)}}, - {"vmullu", "umull", UInt(16, 8), "widening_mul", {UInt(8, 8), UInt(8, 8)}}, - {"vmulls", "smull", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}}, - {"vmullu", "umull", UInt(32, 4), "widening_mul", {UInt(16, 4), UInt(16, 4)}}, - {"vmulls", "smull", Int(64, 2), "widening_mul", {Int(32, 2), Int(32, 2)}}, - {"vmullu", "umull", UInt(64, 2), "widening_mul", {UInt(32, 2), UInt(32, 2)}}, + {"vmulls", "smull", Int(16, 8), "widening_mul", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::SveUnavailable}, + {"vmullu", "umull", UInt(16, 8), "widening_mul", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::SveUnavailable}, + {"vmulls", "smull", Int(32, 4), "widening_mul", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::SveUnavailable}, + {"vmullu", "umull", UInt(32, 4), "widening_mul", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::SveUnavailable}, + {"vmulls", "smull", Int(64, 2), "widening_mul", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::SveUnavailable}, + {"vmullu", "umull", UInt(64, 2), "widening_mul", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::SveUnavailable}, // SQADD, UQADD - Saturating add // On arm32, the ARM version of this seems to be missing on some configurations. @@ -385,12 +473,30 @@ const ArmIntrinsic intrinsic_defs[] = { {"vminu", "umin", UInt(16, 4), "min", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::HalfWidth}, {"vmins", "smin", Int(32, 2), "min", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, {"vminu", "umin", UInt(32, 2), "min", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vmins", "fmin", Float(32, 2), "min", {Float(32, 2), Float(32, 2)}, ArmIntrinsic::HalfWidth}, + {nullptr, "smin", Int(64, 2), "min", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::Neon64Unavailable}, + {nullptr, "umin", UInt(64, 2), "min", {UInt(64, 2), UInt(64, 2)}, ArmIntrinsic::Neon64Unavailable}, {"vmins", "fmin", Float(16, 4), "min", {Float(16, 4), Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, + {"vmins", "fmin", Float(32, 2), "min", {Float(32, 2), Float(32, 2)}, ArmIntrinsic::HalfWidth}, + {nullptr, "fmin", Float(64, 2), "min", {Float(64, 2), Float(64, 2)}}, // FCVTZS, FCVTZU - {nullptr, "fcvtzs", Int(16, 4), "fp_to_int", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::RequireFp16}, - {nullptr, "fcvtzu", UInt(16, 4), "fp_to_int", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::RequireFp16}, + {nullptr, "fcvtzs", Int(16, 4), "fp_to_int", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveInactiveArg}, + {nullptr, "fcvtzu", UInt(16, 4), "fp_to_int", {Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveInactiveArg}, + {nullptr, "fcvtzs", Int(32, 2), "fp_to_int", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveInactiveArg}, + {nullptr, "fcvtzu", UInt(32, 2), "fp_to_int", {Float(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveInactiveArg}, + {nullptr, "fcvtzs", Int(64, 2), "fp_to_int", {Float(64, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveInactiveArg}, + {nullptr, "fcvtzu", UInt(64, 2), "fp_to_int", {Float(64, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveInactiveArg}, + + // FCVTP/M. These only exist in armv8 and onwards, so we just skip them for + // arm-32. LLVM doesn't seem to have intrinsics for them for SVE. + {nullptr, "fcvtpu", UInt(32, 4), "fp_to_int_ceil", {Float(32, 4)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtmu", UInt(32, 4), "fp_to_int_floor", {Float(32, 4)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtps", Int(32, 4), "fp_to_int_ceil", {Float(32, 4)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtms", Int(32, 4), "fp_to_int_floor", {Float(32, 4)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtpu", UInt(32, 2), "fp_to_int_ceil", {Float(32, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtmu", UInt(32, 2), "fp_to_int_floor", {Float(32, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtps", Int(32, 2), "fp_to_int_ceil", {Float(32, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {nullptr, "fcvtms", Int(32, 2), "fp_to_int_floor", {Float(32, 2)}, ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, // SMAX, UMAX, FMAX - Max {"vmaxs", "smax", Int(8, 8), "max", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, @@ -399,25 +505,34 @@ const ArmIntrinsic intrinsic_defs[] = { {"vmaxu", "umax", UInt(16, 4), "max", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::HalfWidth}, {"vmaxs", "smax", Int(32, 2), "max", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, {"vmaxu", "umax", UInt(32, 2), "max", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vmaxs", "fmax", Float(32, 2), "max", {Float(32, 2), Float(32, 2)}, ArmIntrinsic::HalfWidth}, + {nullptr, "smax", Int(64, 2), "max", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::Neon64Unavailable}, + {nullptr, "umax", UInt(64, 2), "max", {UInt(64, 2), UInt(64, 2)}, ArmIntrinsic::Neon64Unavailable}, {"vmaxs", "fmax", Float(16, 4), "max", {Float(16, 4), Float(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, + {"vmaxs", "fmax", Float(32, 2), "max", {Float(32, 2), Float(32, 2)}, ArmIntrinsic::HalfWidth}, + {nullptr, "fmax", Float(64, 2), "max", {Float(64, 2), Float(64, 2)}}, + + // NEG, FNEG + {nullptr, "neg", Int(8, 16), "negate", {Int(8, 16)}, ArmIntrinsic::SveInactiveArg | ArmIntrinsic::Neon64Unavailable}, + {nullptr, "neg", Int(16, 8), "negate", {Int(16, 8)}, ArmIntrinsic::SveInactiveArg | ArmIntrinsic::Neon64Unavailable}, + {nullptr, "neg", Int(32, 4), "negate", {Int(32, 4)}, ArmIntrinsic::SveInactiveArg | ArmIntrinsic::Neon64Unavailable}, + {nullptr, "neg", Int(64, 2), "negate", {Int(64, 2)}, ArmIntrinsic::SveInactiveArg | ArmIntrinsic::Neon64Unavailable}, // SQNEG, UQNEG - Saturating negation - {"vqneg", "sqneg", Int(8, 8), "saturating_negate", {Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vqneg", "sqneg", Int(16, 4), "saturating_negate", {Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vqneg", "sqneg", Int(32, 2), "saturating_negate", {Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vqneg", "sqneg", Int(64, 2), "saturating_negate", {Int(64, 2)}}, + {"vqneg", "sqneg", Int(8, 8), "saturating_negate", {Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"vqneg", "sqneg", Int(16, 4), "saturating_negate", {Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"vqneg", "sqneg", Int(32, 2), "saturating_negate", {Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveInactiveArg}, + {"vqneg", "sqneg", Int(64, 2), "saturating_negate", {Int(64, 2)}, ArmIntrinsic::SveInactiveArg}, // SQXTN, UQXTN, SQXTUN - Saturating narrowing - {"vqmovns", "sqxtn", Int(8, 8), "saturating_narrow", {Int(16, 8)}}, - {"vqmovnu", "uqxtn", UInt(8, 8), "saturating_narrow", {UInt(16, 8)}}, - {"vqmovnsu", "sqxtun", UInt(8, 8), "saturating_narrow", {Int(16, 8)}}, - {"vqmovns", "sqxtn", Int(16, 4), "saturating_narrow", {Int(32, 4)}}, - {"vqmovnu", "uqxtn", UInt(16, 4), "saturating_narrow", {UInt(32, 4)}}, - {"vqmovnsu", "sqxtun", UInt(16, 4), "saturating_narrow", {Int(32, 4)}}, - {"vqmovns", "sqxtn", Int(32, 2), "saturating_narrow", {Int(64, 2)}}, - {"vqmovnu", "uqxtn", UInt(32, 2), "saturating_narrow", {UInt(64, 2)}}, - {"vqmovnsu", "sqxtun", UInt(32, 2), "saturating_narrow", {Int(64, 2)}}, + {"vqmovns", "sqxtn", Int(8, 8), "saturating_narrow", {Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnu", "uqxtn", UInt(8, 8), "saturating_narrow", {UInt(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnsu", "sqxtun", UInt(8, 8), "saturating_narrow", {Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqmovns", "sqxtn", Int(16, 4), "saturating_narrow", {Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnu", "uqxtn", UInt(16, 4), "saturating_narrow", {UInt(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnsu", "sqxtun", UInt(16, 4), "saturating_narrow", {Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqmovns", "sqxtn", Int(32, 2), "saturating_narrow", {Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnu", "uqxtn", UInt(32, 2), "saturating_narrow", {UInt(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqmovnsu", "sqxtun", UInt(32, 2), "saturating_narrow", {Int(64, 2)}, ArmIntrinsic::SveUnavailable}, // RSHRN - Rounding shift right narrow (by immediate in [1, output bits]) // arm32 expects a vector RHS of the same type as the LHS except signed. @@ -440,52 +555,52 @@ const ArmIntrinsic intrinsic_defs[] = { // LLVM pattern matches these. // SQRSHL, UQRSHL - Saturating rounding shift left (by signed vector) - {"vqrshifts", "sqrshl", Int(8, 8), "saturating_rounding_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vqrshiftu", "uqrshl", UInt(8, 8), "saturating_rounding_shift_left", {UInt(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vqrshifts", "sqrshl", Int(16, 4), "saturating_rounding_shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vqrshiftu", "uqrshl", UInt(16, 4), "saturating_rounding_shift_left", {UInt(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vqrshifts", "sqrshl", Int(32, 2), "saturating_rounding_shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vqrshiftu", "uqrshl", UInt(32, 2), "saturating_rounding_shift_left", {UInt(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vqrshifts", "sqrshl", Int(64, 2), "saturating_rounding_shift_left", {Int(64, 2), Int(64, 2)}}, - {"vqrshiftu", "uqrshl", UInt(64, 2), "saturating_rounding_shift_left", {UInt(64, 2), Int(64, 2)}}, + {"vqrshifts", "sqrshl", Int(8, 8), "saturating_rounding_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftu", "uqrshl", UInt(8, 8), "saturating_rounding_shift_left", {UInt(8, 8), Int(8, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqrshifts", "sqrshl", Int(16, 4), "saturating_rounding_shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftu", "uqrshl", UInt(16, 4), "saturating_rounding_shift_left", {UInt(16, 4), Int(16, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqrshifts", "sqrshl", Int(32, 2), "saturating_rounding_shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftu", "uqrshl", UInt(32, 2), "saturating_rounding_shift_left", {UInt(32, 2), Int(32, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqrshifts", "sqrshl", Int(64, 2), "saturating_rounding_shift_left", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftu", "uqrshl", UInt(64, 2), "saturating_rounding_shift_left", {UInt(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, // SQRSHRN, UQRSHRN, SQRSHRUN - Saturating rounding narrowing shift right (by immediate in [1, output bits]) // arm32 expects a vector RHS of the same type as the LHS except signed. - {"vqrshiftns", nullptr, Int(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), Int(16, 8)}}, - {"vqrshiftnu", nullptr, UInt(8, 8), "saturating_rounding_shift_right_narrow", {UInt(16, 8), Int(16, 8)}}, - {"vqrshiftnsu", nullptr, UInt(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), Int(16, 8)}}, - {"vqrshiftns", nullptr, Int(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), Int(32, 4)}}, - {"vqrshiftnu", nullptr, UInt(16, 4), "saturating_rounding_shift_right_narrow", {UInt(32, 4), Int(32, 4)}}, - {"vqrshiftnsu", nullptr, UInt(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), Int(32, 4)}}, - {"vqrshiftns", nullptr, Int(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), Int(64, 2)}}, - {"vqrshiftnu", nullptr, UInt(32, 2), "saturating_rounding_shift_right_narrow", {UInt(64, 2), Int(64, 2)}}, - {"vqrshiftnsu", nullptr, UInt(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), Int(64, 2)}}, + {"vqrshiftns", nullptr, Int(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnu", nullptr, UInt(8, 8), "saturating_rounding_shift_right_narrow", {UInt(16, 8), Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnsu", nullptr, UInt(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftns", nullptr, Int(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnu", nullptr, UInt(16, 4), "saturating_rounding_shift_right_narrow", {UInt(32, 4), Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnsu", nullptr, UInt(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftns", nullptr, Int(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnu", nullptr, UInt(32, 2), "saturating_rounding_shift_right_narrow", {UInt(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vqrshiftnsu", nullptr, UInt(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, // arm64 expects a 32-bit constant. - {nullptr, "sqrshrn", Int(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), UInt(32)}}, - {nullptr, "uqrshrn", UInt(8, 8), "saturating_rounding_shift_right_narrow", {UInt(16, 8), UInt(32)}}, - {nullptr, "sqrshrun", UInt(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), UInt(32)}}, - {nullptr, "sqrshrn", Int(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), UInt(32)}}, - {nullptr, "uqrshrn", UInt(16, 4), "saturating_rounding_shift_right_narrow", {UInt(32, 4), UInt(32)}}, - {nullptr, "sqrshrun", UInt(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), UInt(32)}}, - {nullptr, "sqrshrn", Int(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), UInt(32)}}, - {nullptr, "uqrshrn", UInt(32, 2), "saturating_rounding_shift_right_narrow", {UInt(64, 2), UInt(32)}}, - {nullptr, "sqrshrun", UInt(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), UInt(32)}}, + {nullptr, "sqrshrn", Int(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqrshrn", UInt(8, 8), "saturating_rounding_shift_right_narrow", {UInt(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqrshrun", UInt(8, 8), "saturating_rounding_shift_right_narrow", {Int(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqrshrn", Int(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqrshrn", UInt(16, 4), "saturating_rounding_shift_right_narrow", {UInt(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqrshrun", UInt(16, 4), "saturating_rounding_shift_right_narrow", {Int(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqrshrn", Int(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqrshrn", UInt(32, 2), "saturating_rounding_shift_right_narrow", {UInt(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqrshrun", UInt(32, 2), "saturating_rounding_shift_right_narrow", {Int(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, // SQSHL, UQSHL, SQSHLU - Saturating shift left by signed register. // There is also an immediate version of this - hopefully LLVM does this matching when appropriate. {"vqshifts", "sqshl", Int(8, 8), "saturating_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, {"vqshiftu", "uqshl", UInt(8, 8), "saturating_shift_left", {UInt(8, 8), Int(8, 8)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, - {"vqshiftsu", "sqshlu", UInt(8, 8), "saturating_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, + {"vqshiftsu", "sqshlu", UInt(8, 8), "saturating_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, {"vqshifts", "sqshl", Int(16, 4), "saturating_shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, {"vqshiftu", "uqshl", UInt(16, 4), "saturating_shift_left", {UInt(16, 4), Int(16, 4)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, - {"vqshiftsu", "sqshlu", UInt(16, 4), "saturating_shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, + {"vqshiftsu", "sqshlu", UInt(16, 4), "saturating_shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, {"vqshifts", "sqshl", Int(32, 2), "saturating_shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, {"vqshiftu", "uqshl", UInt(32, 2), "saturating_shift_left", {UInt(32, 2), Int(32, 2)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, - {"vqshiftsu", "sqshlu", UInt(32, 2), "saturating_shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth}, + {"vqshiftsu", "sqshlu", UInt(32, 2), "saturating_shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, {"vqshifts", "sqshl", Int(64, 2), "saturating_shift_left", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::AllowUnsignedOp1}, {"vqshiftu", "uqshl", UInt(64, 2), "saturating_shift_left", {UInt(64, 2), Int(64, 2)}, ArmIntrinsic::AllowUnsignedOp1}, - {"vqshiftsu", "sqshlu", UInt(64, 2), "saturating_shift_left", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::AllowUnsignedOp1}, + {"vqshiftsu", "sqshlu", UInt(64, 2), "saturating_shift_left", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::AllowUnsignedOp1 | ArmIntrinsic::SveUnavailable}, // SQSHRN, UQSHRN, SQRSHRUN Saturating narrowing shift right by an (by immediate in [1, output bits]) // arm32 expects a vector RHS of the same type as the LHS. @@ -500,15 +615,15 @@ const ArmIntrinsic intrinsic_defs[] = { {"vqshiftnsu", nullptr, UInt(32, 2), "saturating_shift_right_narrow", {Int(64, 2), Int(64, 2)}}, // arm64 expects a 32-bit constant. - {nullptr, "sqshrn", Int(8, 8), "saturating_shift_right_narrow", {Int(16, 8), UInt(32)}}, - {nullptr, "uqshrn", UInt(8, 8), "saturating_shift_right_narrow", {UInt(16, 8), UInt(32)}}, - {nullptr, "sqshrn", Int(16, 4), "saturating_shift_right_narrow", {Int(32, 4), UInt(32)}}, - {nullptr, "uqshrn", UInt(16, 4), "saturating_shift_right_narrow", {UInt(32, 4), UInt(32)}}, - {nullptr, "sqshrn", Int(32, 2), "saturating_shift_right_narrow", {Int(64, 2), UInt(32)}}, - {nullptr, "uqshrn", UInt(32, 2), "saturating_shift_right_narrow", {UInt(64, 2), UInt(32)}}, - {nullptr, "sqshrun", UInt(8, 8), "saturating_shift_right_narrow", {Int(16, 8), UInt(32)}}, - {nullptr, "sqshrun", UInt(16, 4), "saturating_shift_right_narrow", {Int(32, 4), UInt(32)}}, - {nullptr, "sqshrun", UInt(32, 2), "saturating_shift_right_narrow", {Int(64, 2), UInt(32)}}, + {nullptr, "sqshrn", Int(8, 8), "saturating_shift_right_narrow", {Int(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqshrn", UInt(8, 8), "saturating_shift_right_narrow", {UInt(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqshrn", Int(16, 4), "saturating_shift_right_narrow", {Int(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqshrn", UInt(16, 4), "saturating_shift_right_narrow", {UInt(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqshrn", Int(32, 2), "saturating_shift_right_narrow", {Int(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "uqshrn", UInt(32, 2), "saturating_shift_right_narrow", {UInt(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqshrun", UInt(8, 8), "saturating_shift_right_narrow", {Int(16, 8), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqshrun", UInt(16, 4), "saturating_shift_right_narrow", {Int(32, 4), UInt(32)}, ArmIntrinsic::SveUnavailable}, + {nullptr, "sqshrun", UInt(32, 2), "saturating_shift_right_narrow", {Int(64, 2), UInt(32)}, ArmIntrinsic::SveUnavailable}, // SRSHL, URSHL - Rounding shift left (by signed vector) {"vrshifts", "srshl", Int(8, 8), "rounding_shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, @@ -521,14 +636,15 @@ const ArmIntrinsic intrinsic_defs[] = { {"vrshiftu", "urshl", UInt(64, 2), "rounding_shift_left", {UInt(64, 2), Int(64, 2)}}, // SSHL, USHL - Shift left (by signed vector) - {"vshifts", "sshl", Int(8, 8), "shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vshiftu", "ushl", UInt(8, 8), "shift_left", {UInt(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth}, - {"vshifts", "sshl", Int(16, 4), "shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vshiftu", "ushl", UInt(16, 4), "shift_left", {UInt(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vshifts", "sshl", Int(32, 2), "shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vshiftu", "ushl", UInt(32, 2), "shift_left", {UInt(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, - {"vshifts", "sshl", Int(64, 2), "shift_left", {Int(64, 2), Int(64, 2)}}, - {"vshiftu", "ushl", UInt(64, 2), "shift_left", {UInt(64, 2), Int(64, 2)}}, + // In SVE, no equivalent is found, though there are rounding, saturating, or widening versions. + {"vshifts", "sshl", Int(8, 8), "shift_left", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshiftu", "ushl", UInt(8, 8), "shift_left", {UInt(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshifts", "sshl", Int(16, 4), "shift_left", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshiftu", "ushl", UInt(16, 4), "shift_left", {UInt(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshifts", "sshl", Int(32, 2), "shift_left", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshiftu", "ushl", UInt(32, 2), "shift_left", {UInt(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {"vshifts", "sshl", Int(64, 2), "shift_left", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vshiftu", "ushl", UInt(64, 2), "shift_left", {UInt(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, // SRSHR, URSHR - Rounding shift right (by immediate in [1, output bits]) // LLVM wants these expressed as SRSHL by negative amounts. @@ -537,28 +653,28 @@ const ArmIntrinsic intrinsic_defs[] = { // LLVM pattern matches these for us. // RADDHN - Add and narrow with rounding. - {"vraddhn", "raddhn", Int(8, 8), "rounding_add_narrow", {Int(16, 8), Int(16, 8)}}, - {"vraddhn", "raddhn", UInt(8, 8), "rounding_add_narrow", {UInt(16, 8), UInt(16, 8)}}, - {"vraddhn", "raddhn", Int(16, 4), "rounding_add_narrow", {Int(32, 4), Int(32, 4)}}, - {"vraddhn", "raddhn", UInt(16, 4), "rounding_add_narrow", {UInt(32, 4), UInt(32, 4)}}, - {"vraddhn", "raddhn", Int(32, 2), "rounding_add_narrow", {Int(64, 2), Int(64, 2)}}, - {"vraddhn", "raddhn", UInt(32, 2), "rounding_add_narrow", {UInt(64, 2), UInt(64, 2)}}, + {"vraddhn", "raddhn", Int(8, 8), "rounding_add_narrow", {Int(16, 8), Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vraddhn", "raddhn", UInt(8, 8), "rounding_add_narrow", {UInt(16, 8), UInt(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vraddhn", "raddhn", Int(16, 4), "rounding_add_narrow", {Int(32, 4), Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vraddhn", "raddhn", UInt(16, 4), "rounding_add_narrow", {UInt(32, 4), UInt(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vraddhn", "raddhn", Int(32, 2), "rounding_add_narrow", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vraddhn", "raddhn", UInt(32, 2), "rounding_add_narrow", {UInt(64, 2), UInt(64, 2)}, ArmIntrinsic::SveUnavailable}, // RSUBHN - Sub and narrow with rounding. - {"vrsubhn", "rsubhn", Int(8, 8), "rounding_sub_narrow", {Int(16, 8), Int(16, 8)}}, - {"vrsubhn", "rsubhn", UInt(8, 8), "rounding_sub_narrow", {UInt(16, 8), UInt(16, 8)}}, - {"vrsubhn", "rsubhn", Int(16, 4), "rounding_sub_narrow", {Int(32, 4), Int(32, 4)}}, - {"vrsubhn", "rsubhn", UInt(16, 4), "rounding_sub_narrow", {UInt(32, 4), UInt(32, 4)}}, - {"vrsubhn", "rsubhn", Int(32, 2), "rounding_sub_narrow", {Int(64, 2), Int(64, 2)}}, - {"vrsubhn", "rsubhn", UInt(32, 2), "rounding_sub_narrow", {UInt(64, 2), UInt(64, 2)}}, + {"vrsubhn", "rsubhn", Int(8, 8), "rounding_sub_narrow", {Int(16, 8), Int(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vrsubhn", "rsubhn", UInt(8, 8), "rounding_sub_narrow", {UInt(16, 8), UInt(16, 8)}, ArmIntrinsic::SveUnavailable}, + {"vrsubhn", "rsubhn", Int(16, 4), "rounding_sub_narrow", {Int(32, 4), Int(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vrsubhn", "rsubhn", UInt(16, 4), "rounding_sub_narrow", {UInt(32, 4), UInt(32, 4)}, ArmIntrinsic::SveUnavailable}, + {"vrsubhn", "rsubhn", Int(32, 2), "rounding_sub_narrow", {Int(64, 2), Int(64, 2)}, ArmIntrinsic::SveUnavailable}, + {"vrsubhn", "rsubhn", UInt(32, 2), "rounding_sub_narrow", {UInt(64, 2), UInt(64, 2)}, ArmIntrinsic::SveUnavailable}, // SQDMULH - Saturating doubling multiply keep high half. - {"vqdmulh", "sqdmulh", Int(16, 4), "qdmulh", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vqdmulh", "sqdmulh", Int(32, 2), "qdmulh", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, + {"vqdmulh", "sqdmulh", Int(16, 4), "qdmulh", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"vqdmulh", "sqdmulh", Int(32, 2), "qdmulh", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, // SQRDMULH - Saturating doubling multiply keep high half with rounding. - {"vqrdmulh", "sqrdmulh", Int(16, 4), "qrdmulh", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth}, - {"vqrdmulh", "sqrdmulh", Int(32, 2), "qrdmulh", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth}, + {"vqrdmulh", "sqrdmulh", Int(16, 4), "qrdmulh", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, + {"vqrdmulh", "sqrdmulh", Int(32, 2), "qrdmulh", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::SveNoPredicate}, // PADD - Pairwise add. // 32-bit only has half-width versions. @@ -571,47 +687,49 @@ const ArmIntrinsic intrinsic_defs[] = { {"vpadd", nullptr, Float(32, 2), "pairwise_add", {Float(32, 4)}, ArmIntrinsic::SplitArg0}, {"vpadd", nullptr, Float(16, 4), "pairwise_add", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::RequireFp16}, - {nullptr, "addp", Int(8, 8), "pairwise_add", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "addp", UInt(8, 8), "pairwise_add", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "addp", Int(16, 4), "pairwise_add", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "addp", UInt(16, 4), "pairwise_add", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "addp", Int(32, 2), "pairwise_add", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "addp", UInt(32, 2), "pairwise_add", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "faddp", Float(32, 2), "pairwise_add", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "faddp", Float(64, 2), "pairwise_add", {Float(64, 4)}, ArmIntrinsic::SplitArg0}, - {nullptr, "faddp", Float(16, 4), "pairwise_add", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, + {nullptr, "addp", Int(8, 8), "pairwise_add", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", UInt(8, 8), "pairwise_add", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", Int(16, 4), "pairwise_add", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", UInt(16, 4), "pairwise_add", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", Int(32, 2), "pairwise_add", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", UInt(32, 2), "pairwise_add", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", Int(64, 2), "pairwise_add", {Int(64, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::SveUnavailable}, + {nullptr, "addp", UInt(64, 2), "pairwise_add", {UInt(64, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::SveUnavailable}, + {nullptr, "faddp", Float(32, 2), "pairwise_add", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "faddp", Float(64, 2), "pairwise_add", {Float(64, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::SveUnavailable}, + {nullptr, "faddp", Float(16, 4), "pairwise_add", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveUnavailable}, // SADDLP, UADDLP - Pairwise add long. - {"vpaddls", "saddlp", Int(16, 4), "pairwise_widening_add", {Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddlu", "uaddlp", UInt(16, 4), "pairwise_widening_add", {UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddlu", "uaddlp", Int(16, 4), "pairwise_widening_add", {UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddls", "saddlp", Int(32, 2), "pairwise_widening_add", {Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddlu", "uaddlp", UInt(32, 2), "pairwise_widening_add", {UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddlu", "uaddlp", Int(32, 2), "pairwise_widening_add", {UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs}, - {"vpaddls", "saddlp", Int(64, 1), "pairwise_widening_add", {Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors}, - {"vpaddlu", "uaddlp", UInt(64, 1), "pairwise_widening_add", {UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors}, - {"vpaddlu", "uaddlp", Int(64, 1), "pairwise_widening_add", {UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors}, + {"vpaddls", "saddlp", Int(16, 4), "pairwise_widening_add", {Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", UInt(16, 4), "pairwise_widening_add", {UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", Int(16, 4), "pairwise_widening_add", {UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddls", "saddlp", Int(32, 2), "pairwise_widening_add", {Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", UInt(32, 2), "pairwise_widening_add", {UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", Int(32, 2), "pairwise_widening_add", {UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::SveUnavailable}, + {"vpaddls", "saddlp", Int(64, 1), "pairwise_widening_add", {Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", UInt(64, 1), "pairwise_widening_add", {UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::SveUnavailable}, + {"vpaddlu", "uaddlp", Int(64, 1), "pairwise_widening_add", {UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleRetArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::SveUnavailable}, // SPADAL, UPADAL - Pairwise add and accumulate long. - {"vpadals", nullptr, Int(16, 4), "pairwise_widening_add_accumulate", {Int(16, 4), Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadalu", nullptr, UInt(16, 4), "pairwise_widening_add_accumulate", {UInt(16, 4), UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadalu", nullptr, Int(16, 4), "pairwise_widening_add_accumulate", {Int(16, 4), UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadals", nullptr, Int(32, 2), "pairwise_widening_add_accumulate", {Int(32, 2), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadalu", nullptr, UInt(32, 2), "pairwise_widening_add_accumulate", {UInt(32, 2), UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadalu", nullptr, Int(32, 2), "pairwise_widening_add_accumulate", {Int(32, 2), UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs}, - {"vpadals", nullptr, Int(64, 1), "pairwise_widening_add_accumulate", {Int(64, 1), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors}, - {"vpadalu", nullptr, UInt(64, 1), "pairwise_widening_add_accumulate", {UInt(64, 1), UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors}, - {"vpadalu", nullptr, Int(64, 1), "pairwise_widening_add_accumulate", {Int(64, 1), UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors}, + {"vpadals", "sadalp", Int(16, 4), "pairwise_widening_add_accumulate", {Int(16, 4), Int(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", UInt(16, 4), "pairwise_widening_add_accumulate", {UInt(16, 4), UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", Int(16, 4), "pairwise_widening_add_accumulate", {Int(16, 4), UInt(8, 8)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadals", "sadalp", Int(32, 2), "pairwise_widening_add_accumulate", {Int(32, 2), Int(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", UInt(32, 2), "pairwise_widening_add_accumulate", {UInt(32, 2), UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", Int(32, 2), "pairwise_widening_add_accumulate", {Int(32, 2), UInt(16, 4)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::Neon64Unavailable}, + {"vpadals", "sadalp", Int(64, 1), "pairwise_widening_add_accumulate", {Int(64, 1), Int(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", UInt(64, 1), "pairwise_widening_add_accumulate", {UInt(64, 1), UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::Neon64Unavailable}, + {"vpadalu", "uadalp", Int(64, 1), "pairwise_widening_add_accumulate", {Int(64, 1), UInt(32, 2)}, ArmIntrinsic::HalfWidth | ArmIntrinsic::MangleArgs | ArmIntrinsic::ScalarsAreVectors | ArmIntrinsic::Neon64Unavailable}, // SMAXP, UMAXP, FMAXP - Pairwise max. - {nullptr, "smaxp", Int(8, 8), "pairwise_max", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "umaxp", UInt(8, 8), "pairwise_max", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "smaxp", Int(16, 4), "pairwise_max", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "umaxp", UInt(16, 4), "pairwise_max", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "smaxp", Int(32, 2), "pairwise_max", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "umaxp", UInt(32, 2), "pairwise_max", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "fmaxp", Float(32, 2), "pairwise_max", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "fmaxp", Float(16, 4), "pairwise_max", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, + {nullptr, "smaxp", Int(8, 8), "pairwise_max", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "umaxp", UInt(8, 8), "pairwise_max", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "smaxp", Int(16, 4), "pairwise_max", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "umaxp", UInt(16, 4), "pairwise_max", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "smaxp", Int(32, 2), "pairwise_max", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "umaxp", UInt(32, 2), "pairwise_max", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "fmaxp", Float(32, 2), "pairwise_max", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "fmaxp", Float(16, 4), "pairwise_max", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveUnavailable}, // On arm32, we only have half-width versions of these. {"vpmaxs", nullptr, Int(8, 8), "pairwise_max", {Int(8, 16)}, ArmIntrinsic::SplitArg0}, @@ -624,14 +742,14 @@ const ArmIntrinsic intrinsic_defs[] = { {"vpmaxs", nullptr, Float(16, 4), "pairwise_max", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::RequireFp16}, // SMINP, UMINP, FMINP - Pairwise min. - {nullptr, "sminp", Int(8, 8), "pairwise_min", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "uminp", UInt(8, 8), "pairwise_min", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "sminp", Int(16, 4), "pairwise_min", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "uminp", UInt(16, 4), "pairwise_min", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "sminp", Int(32, 2), "pairwise_min", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "uminp", UInt(32, 2), "pairwise_min", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "fminp", Float(32, 2), "pairwise_min", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth}, - {nullptr, "fminp", Float(16, 4), "pairwise_min", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16}, + {nullptr, "sminp", Int(8, 8), "pairwise_min", {Int(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "uminp", UInt(8, 8), "pairwise_min", {UInt(8, 16)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "sminp", Int(16, 4), "pairwise_min", {Int(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "uminp", UInt(16, 4), "pairwise_min", {UInt(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "sminp", Int(32, 2), "pairwise_min", {Int(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "uminp", UInt(32, 2), "pairwise_min", {UInt(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "fminp", Float(32, 2), "pairwise_min", {Float(32, 4)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::SveUnavailable}, + {nullptr, "fminp", Float(16, 4), "pairwise_min", {Float(16, 8)}, ArmIntrinsic::SplitArg0 | ArmIntrinsic::HalfWidth | ArmIntrinsic::RequireFp16 | ArmIntrinsic::SveUnavailable}, // On arm32, we only have half-width versions of these. {"vpmins", nullptr, Int(8, 8), "pairwise_min", {Int(8, 16)}, ArmIntrinsic::SplitArg0}, @@ -645,28 +763,35 @@ const ArmIntrinsic intrinsic_defs[] = { // SDOT, UDOT - Dot products. // Mangle this one manually, there aren't that many and it is a special case. - {nullptr, "sdot.v2i32.v8i8", Int(32, 2), "dot_product", {Int(32, 2), Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle}, - {nullptr, "udot.v2i32.v8i8", Int(32, 2), "dot_product", {Int(32, 2), UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle}, - {nullptr, "udot.v2i32.v8i8", UInt(32, 2), "dot_product", {UInt(32, 2), UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle}, - {nullptr, "sdot.v4i32.v16i8", Int(32, 4), "dot_product", {Int(32, 4), Int(8, 16), Int(8, 16)}, ArmIntrinsic::NoMangle}, - {nullptr, "udot.v4i32.v16i8", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle}, - {nullptr, "udot.v4i32.v16i8", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle}, + {nullptr, "sdot.v2i32.v8i8", Int(32, 2), "dot_product", {Int(32, 2), Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + {nullptr, "udot.v2i32.v8i8", Int(32, 2), "dot_product", {Int(32, 2), UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + {nullptr, "udot.v2i32.v8i8", UInt(32, 2), "dot_product", {UInt(32, 2), UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + {nullptr, "sdot.v4i32.v16i8", Int(32, 4), "dot_product", {Int(32, 4), Int(8, 16), Int(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + {nullptr, "udot.v4i32.v16i8", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + {nullptr, "udot.v4i32.v16i8", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveUnavailable}, + // SVE versions. + {nullptr, "sdot.nxv4i32", Int(32, 4), "dot_product", {Int(32, 4), Int(8, 16), Int(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::SveRequired}, + {nullptr, "udot.nxv4i32", Int(32, 4), "dot_product", {Int(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::SveRequired}, + {nullptr, "udot.nxv4i32", UInt(32, 4), "dot_product", {UInt(32, 4), UInt(8, 16), UInt(8, 16)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::SveRequired}, + {nullptr, "sdot.nxv2i64", Int(64, 2), "dot_product", {Int(64, 2), Int(16, 8), Int(16, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::Neon64Unavailable | ArmIntrinsic::SveRequired}, + {nullptr, "udot.nxv2i64", Int(64, 2), "dot_product", {Int(64, 2), UInt(16, 8), UInt(16, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::Neon64Unavailable | ArmIntrinsic::SveRequired}, + {nullptr, "udot.nxv2i64", UInt(64, 2), "dot_product", {UInt(64, 2), UInt(16, 8), UInt(16, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::SveNoPredicate | ArmIntrinsic::Neon64Unavailable | ArmIntrinsic::SveRequired}, // ABDL - Widening absolute difference // The ARM backend folds both signed and unsigned widening casts of absd to a widening_absd, so we need to handle both signed and // unsigned input and return types. - {"vabdl_i8x8", "vabdl_i8x8", Int(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_i8x8", "vabdl_i8x8", UInt(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u8x8", "vabdl_u8x8", Int(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u8x8", "vabdl_u8x8", UInt(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_i16x4", "vabdl_i16x4", Int(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_i16x4", "vabdl_i16x4", UInt(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u16x4", "vabdl_u16x4", Int(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u16x4", "vabdl_u16x4", UInt(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_i32x2", "vabdl_i32x2", Int(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_i32x2", "vabdl_i32x2", UInt(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u32x2", "vabdl_u32x2", Int(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, - {"vabdl_u32x2", "vabdl_u32x2", UInt(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix}, + {"vabdl_i8x8", "vabdl_i8x8", Int(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_i8x8", "vabdl_i8x8", UInt(16, 8), "widening_absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u8x8", "vabdl_u8x8", Int(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u8x8", "vabdl_u8x8", UInt(16, 8), "widening_absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_i16x4", "vabdl_i16x4", Int(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_i16x4", "vabdl_i16x4", UInt(32, 4), "widening_absd", {Int(16, 4), Int(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u16x4", "vabdl_u16x4", Int(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u16x4", "vabdl_u16x4", UInt(32, 4), "widening_absd", {UInt(16, 4), UInt(16, 4)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_i32x2", "vabdl_i32x2", Int(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_i32x2", "vabdl_i32x2", UInt(64, 2), "widening_absd", {Int(32, 2), Int(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u32x2", "vabdl_u32x2", Int(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, + {"vabdl_u32x2", "vabdl_u32x2", UInt(64, 2), "widening_absd", {UInt(32, 2), UInt(32, 2)}, ArmIntrinsic::NoMangle | ArmIntrinsic::NoPrefix | ArmIntrinsic::SveUnavailable}, }; // List of fp16 math functions which we can avoid "emulated" equivalent code generation. @@ -706,32 +831,103 @@ const std::map float16_transcendental_remapping = { }; // clang-format on -llvm::Function *CodeGen_ARM::define_concat_args_wrapper(llvm::Function *inner, const string &name) { - llvm::FunctionType *inner_ty = inner->getFunctionType(); +llvm::Type *CodeGen_ARM::llvm_type_with_constraint(const Type &t, bool scalars_are_vectors, + VectorTypeConstraint constraint) { + llvm::Type *ret = llvm_type_of(t.element_of()); + if (!t.is_scalar() || scalars_are_vectors) { + int lanes = t.lanes(); + if (constraint == VectorTypeConstraint::VScale) { + lanes /= target_vscale(); + } + ret = get_vector_type(ret, lanes, constraint); + } + return ret; +} + +llvm::Function *CodeGen_ARM::define_intrin_wrapper(const std::string &inner_name, + const Type &ret_type, + const std::string &mangled_name, + const std::vector &arg_types, + int intrinsic_flags, + bool sve_intrinsic) { + + auto to_llvm_type = [&](const Type &t) { + return llvm_type_with_constraint(t, (intrinsic_flags & ArmIntrinsic::ScalarsAreVectors), + !sve_intrinsic ? VectorTypeConstraint::Fixed : VectorTypeConstraint::VScale); + }; + + llvm::Type *llvm_ret_type = to_llvm_type(ret_type); + std::vector llvm_arg_types; + std::transform(arg_types.begin(), arg_types.end(), std::back_inserter(llvm_arg_types), to_llvm_type); + + const bool add_predicate = sve_intrinsic && !(intrinsic_flags & ArmIntrinsic::SveNoPredicate); + bool add_inactive_arg = sve_intrinsic && (intrinsic_flags & ArmIntrinsic::SveInactiveArg); + bool split_arg0 = intrinsic_flags & ArmIntrinsic::SplitArg0; + + if (!(add_inactive_arg || add_predicate || split_arg0)) { + // No need to wrap + return get_llvm_intrin(llvm_ret_type, mangled_name, llvm_arg_types); + } + + std::vector inner_llvm_arg_types; + std::vector inner_args; + internal_assert(!arg_types.empty()); + const int inner_lanes = split_arg0 ? arg_types[0].lanes() / 2 : arg_types[0].lanes(); + + if (add_inactive_arg) { + // The fallback value has the same type as ret value. + // We don't use this, so just pad it with 0. + inner_llvm_arg_types.push_back(llvm_ret_type); + + Value *zero = Constant::getNullValue(llvm_ret_type); + inner_args.push_back(zero); + } + if (add_predicate) { + llvm::Type *pred_type = to_llvm_type(Int(1, inner_lanes)); + inner_llvm_arg_types.push_back(pred_type); + // Halide does not have general support for predication so use + // constant true for all lanes. + Value *ptrue = Constant::getAllOnesValue(pred_type); + inner_args.push_back(ptrue); + } + if (split_arg0) { + llvm::Type *split_arg_type = to_llvm_type(arg_types[0].with_lanes(inner_lanes)); + inner_llvm_arg_types.push_back(split_arg_type); + inner_llvm_arg_types.push_back(split_arg_type); + internal_assert(arg_types.size() == 1); + } else { + // Push back all argument typs which Halide defines + std::copy(llvm_arg_types.begin(), llvm_arg_types.end(), std::back_inserter(inner_llvm_arg_types)); + } - internal_assert(inner_ty->getNumParams() == 2); - llvm::Type *inner_arg0_ty = inner_ty->getParamType(0); - llvm::Type *inner_arg1_ty = inner_ty->getParamType(1); - int inner_arg0_lanes = get_vector_num_elements(inner_arg0_ty); - int inner_arg1_lanes = get_vector_num_elements(inner_arg1_ty); + llvm::Function *inner = get_llvm_intrin(llvm_ret_type, mangled_name, inner_llvm_arg_types); + llvm::FunctionType *inner_ty = inner->getFunctionType(); - llvm::Type *concat_arg_ty = - get_vector_type(inner_arg0_ty->getScalarType(), inner_arg0_lanes + inner_arg1_lanes); + llvm::FunctionType *wrapper_ty = llvm::FunctionType::get(inner_ty->getReturnType(), llvm_arg_types, false); - // Make a wrapper. - llvm::FunctionType *wrapper_ty = - llvm::FunctionType::get(inner_ty->getReturnType(), {concat_arg_ty}, false); + string wrapper_name = inner_name + unique_name("_wrapper"); llvm::Function *wrapper = - llvm::Function::Create(wrapper_ty, llvm::GlobalValue::InternalLinkage, name, module.get()); + llvm::Function::Create(wrapper_ty, llvm::GlobalValue::InternalLinkage, wrapper_name, module.get()); llvm::BasicBlock *block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper); IRBuilderBase::InsertPoint here = builder->saveIP(); builder->SetInsertPoint(block); + if (split_arg0) { + // Call the real intrinsic. + Value *low = slice_vector(wrapper->getArg(0), 0, inner_lanes); + Value *high = slice_vector(wrapper->getArg(0), inner_lanes, inner_lanes); + inner_args.push_back(low); + inner_args.push_back(high); + internal_assert(inner_llvm_arg_types.size() == 2); + } else { + for (auto *itr = wrapper->arg_begin(); itr != wrapper->arg_end(); ++itr) { + inner_args.push_back(itr); + } + } + // Call the real intrinsic. - Value *low = slice_vector(wrapper->getArg(0), 0, inner_arg0_lanes); - Value *high = slice_vector(wrapper->getArg(0), inner_arg0_lanes, inner_arg1_lanes); - Value *ret = builder->CreateCall(inner, {low, high}); + Value *ret = builder->CreateCall(inner, inner_args); builder->CreateRet(ret); // Always inline these wrappers. @@ -746,15 +942,32 @@ llvm::Function *CodeGen_ARM::define_concat_args_wrapper(llvm::Function *inner, c void CodeGen_ARM::init_module() { CodeGen_Posix::init_module(); - if (neon_intrinsics_disabled()) { + const bool has_neon = !target.has_feature(Target::NoNEON); + const bool has_sve = target.has_feature(Target::SVE2); + if (!(has_neon || has_sve)) { return; } - string prefix = target.bits == 32 ? "llvm.arm.neon." : "llvm.aarch64.neon."; + enum class SIMDFlavors { + NeonWidthX1, + NeonWidthX2, + SVE, + }; + + std::vector flavors; + if (has_neon) { + flavors.push_back(SIMDFlavors::NeonWidthX1); + flavors.push_back(SIMDFlavors::NeonWidthX2); + } + if (has_sve) { + flavors.push_back(SIMDFlavors::SVE); + } + for (const ArmIntrinsic &intrin : intrinsic_defs) { if (intrin.flags & ArmIntrinsic::RequireFp16 && !target.has_feature(Target::ARMFp16)) { continue; } + // Get the name of the intrinsic with the appropriate prefix. const char *intrin_name = nullptr; if (target.bits == 32) { @@ -765,21 +978,66 @@ void CodeGen_ARM::init_module() { if (!intrin_name) { continue; } - string full_name = intrin_name; - if (!starts_with(full_name, "llvm.") && (intrin.flags & ArmIntrinsic::NoPrefix) == 0) { - full_name = prefix + full_name; - } - // We might have to generate versions of this intrinsic with multiple widths. - vector width_factors = {1}; - if (intrin.flags & ArmIntrinsic::HalfWidth) { - width_factors.push_back(2); - } + // This makes up to three passes defining intrinsics for 64-bit, + // 128-bit, and, if SVE is avaailable, whatever the SVE target width + // is. Some variants will not result in a definition getting added based + // on the target and the intrinsic flags. The intrinsic width may be + // scaled and one of two opcodes may be selected by different + // interations of this loop. + for (const auto flavor : flavors) { + const bool is_sve = (flavor == SIMDFlavors::SVE); + + // Skip intrinsics that are NEON or SVE only depending on whether compiling for SVE. + if (is_sve) { + if (intrin.flags & ArmIntrinsic::SveUnavailable) { + continue; + } + } else { + if (intrin.flags & ArmIntrinsic::SveRequired) { + continue; + } + } + if ((target.bits == 64) && + (intrin.flags & ArmIntrinsic::Neon64Unavailable) && + !is_sve) { + continue; + } + // Already declared in the x1 pass. + if ((flavor == SIMDFlavors::NeonWidthX2) && + !(intrin.flags & ArmIntrinsic::HalfWidth)) { + continue; + } + + string full_name = intrin_name; + const bool is_vanilla_intrinsic = starts_with(full_name, "llvm."); + if (!is_vanilla_intrinsic && (intrin.flags & ArmIntrinsic::NoPrefix) == 0) { + if (target.bits == 32) { + full_name = "llvm.arm.neon." + full_name; + } else { + full_name = (is_sve ? "llvm.aarch64.sve." : "llvm.aarch64.neon.") + full_name; + } + } + + int width_factor = 1; + if (!((intrin.ret_type.lanes <= 1) && (intrin.flags & ArmIntrinsic::NoMangle))) { + switch (flavor) { + case SIMDFlavors::NeonWidthX1: + width_factor = 1; + break; + case SIMDFlavors::NeonWidthX2: + width_factor = 2; + break; + case SIMDFlavors::SVE: + width_factor = (intrin.flags & ArmIntrinsic::HalfWidth) ? 2 : 1; + width_factor *= target_vscale(); + break; + } + } - for (int width_factor : width_factors) { Type ret_type = intrin.ret_type; ret_type = ret_type.with_lanes(ret_type.lanes() * width_factor); - internal_assert(ret_type.bits() * ret_type.lanes() <= 128) << full_name << "\n"; + internal_assert(ret_type.bits() * ret_type.lanes() <= 128 * width_factor) << full_name << "\n"; vector arg_types; arg_types.reserve(4); for (halide_type_t i : intrin.arg_types) { @@ -787,9 +1045,7 @@ void CodeGen_ARM::init_module() { break; } Type arg_type = i; - if (arg_type.is_vector()) { - arg_type = arg_type.with_lanes(arg_type.lanes() * width_factor); - } + arg_type = arg_type.with_lanes(arg_type.lanes() * width_factor); arg_types.emplace_back(arg_type); } @@ -799,7 +1055,7 @@ void CodeGen_ARM::init_module() { if (starts_with(full_name, "llvm.") && (intrin.flags & ArmIntrinsic::NoMangle) == 0) { // Append LLVM name mangling for either the return type or the arguments, or both. vector types; - if (intrin.flags & ArmIntrinsic::MangleArgs) { + if (intrin.flags & ArmIntrinsic::MangleArgs && !is_sve) { types = arg_types; } else if (intrin.flags & ArmIntrinsic::MangleRetArgs) { types = {ret_type}; @@ -808,7 +1064,9 @@ void CodeGen_ARM::init_module() { types = {ret_type}; } for (const Type &t : types) { - mangled_name_builder << ".v" << t.lanes(); + std::string llvm_vector_prefix = is_sve ? ".nxv" : ".v"; + int mangle_lanes = t.lanes() / (is_sve ? target_vscale() : 1); + mangled_name_builder << llvm_vector_prefix << mangle_lanes; if (t.is_int() || t.is_uint()) { mangled_name_builder << "i"; } else if (t.is_float()) { @@ -819,17 +1077,9 @@ void CodeGen_ARM::init_module() { } string mangled_name = mangled_name_builder.str(); - llvm::Function *intrin_impl = nullptr; - if (intrin.flags & ArmIntrinsic::SplitArg0) { - // This intrinsic needs a wrapper to split the argument. - string wrapper_name = intrin.name + unique_name("_wrapper"); - Type split_arg_type = arg_types[0].with_lanes(arg_types[0].lanes() / 2); - llvm::Function *to_wrap = get_llvm_intrin(ret_type, mangled_name, {split_arg_type, split_arg_type}); - intrin_impl = define_concat_args_wrapper(to_wrap, wrapper_name); - } else { - bool scalars_are_vectors = intrin.flags & ArmIntrinsic::ScalarsAreVectors; - intrin_impl = get_llvm_intrin(ret_type, mangled_name, arg_types, scalars_are_vectors); - } + llvm::Function *intrin_impl = define_intrin_wrapper( + intrin.name, ret_type, mangled_name, arg_types, + intrin.flags, is_sve); function_does_not_access_memory(intrin_impl); intrin_impl->addFnAttr(llvm::Attribute::NoUnwind); @@ -862,8 +1112,31 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, CodeGen_Posix::compile_func(func, simple_name, extern_name); } +void CodeGen_ARM::begin_func(LinkageType linkage, const std::string &simple_name, + const std::string &extern_name, const std::vector &args) { + CodeGen_Posix::begin_func(linkage, simple_name, extern_name, args); + + // TODO(https://github.com/halide/Halide/issues/8092): There is likely a + // better way to ensure this is only generated for the outermost function + // that is being compiled. Avoiding the assert on inner functions is both an + // efficiency and a correctness issue as the assertion code may not compile + // in all contexts. + if (linkage != LinkageType::Internal) { + int effective_vscale = target_vscale(); + if (effective_vscale != 0 && !target.has_feature(Target::NoAsserts)) { + // Make sure run-time vscale is equal to compile-time vscale + Expr runtime_vscale = Call::make(Int(32), Call::get_runtime_vscale, {}, Call::PureIntrinsic); + Value *val_runtime_vscale = codegen(runtime_vscale); + Value *val_compiletime_vscale = ConstantInt::get(i32_t, effective_vscale); + Value *cond = builder->CreateICmpEQ(val_runtime_vscale, val_compiletime_vscale); + create_assertion(cond, Call::make(Int(32), "halide_error_vscale_invalid", + {simple_name, runtime_vscale, Expr(effective_vscale)}, Call::Extern)); + } + } +} + void CodeGen_ARM::visit(const Cast *op) { - if (!neon_intrinsics_disabled() && op->type.is_vector()) { + if (!simd_intrinsics_disabled() && op->type.is_vector()) { vector matches; for (const Pattern &pattern : casts) { if (expr_match(pattern.pattern, op, matches)) { @@ -898,14 +1171,11 @@ void CodeGen_ARM::visit(const Cast *op) { } } - // LLVM fptoui generates fcvtzs if src is fp16 scalar else fcvtzu. - // To avoid that, we use neon intrinsic explicitly. - if (is_float16_and_has_feature(op->value.type())) { - if (op->type.is_int_or_uint() && op->type.bits() == 16) { - value = call_overloaded_intrin(op->type, "fp_to_int", {op->value}); - if (value) { - return; - } + // LLVM fptoui generates fcvtzs or fcvtzu in inconsistent way + if (op->value.type().is_float() && op->type.is_int_or_uint()) { + if (Value *v = call_overloaded_intrin(op->type, "fp_to_int", {op->value})) { + value = v; + return; } } @@ -913,7 +1183,7 @@ void CodeGen_ARM::visit(const Cast *op) { } void CodeGen_ARM::visit(const Add *op) { - if (neon_intrinsics_disabled() || + if (simd_intrinsics_disabled() || !op->type.is_vector() || !target.has_feature(Target::ARMDotProd) || !op->type.is_int_or_uint() || @@ -997,7 +1267,7 @@ void CodeGen_ARM::visit(const Add *op) { } void CodeGen_ARM::visit(const Sub *op) { - if (neon_intrinsics_disabled()) { + if (simd_intrinsics_disabled()) { CodeGen_Posix::visit(op); return; } @@ -1012,6 +1282,46 @@ void CodeGen_ARM::visit(const Sub *op) { } } + // Peep-hole (0 - b) pattern to generate "negate" instruction + if (is_const_zero(op->a)) { + if (target_vscale() != 0) { + if ((op->type.bits() >= 8 && op->type.is_int())) { + if (Value *v = call_overloaded_intrin(op->type, "negate", {op->b})) { + value = v; + return; + } + } else if (op->type.bits() >= 16 && op->type.is_float()) { + value = builder->CreateFNeg(codegen(op->b)); + return; + } + } else { + // llvm.neon.neg/fneg intrinsic doesn't seem to exist. Instead, + // llvm will generate floating point negate instructions if we ask for (-0.0f)-x + if (op->type.is_float() && + (op->type.bits() >= 32 || is_float16_and_has_feature(op->type))) { + Constant *a; + if (op->type.bits() == 16) { + a = ConstantFP::getNegativeZero(f16_t); + } else if (op->type.bits() == 32) { + a = ConstantFP::getNegativeZero(f32_t); + } else if (op->type.bits() == 64) { + a = ConstantFP::getNegativeZero(f64_t); + } else { + a = nullptr; + internal_error << "Unknown bit width for floating point type: " << op->type << "\n"; + } + + Value *b = codegen(op->b); + + if (op->type.lanes() > 1) { + a = get_splat(op->type.lanes(), a); + } + value = builder->CreateFSub(a, b); + return; + } + } + } + // llvm will generate floating point negate instructions if we ask for (-0.0f)-x if (op->type.is_float() && (op->type.bits() >= 32 || is_float16_and_has_feature(op->type)) && @@ -1042,7 +1352,7 @@ void CodeGen_ARM::visit(const Sub *op) { void CodeGen_ARM::visit(const Min *op) { // Use a 2-wide vector for scalar floats. - if (!neon_intrinsics_disabled() && (op->type == Float(32) || op->type.is_vector())) { + if (!simd_intrinsics_disabled() && (op->type.is_float() || op->type.is_vector())) { value = call_overloaded_intrin(op->type, "min", {op->a, op->b}); if (value) { return; @@ -1054,7 +1364,7 @@ void CodeGen_ARM::visit(const Min *op) { void CodeGen_ARM::visit(const Max *op) { // Use a 2-wide vector for scalar floats. - if (!neon_intrinsics_disabled() && (op->type == Float(32) || op->type.is_vector())) { + if (!simd_intrinsics_disabled() && (op->type.is_float() || op->type.is_vector())) { value = call_overloaded_intrin(op->type, "max", {op->a, op->b}); if (value) { return; @@ -1066,12 +1376,13 @@ void CodeGen_ARM::visit(const Max *op) { void CodeGen_ARM::visit(const Store *op) { // Predicated store - if (!is_const_one(op->predicate)) { + const bool is_predicated_store = !is_const_one(op->predicate); + if (is_predicated_store && !target.has_feature(Target::SVE2)) { CodeGen_Posix::visit(op); return; } - if (neon_intrinsics_disabled()) { + if (simd_intrinsics_disabled()) { CodeGen_Posix::visit(op); return; } @@ -1079,8 +1390,8 @@ void CodeGen_ARM::visit(const Store *op) { // A dense store of an interleaving can be done using a vst2 intrinsic const Ramp *ramp = op->index.as(); - // We only deal with ramps here - if (!ramp) { + // We only deal with ramps here except for SVE2 + if (!ramp && !target.has_feature(Target::SVE2)) { CodeGen_Posix::visit(op); return; } @@ -1102,21 +1413,27 @@ void CodeGen_ARM::visit(const Store *op) { intrin_type = t; Type elt = t.element_of(); int vec_bits = t.bits() * t.lanes(); - if (elt == Float(32) || + if (elt == Float(32) || elt == Float(64) || is_float16_and_has_feature(elt) || - elt == Int(8) || elt == Int(16) || elt == Int(32) || - elt == UInt(8) || elt == UInt(16) || elt == UInt(32)) { + elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) || + elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) { + // TODO(zvookin): Handle vector_bits_*. if (vec_bits % 128 == 0) { type_ok_for_vst = true; - intrin_type = intrin_type.with_lanes(128 / t.bits()); + int target_vector_bits = target.vector_bits; + if (target_vector_bits == 0) { + target_vector_bits = 128; + } + intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits()); } else if (vec_bits % 64 == 0) { type_ok_for_vst = true; - intrin_type = intrin_type.with_lanes(64 / t.bits()); + auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64; + intrin_type = intrin_type.with_lanes(intrin_bits / t.bits()); } } } - if (is_const_one(ramp->stride) && + if (ramp && is_const_one(ramp->stride) && shuffle && shuffle->is_interleave() && type_ok_for_vst && 2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) { @@ -1138,11 +1455,14 @@ void CodeGen_ARM::visit(const Store *op) { for (int i = 0; i < num_vecs; ++i) { args[i] = codegen(shuffle->vectors[i]); } + Value *store_pred_val = codegen(op->predicate); + + bool is_sve = target.has_feature(Target::SVE2); // Declare the function std::ostringstream instr; vector arg_types; - llvm::Type *intrin_llvm_type = llvm_type_of(intrin_type); + llvm::Type *intrin_llvm_type = llvm_type_with_constraint(intrin_type, false, is_sve ? VectorTypeConstraint::VScale : VectorTypeConstraint::Fixed); #if LLVM_VERSION >= 170 const bool is_opaque = true; #else @@ -1160,27 +1480,38 @@ void CodeGen_ARM::visit(const Store *op) { arg_types.front() = i8_t->getPointerTo(); arg_types.back() = i32_t; } else { - instr << "llvm.aarch64.neon.st" - << num_vecs - << ".v" - << intrin_type.lanes() - << (t.is_float() ? 'f' : 'i') - << t.bits() - << ".p0"; - if (!is_opaque) { - instr << (t.is_float() ? 'f' : 'i') << t.bits(); + if (is_sve) { + instr << "llvm.aarch64.sve.st" + << num_vecs + << ".nxv" + << (intrin_type.lanes() / target_vscale()) + << (t.is_float() ? 'f' : 'i') + << t.bits(); + arg_types = vector(num_vecs, intrin_llvm_type); + arg_types.emplace_back(get_vector_type(i1_t, intrin_type.lanes() / target_vscale(), VectorTypeConstraint::VScale)); // predicate + arg_types.emplace_back(llvm_type_of(intrin_type.element_of())->getPointerTo()); + } else { + instr << "llvm.aarch64.neon.st" + << num_vecs + << ".v" + << intrin_type.lanes() + << (t.is_float() ? 'f' : 'i') + << t.bits() + << ".p0"; + if (!is_opaque) { + instr << (t.is_float() ? 'f' : 'i') << t.bits(); + } + arg_types = vector(num_vecs + 1, intrin_llvm_type); + arg_types.back() = llvm_type_of(intrin_type.element_of())->getPointerTo(); } - arg_types = vector(num_vecs + 1, intrin_llvm_type); - arg_types.back() = llvm_type_of(intrin_type.element_of())->getPointerTo(); } llvm::FunctionType *fn_type = FunctionType::get(llvm::Type::getVoidTy(*context), arg_types, false); llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type); internal_assert(fn); - // How many vst instructions do we need to generate? - int slices = t.lanes() / intrin_type.lanes(); + // SVE2 supports predication for smaller than whole vector size. + internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes())); - internal_assert(slices >= 1); for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) { Expr slice_base = simplify(ramp->base + i * num_vecs); Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_type.lanes() * num_vecs); @@ -1190,6 +1521,7 @@ void CodeGen_ARM::visit(const Store *op) { // Take a slice of each arg for (int j = 0; j < num_vecs; j++) { slice_args[j] = slice_vector(slice_args[j], i, intrin_type.lanes()); + slice_args[j] = convert_fixed_or_scalable_vector_type(slice_args[j], get_vector_type(slice_args[j]->getType()->getScalarType(), intrin_type.lanes())); } if (target.bits == 32) { @@ -1200,10 +1532,30 @@ void CodeGen_ARM::visit(const Store *op) { // Set the alignment argument slice_args.push_back(ConstantInt::get(i32_t, alignment)); } else { + if (is_sve) { + // Set the predicate argument + auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes()); + Value *vpred_val; + if (is_predicated_store) { + vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes()); + } else { + Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes); + vpred_val = codegen(vpred); + } + slice_args.push_back(vpred_val); + } // Set the pointer argument slice_args.push_back(ptr); } + if (is_sve) { + for (auto &arg : slice_args) { + if (arg->getType()->isVectorTy()) { + arg = match_vector_type_scalable(arg, VectorTypeConstraint::VScale); + } + } + } + CallInst *store = builder->CreateCall(fn, slice_args); add_tbaa_metadata(store, op->name, slice_ramp); } @@ -1216,8 +1568,95 @@ void CodeGen_ARM::visit(const Store *op) { return; } + if (target.has_feature(Target::SVE2)) { + const IntImm *stride = ramp ? ramp->stride.as() : nullptr; + if (stride && stride->value == 1) { + // Basically we can deal with vanilla codegen, + // but to avoid LLVM error, process with the multiple of natural_lanes + const int natural_lanes = target.natural_vector_size(op->value.type()); + if (ramp->lanes % natural_lanes) { + int aligned_lanes = align_up(ramp->lanes, natural_lanes); + // Use predicate to prevent overrun + Expr vpred; + if (is_predicated_store) { + vpred = Shuffle::make_concat({op->predicate, const_false(aligned_lanes - ramp->lanes)}); + } else { + vpred = make_vector_predicate_1s_0s(ramp->lanes, aligned_lanes - ramp->lanes); + } + auto aligned_index = Ramp::make(ramp->base, stride, aligned_lanes); + Expr padding = make_zero(op->value.type().with_lanes(aligned_lanes - ramp->lanes)); + Expr aligned_value = Shuffle::make_concat({op->value, padding}); + codegen(Store::make(op->name, aligned_value, aligned_index, op->param, vpred, op->alignment)); + return; + } + } else if (op->index.type().is_vector()) { + // Scatter + Type elt = op->value.type().element_of(); + + // Rewrite float16 case into reinterpret and Store in uint16, as it is unsupported in LLVM + if (is_float16_and_has_feature(elt)) { + Type u16_type = op->value.type().with_code(halide_type_uint); + Expr v = reinterpret(u16_type, op->value); + codegen(Store::make(op->name, v, op->index, op->param, op->predicate, op->alignment)); + return; + } + + const int store_lanes = op->value.type().lanes(); + const int index_bits = 32; + Type type_with_max_bits = Int(std::max(elt.bits(), index_bits)); + // The number of lanes is constrained by index vector type + const int natural_lanes = target.natural_vector_size(type_with_max_bits); + const int vscale_natural_lanes = natural_lanes / target_vscale(); + + Expr base = 0; + Value *elt_ptr = codegen_buffer_pointer(op->name, elt, base); + Value *val = codegen(op->value); + Value *index = codegen(op->index); + Value *store_pred_val = codegen(op->predicate); + + llvm::Type *slice_type = get_vector_type(llvm_type_of(elt), vscale_natural_lanes, VectorTypeConstraint::VScale); + llvm::Type *slice_index_type = get_vector_type(llvm_type_of(op->index.type().element_of()), vscale_natural_lanes, VectorTypeConstraint::VScale); + llvm::Type *pred_type = get_vector_type(llvm_type_of(op->predicate.type().element_of()), vscale_natural_lanes, VectorTypeConstraint::VScale); + + std::ostringstream instr; + instr << "llvm.aarch64.sve.st1.scatter.uxtw." + << (elt.bits() != 8 ? "index." : "") // index is scaled into bytes + << "nxv" + << vscale_natural_lanes + << (elt == Float(32) || elt == Float(64) ? 'f' : 'i') + << elt.bits(); + + vector arg_types{slice_type, pred_type, elt_ptr->getType(), slice_index_type}; + llvm::FunctionType *fn_type = FunctionType::get(void_t, arg_types, false); + FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type); + + // We need to slice the result into native vector lanes to use intrinsic + for (int i = 0; i < store_lanes; i += natural_lanes) { + Value *slice_value = slice_vector(val, i, natural_lanes); + Value *slice_index = slice_vector(index, i, natural_lanes); + const int active_lanes = std::min(store_lanes - i, natural_lanes); + + Expr vpred = make_vector_predicate_1s_0s(active_lanes, natural_lanes - active_lanes); + Value *vpred_val = codegen(vpred); + vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, pred_type); + if (is_predicated_store) { + Value *sliced_store_vpred_val = slice_vector(store_pred_val, i, natural_lanes); + vpred_val = builder->CreateAnd(vpred_val, sliced_store_vpred_val); + } + + slice_value = match_vector_type_scalable(slice_value, VectorTypeConstraint::VScale); + vpred_val = match_vector_type_scalable(vpred_val, VectorTypeConstraint::VScale); + slice_index = match_vector_type_scalable(slice_index, VectorTypeConstraint::VScale); + CallInst *store = builder->CreateCall(fn, {slice_value, vpred_val, elt_ptr, slice_index}); + add_tbaa_metadata(store, op->name, op->index); + } + + return; + } + } + // If the stride is one or minus one, we can deal with that using vanilla codegen - const IntImm *stride = ramp->stride.as(); + const IntImm *stride = ramp ? ramp->stride.as() : nullptr; if (stride && (stride->value == 1 || stride->value == -1)) { CodeGen_Posix::visit(op); return; @@ -1250,12 +1689,13 @@ void CodeGen_ARM::visit(const Store *op) { void CodeGen_ARM::visit(const Load *op) { // Predicated load - if (!is_const_one(op->predicate)) { + const bool is_predicated_load = !is_const_one(op->predicate); + if (is_predicated_load && !target.has_feature(Target::SVE2)) { CodeGen_Posix::visit(op); return; } - if (neon_intrinsics_disabled()) { + if (simd_intrinsics_disabled()) { CodeGen_Posix::visit(op); return; } @@ -1263,14 +1703,15 @@ void CodeGen_ARM::visit(const Load *op) { const Ramp *ramp = op->index.as(); // We only deal with ramps here - if (!ramp) { + if (!ramp && !target.has_feature(Target::SVE2)) { CodeGen_Posix::visit(op); return; } // If the stride is in [-1, 1], we can deal with that using vanilla codegen const IntImm *stride = ramp ? ramp->stride.as() : nullptr; - if (stride && (-1 <= stride->value && stride->value <= 1)) { + if (stride && (-1 <= stride->value && stride->value <= 1) && + !target.has_feature(Target::SVE2)) { CodeGen_Posix::visit(op); return; } @@ -1296,6 +1737,168 @@ void CodeGen_ARM::visit(const Load *op) { } } + if (target.has_feature(Target::SVE2)) { + if (stride && stride->value < 1) { + CodeGen_Posix::visit(op); + return; + } else if (stride && stride->value == 1) { + const int natural_lanes = target.natural_vector_size(op->type); + if (ramp->lanes % natural_lanes) { + // Load with lanes multiple of natural_lanes + int aligned_lanes = align_up(ramp->lanes, natural_lanes); + // Use predicate to prevent from overrun + Expr vpred; + if (is_predicated_load) { + vpred = Shuffle::make_concat({op->predicate, const_false(aligned_lanes - ramp->lanes)}); + } else { + vpred = make_vector_predicate_1s_0s(ramp->lanes, aligned_lanes - ramp->lanes); + } + auto aligned_index = Ramp::make(ramp->base, stride, aligned_lanes); + auto aligned_type = op->type.with_lanes(aligned_lanes); + value = codegen(Load::make(aligned_type, op->name, aligned_index, op->image, op->param, vpred, op->alignment)); + value = slice_vector(value, 0, ramp->lanes); + return; + } else { + CodeGen_Posix::visit(op); + return; + } + } else if (stride && (2 <= stride->value && stride->value <= 4)) { + // Structured load ST2/ST3/ST4 of SVE + + Expr base = ramp->base; + ModulusRemainder align = op->alignment; + + int aligned_stride = gcd(stride->value, align.modulus); + int offset = 0; + if (aligned_stride == stride->value) { + offset = mod_imp((int)align.remainder, aligned_stride); + } else { + const Add *add = base.as(); + if (const IntImm *add_c = add ? add->b.as() : base.as()) { + offset = mod_imp(add_c->value, stride->value); + } + } + + if (offset) { + base = simplify(base - offset); + } + + Value *load_pred_val = codegen(op->predicate); + + // We need to slice the result in to native vector lanes to use sve intrin. + // LLVM will optimize redundant ld instructions afterwards + const int slice_lanes = target.natural_vector_size(op->type); + vector results; + for (int i = 0; i < op->type.lanes(); i += slice_lanes) { + int load_base_i = i * stride->value; + Expr slice_base = simplify(base + load_base_i); + Expr slice_index = Ramp::make(slice_base, stride, slice_lanes); + std::ostringstream instr; + instr << "llvm.aarch64.sve.ld" + << stride->value + << ".sret.nxv" + << slice_lanes + << (op->type.is_float() ? 'f' : 'i') + << op->type.bits(); + llvm::Type *elt = llvm_type_of(op->type.element_of()); + llvm::Type *slice_type = get_vector_type(elt, slice_lanes); + StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type)); + std::vector arg_types{get_vector_type(i1_t, slice_lanes), PointerType::get(elt, 0)}; + llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false); + FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type); + + // Set the predicate argument + int active_lanes = std::min(op->type.lanes() - i, slice_lanes); + + Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes); + Value *vpred_val = codegen(vpred); + vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes)); + if (is_predicated_load) { + Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes); + vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val); + } + + Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base); + CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr}); + add_tbaa_metadata(load_i, op->name, slice_index); + // extract one element out of returned struct + Value *extracted = builder->CreateExtractValue(load_i, offset); + results.push_back(extracted); + } + + // Retrieve original lanes + value = concat_vectors(results); + value = slice_vector(value, 0, op->type.lanes()); + return; + } else if (op->index.type().is_vector()) { + // General Gather Load + + // Rewrite float16 case into load in uint16 and reinterpret, as it is unsupported in LLVM + if (is_float16_and_has_feature(op->type)) { + Type u16_type = op->type.with_code(halide_type_uint); + Expr equiv = Load::make(u16_type, op->name, op->index, op->image, op->param, op->predicate, op->alignment); + equiv = reinterpret(op->type, equiv); + equiv = common_subexpression_elimination(equiv); + value = codegen(equiv); + return; + } + + Type elt = op->type.element_of(); + const int load_lanes = op->type.lanes(); + const int index_bits = 32; + Type type_with_max_bits = Int(std::max(elt.bits(), index_bits)); + // The number of lanes is constrained by index vector type + const int natural_lanes = target.natural_vector_size(type_with_max_bits); + const int vscale_natural_lanes = natural_lanes / target_vscale(); + + Expr base = 0; + Value *elt_ptr = codegen_buffer_pointer(op->name, elt, base); + Value *index = codegen(op->index); + Value *load_pred_val = codegen(op->predicate); + + llvm::Type *slice_type = get_vector_type(llvm_type_of(elt), vscale_natural_lanes, VectorTypeConstraint::VScale); + llvm::Type *slice_index_type = get_vector_type(llvm_type_of(op->index.type().element_of()), vscale_natural_lanes, VectorTypeConstraint::VScale); + llvm::Type *pred_type = get_vector_type(llvm_type_of(op->predicate.type().element_of()), vscale_natural_lanes, VectorTypeConstraint::VScale); + + std::ostringstream instr; + instr << "llvm.aarch64.sve.ld1.gather.uxtw." + << (elt.bits() != 8 ? "index." : "") // index is scaled into bytes + << "nxv" + << vscale_natural_lanes + << (elt == Float(32) || elt == Float(64) ? 'f' : 'i') + << elt.bits(); + + llvm::FunctionType *fn_type = FunctionType::get(slice_type, {pred_type, elt_ptr->getType(), slice_index_type}, false); + FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type); + + // We need to slice the result in to native vector lanes to use intrinsic + vector results; + for (int i = 0; i < load_lanes; i += natural_lanes) { + Value *slice_index = slice_vector(index, i, natural_lanes); + + const int active_lanes = std::min(load_lanes - i, natural_lanes); + + Expr vpred = make_vector_predicate_1s_0s(active_lanes, natural_lanes - active_lanes); + Value *vpred_val = codegen(vpred); + if (is_predicated_load) { + Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, natural_lanes); + vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val); + } + + vpred_val = match_vector_type_scalable(vpred_val, VectorTypeConstraint::VScale); + slice_index = match_vector_type_scalable(slice_index, VectorTypeConstraint::VScale); + CallInst *gather = builder->CreateCall(fn, {vpred_val, elt_ptr, slice_index}); + add_tbaa_metadata(gather, op->name, op->index); + results.push_back(gather); + } + + // Retrieve original lanes + value = concat_vectors(results); + value = slice_vector(value, 0, load_lanes); + return; + } + } + CodeGen_Posix::visit(op); } @@ -1322,6 +1925,33 @@ void CodeGen_ARM::visit(const Shuffle *op) { } } +void CodeGen_ARM::visit(const Ramp *op) { + if (target_vscale() != 0 && op->type.is_int_or_uint()) { + if (is_const_zero(op->base) && is_const_one(op->stride)) { + codegen_func_t cg_func = [&](int lanes, const std::vector &args) { + internal_assert(args.empty()); + // Generate stepvector intrinsic for ScalableVector + return builder->CreateStepVector(llvm_type_of(op->type.with_lanes(lanes))); + }; + + // codgen with next-power-of-two lanes, because if we sliced into natural_lanes(e.g. 4), + // it would produce {0,1,2,3,0,1,..} instead of {0,1,2,3,4,5,..} + const int ret_lanes = op->type.lanes(); + const int aligned_lanes = next_power_of_two(ret_lanes); + value = codegen_with_lanes(aligned_lanes, ret_lanes, {}, cg_func); + return; + } else { + Expr broadcast_base = Broadcast::make(op->base, op->lanes); + Expr broadcast_stride = Broadcast::make(op->stride, op->lanes); + Expr step_ramp = Ramp::make(make_zero(op->base.type()), make_one(op->base.type()), op->lanes); + value = codegen(broadcast_base + broadcast_stride * step_ramp); + return; + } + } + + CodeGen_Posix::visit(op); +} + void CodeGen_ARM::visit(const Call *op) { if (op->is_intrinsic(Call::sorted_avg)) { value = codegen(halving_add(op->args[0], op->args[1])); @@ -1407,7 +2037,6 @@ void CodeGen_ARM::visit(const Call *op) { for (const auto &i : cast_rewrites) { if (expr_match(i.first, op, matches)) { Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes())); - debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n"; value = codegen(replacement); return; } @@ -1464,14 +2093,28 @@ void CodeGen_ARM::visit(const LE *op) { } void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { - if (neon_intrinsics_disabled() || - op->op == VectorReduce::Or || - op->op == VectorReduce::And || - op->op == VectorReduce::Mul) { + if (simd_intrinsics_disabled()) { CodeGen_Posix::codegen_vector_reduce(op, init); return; } + if (codegen_dot_product_vector_reduce(op, init)) { + return; + } + if (codegen_pairwise_vector_reduce(op, init)) { + return; + } + if (codegen_across_vector_reduce(op, init)) { + return; + } + CodeGen_Posix::codegen_vector_reduce(op, init); +} + +bool CodeGen_ARM::codegen_dot_product_vector_reduce(const VectorReduce *op, const Expr &init) { + if (op->op != VectorReduce::Add) { + return false; + } + struct Pattern { VectorReduce::Operator reduce_op; int factor; @@ -1485,11 +2128,23 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_i8x_)), "dot_product", Target::ARMDotProd}, {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::ARMDotProd}, {VectorReduce::Add, 4, u32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::ARMDotProd}, + {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_i8x_)), "dot_product", Target::SVE2}, + {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::SVE2}, + {VectorReduce::Add, 4, u32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Target::SVE2}, + {VectorReduce::Add, 4, i64(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", Target::SVE2}, + {VectorReduce::Add, 4, i64(widening_mul(wild_u16x_, wild_u16x_)), "dot_product", Target::SVE2}, + {VectorReduce::Add, 4, u64(widening_mul(wild_u16x_, wild_u16x_)), "dot_product", Target::SVE2}, // A sum is the same as a dot product with a vector of ones, and this appears to // be a bit faster. {VectorReduce::Add, 4, i32(wild_i8x_), "dot_product", Target::ARMDotProd, {1}}, {VectorReduce::Add, 4, i32(wild_u8x_), "dot_product", Target::ARMDotProd, {1}}, {VectorReduce::Add, 4, u32(wild_u8x_), "dot_product", Target::ARMDotProd, {1}}, + {VectorReduce::Add, 4, i32(wild_i8x_), "dot_product", Target::SVE2, {1}}, + {VectorReduce::Add, 4, i32(wild_u8x_), "dot_product", Target::SVE2, {1}}, + {VectorReduce::Add, 4, u32(wild_u8x_), "dot_product", Target::SVE2, {1}}, + {VectorReduce::Add, 4, i64(wild_i16x_), "dot_product", Target::SVE2, {1}}, + {VectorReduce::Add, 4, i64(wild_u16x_), "dot_product", Target::SVE2, {1}}, + {VectorReduce::Add, 4, u64(wild_u16x_), "dot_product", Target::SVE2, {1}}, }; // clang-format on @@ -1507,7 +2162,7 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init Expr equiv = VectorReduce::make(op->op, op->value, op->value.type().lanes() / p.factor); equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); codegen_vector_reduce(equiv.as(), init); - return; + return true; } for (int i : p.extra_operands) { @@ -1518,6 +2173,7 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init if (!i.defined()) { i = make_zero(op->type); } + if (const Shuffle *s = matches[0].as()) { if (s->is_broadcast()) { // LLVM wants the broadcast as the second operand for the broadcasting @@ -1525,15 +2181,27 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init std::swap(matches[0], matches[1]); } } - value = call_overloaded_intrin(op->type, p.intrin, {i, matches[0], matches[1]}); - if (value) { - return; + + if (Value *v = call_overloaded_intrin(op->type, p.intrin, {i, matches[0], matches[1]})) { + value = v; + return true; } } } + return false; +} + +bool CodeGen_ARM::codegen_pairwise_vector_reduce(const VectorReduce *op, const Expr &init) { + if (op->op != VectorReduce::Add && + op->op != VectorReduce::Max && + op->op != VectorReduce::Min) { + return false; + } + // TODO: Move this to be patterns? The patterns are pretty trivial, but some // of the other logic is tricky. + int factor = op->value.type().lanes() / op->type.lanes(); const char *intrin = nullptr; vector intrin_args; Expr accumulator = init; @@ -1547,33 +2215,38 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init narrow = lossless_cast(narrow_type.with_code(Type::UInt), op->value); } if (narrow.defined()) { - if (init.defined() && target.bits == 32) { - // On 32-bit, we have an intrinsic for widening add-accumulate. + if (init.defined() && (target.bits == 32 || target.has_feature(Target::SVE2))) { + // On 32-bit or SVE2, we have an intrinsic for widening add-accumulate. // TODO: this could be written as a pattern with widen_right_add (#6951). intrin = "pairwise_widening_add_accumulate"; intrin_args = {accumulator, narrow}; accumulator = Expr(); + } else if (target.has_feature(Target::SVE2)) { + intrin = "pairwise_widening_add_accumulate"; + intrin_args = {Expr(0), narrow}; + accumulator = Expr(); } else { // On 64-bit, LLVM pattern matches widening add-accumulate if // we give it the widening add. intrin = "pairwise_widening_add"; intrin_args = {narrow}; } - } else { + } else if (!target.has_feature(Target::SVE2)) { + // Exclude SVE, as it process lanes in different order (even/odd wise) than NEON intrin = "pairwise_add"; intrin_args = {op->value}; } - } else if (op->op == VectorReduce::Min && factor == 2) { + } else if (op->op == VectorReduce::Min && factor == 2 && !target.has_feature(Target::SVE2)) { intrin = "pairwise_min"; intrin_args = {op->value}; - } else if (op->op == VectorReduce::Max && factor == 2) { + } else if (op->op == VectorReduce::Max && factor == 2 && !target.has_feature(Target::SVE2)) { intrin = "pairwise_max"; intrin_args = {op->value}; } if (intrin) { - value = call_overloaded_intrin(op->type, intrin, intrin_args); - if (value) { + if (Value *v = call_overloaded_intrin(op->type, intrin, intrin_args)) { + value = v; if (accumulator.defined()) { // We still have an initial value to take care of string n = unique_name('t'); @@ -1595,11 +2268,126 @@ void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init codegen(accumulator); sym_pop(n); } - return; + return true; } } - CodeGen_Posix::codegen_vector_reduce(op, init); + return false; +} + +bool CodeGen_ARM::codegen_across_vector_reduce(const VectorReduce *op, const Expr &init) { + if (target_vscale() == 0) { + // Leave this to vanilla codegen to emit "llvm.vector.reduce." intrinsic, + // which doesn't support scalable vector in LLVM 14 + return false; + } + + if (op->op != VectorReduce::Add && + op->op != VectorReduce::Max && + op->op != VectorReduce::Min) { + return false; + } + + Expr val = op->value; + const int output_lanes = op->type.lanes(); + const int native_lanes = target.natural_vector_size(op->type); + const int input_lanes = val.type().lanes(); + const int input_bits = op->type.bits(); + Type elt = op->type.element_of(); + + if (output_lanes != 1 || input_lanes < 2) { + return false; + } + + Expr (*binop)(Expr, Expr) = nullptr; + std::string op_name; + switch (op->op) { + case VectorReduce::Add: + binop = Add::make; + op_name = "add"; + break; + case VectorReduce::Min: + binop = Min::make; + op_name = "min"; + break; + case VectorReduce::Max: + binop = Max::make; + op_name = "max"; + break; + default: + internal_error << "unreachable"; + } + + if (input_lanes == native_lanes) { + std::stringstream name; // e.g. llvm.aarch64.sve.sminv.nxv4i32 + name << "llvm.aarch64.sve." + << (op->type.is_float() ? "f" : op->type.is_int() ? "s" : + "u") + << op_name << "v" + << ".nxv" << (native_lanes / target_vscale()) << (op->type.is_float() ? "f" : "i") << input_bits; + + // Integer add accumulation output is 64 bit only + const bool type_upgraded = op->op == VectorReduce::Add && op->type.is_int_or_uint(); + const int output_bits = type_upgraded ? 64 : input_bits; + Type intrin_ret_type = op->type.with_bits(output_bits); + + const string intrin_name = name.str(); + + Expr pred = const_true(native_lanes); + vector args{pred, op->value}; + + // Make sure the declaration exists, or the codegen for + // call will assume that the args should scalarize. + if (!module->getFunction(intrin_name)) { + vector arg_types; + for (const Expr &e : args) { + arg_types.push_back(llvm_type_with_constraint(e.type(), false, VectorTypeConstraint::VScale)); + } + FunctionType *func_t = FunctionType::get(llvm_type_with_constraint(intrin_ret_type, false, VectorTypeConstraint::VScale), + arg_types, false); + llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get()); + } + + Expr equiv = Call::make(intrin_ret_type, intrin_name, args, Call::PureExtern); + if (type_upgraded) { + equiv = Cast::make(op->type, equiv); + } + if (init.defined()) { + equiv = binop(init, equiv); + } + equiv = common_subexpression_elimination(equiv); + equiv.accept(this); + return true; + + } else if (input_lanes < native_lanes) { + // Create equivalent where lanes==native_lanes by padding data which doesn't affect the result + Expr padding; + const int inactive_lanes = native_lanes - input_lanes; + + switch (op->op) { + case VectorReduce::Add: + padding = make_zero(elt.with_lanes(inactive_lanes)); + break; + case VectorReduce::Min: + padding = elt.with_lanes(inactive_lanes).min(); + break; + case VectorReduce::Max: + padding = elt.with_lanes(inactive_lanes).max(); + break; + default: + internal_error << "unreachable"; + } + + Expr equiv = VectorReduce::make(op->op, Shuffle::make_concat({val, padding}), 1); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = common_subexpression_elimination(equiv); + equiv.accept(this); + return true; + } + + return false; } Type CodeGen_ARM::upgrade_type_for_arithmetic(const Type &t) const { @@ -1623,6 +2411,39 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const { return CodeGen_Posix::upgrade_type_for_storage(t); } +Value *CodeGen_ARM::codegen_with_lanes(int slice_lanes, int total_lanes, + const std::vector &args, codegen_func_t &cg_func) { + std::vector llvm_args; + // codegen args + for (const auto &arg : args) { + llvm_args.push_back(codegen(arg)); + } + + if (slice_lanes == total_lanes) { + // codegen op + return cg_func(slice_lanes, llvm_args); + } + + std::vector results; + for (int start = 0; start < total_lanes; start += slice_lanes) { + std::vector sliced_args; + for (auto &llvm_arg : llvm_args) { + Value *v = llvm_arg; + if (get_vector_num_elements(llvm_arg->getType()) == total_lanes) { + // Except for scalar argument which some ops have, arguments are sliced + v = slice_vector(llvm_arg, start, slice_lanes); + } + sliced_args.push_back(v); + } + // codegen op + value = cg_func(slice_lanes, sliced_args); + results.push_back(value); + } + // Restore the results into vector with total_lanes + value = concat_vectors(results); + return slice_vector(value, 0, total_lanes); +} + string CodeGen_ARM::mcpu_target() const { if (target.bits == 32) { if (target.has_feature(Target::ARMv7s)) { @@ -1635,6 +2456,8 @@ string CodeGen_ARM::mcpu_target() const { return "cyclone"; } else if (target.os == Target::OSX) { return "apple-a12"; + } else if (target.has_feature(Target::SVE2)) { + return "cortex-x1"; } else { return "generic"; } @@ -1667,6 +2490,7 @@ string CodeGen_ARM::mattrs() const { } } else { // TODO: Should Halide's SVE flags be 64-bit only? + // TODO: Sound we ass "-neon" if NoNEON is set? Does this make any sense? if (target.has_feature(Target::SVE2)) { attrs.emplace_back("+sve2"); } else if (target.has_feature(Target::SVE)) { @@ -1689,7 +2513,21 @@ bool CodeGen_ARM::use_soft_float_abi() const { } int CodeGen_ARM::native_vector_bits() const { - return 128; + if (target.has_feature(Target::SVE) || target.has_feature(Target::SVE2)) { + return std::max(target.vector_bits, 128); + } else { + return 128; + } +} + +int CodeGen_ARM::target_vscale() const { + if (target.features_any_of({Target::SVE, Target::SVE2})) { + user_assert(target.vector_bits != 0) << "For SVE/SVE2 support, target_vector_bits= must be set in target.\n"; + user_assert((target.vector_bits % 128) == 0) << "For SVE/SVE2 support, target_vector_bits must be a multiple of 128.\n"; + return target.vector_bits / 128; + } + + return 0; } bool CodeGen_ARM::supports_call_as_float16(const Call *op) const { diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 8922461524c5..1871460569c3 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -657,7 +657,11 @@ void CodeGen_LLVM::end_func(const std::vector &args) { } } - internal_assert(!verifyFunction(*function, &llvm::errs())); + bool valid = !verifyFunction(*function, &llvm::errs()); + if (!valid) { + function->print(dbgs()); + } + internal_assert(valid) << "Generated function does not pass LLVM's verifyFunction.\n"; current_function_args.clear(); } @@ -1348,10 +1352,6 @@ bool is_power_of_two(int x) { return (x & (x - 1)) == 0; } -int next_power_of_two(int x) { - return static_cast(1) << static_cast(std::ceil(std::log2(x))); -} - } // namespace Type CodeGen_LLVM::upgrade_type_for_arithmetic(const Type &t) const { @@ -1449,16 +1449,16 @@ void CodeGen_LLVM::visit(const Cast *op) { } value = codegen(op->value); - llvm::Type *llvm_dst = llvm_type_of(dst); + llvm::Type *llvm_dst = llvm_type_of(dst.element_of()); + if (value->getType()->isVectorTy()) { + llvm_dst = VectorType::get(llvm_dst, dyn_cast(value->getType())->getElementCount()); + } if (dst.is_handle() && src.is_handle()) { value = builder->CreateBitCast(value, llvm_dst); } else if (dst.is_handle() || src.is_handle()) { internal_error << "Can't cast from " << src << " to " << dst << "\n"; } else if (!src.is_float() && !dst.is_float()) { - // Widening integer casts either zero extend or sign extend, - // depending on the source type. Narrowing integer casts - // always truncate. value = builder->CreateIntCast(value, llvm_dst, src.is_int()); } else if (src.is_float() && dst.is_int()) { value = builder->CreateFPToSI(value, llvm_dst); @@ -1879,6 +1879,11 @@ void CodeGen_LLVM::visit(const Select *op) { Value *a = codegen(op->true_value); Value *b = codegen(op->false_value); + if (a->getType()->isVectorTy()) { + cmp = match_vector_type_scalable(cmp, a); + b = match_vector_type_scalable(b, a); + } + if (!try_vector_predication_intrinsic("llvm.vp.select", llvm_type_of(op->type), op->type.lanes(), NoMask(), {VPArg(cmp), VPArg(a, 0), VPArg(b)})) { value = builder->CreateSelect(cmp, a, b); @@ -2266,6 +2271,7 @@ void CodeGen_LLVM::codegen_predicated_store(const Store *op) { Value *vpred = codegen(op->predicate); Halide::Type value_type = op->value.type(); Value *val = codegen(op->value); + vpred = match_vector_type_scalable(vpred, value); int alignment = value_type.bytes(); int native_bytes = native_vector_bits() / 8; @@ -2357,7 +2363,6 @@ llvm::Value *CodeGen_LLVM::codegen_vector_load(const Type &type, const std::stri llvm::Value *vpred, bool slice_to_native, llvm::Value *stride) { debug(4) << "Vectorize predicated dense vector load:\n\t" << "(" << type << ")" << name << "[ramp(base, 1, " << type.lanes() << ")]\n"; - int align_bytes = type.bytes(); // The size of a single element int native_bits = native_vector_bits(); @@ -2402,7 +2407,7 @@ llvm::Value *CodeGen_LLVM::codegen_vector_load(const Type &type, const std::stri Value *elt_ptr = codegen_buffer_pointer(name, type.element_of(), slice_base); Value *vec_ptr = builder->CreatePointerCast(elt_ptr, slice_type->getPointerTo()); - Value *slice_mask = (vpred != nullptr) ? slice_vector(vpred, i, slice_lanes) : nullptr; + Value *slice_mask = (vpred != nullptr) ? match_vector_type_scalable(slice_vector(vpred, i, slice_lanes), slice_type) : nullptr; MaskVariant vp_slice_mask = slice_mask ? MaskVariant(slice_mask) : AllEnabledMask(); Instruction *load_inst = nullptr; @@ -3304,6 +3309,8 @@ void CodeGen_LLVM::visit(const Call *op) { value = codegen(lower_extract_bits(op)); } else if (op->is_intrinsic(Call::concat_bits)) { value = codegen(lower_concat_bits(op)); + } else if (op->is_intrinsic(Call::get_runtime_vscale)) { + value = builder->CreateVScale(ConstantInt::get(i32_t, 1)); } else if (op->is_intrinsic()) { Expr lowered = lower_intrinsic(op); if (!lowered.defined()) { @@ -3478,6 +3485,11 @@ void CodeGen_LLVM::visit(const Call *op) { << halide_arg << "\n"; args[i] = builder->CreatePointerCast(args[i], t); } + } else if (args[i]->getType()->isVectorTy()) { + llvm::Type *t = func_t->getParamType(i); + if (t->isVectorTy()) { + args[i] = match_vector_type_scalable(args[i], t); + } } } } @@ -4274,14 +4286,14 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini break; case VectorReduce::Min: name = "fmin"; - // TODO(zvookin): Not correct for stricT_float. See: https://github.com/halide/Halide/issues/7118 + // TODO(zvookin): Not correct for strict_float. See: https://github.com/halide/Halide/issues/7118 if (takes_initial_value && !initial_value.defined()) { initial_value = op->type.max(); } break; case VectorReduce::Max: name = "fmax"; - // TODO(zvookin): Not correct for stricT_float. See: https://github.com/halide/Halide/issues/7118 + // TODO(zvookin): Not correct for strict_float. See: https://github.com/halide/Halide/issues/7118 if (takes_initial_value && !initial_value.defined()) { initial_value = op->type.min(); } @@ -4752,16 +4764,45 @@ Value *CodeGen_LLVM::call_intrin(const llvm::Type *result_type, int intrin_lanes llvm::FunctionType *intrin_type = intrin->getFunctionType(); for (int i = 0; i < (int)arg_values.size(); i++) { - if (arg_values[i]->getType() != intrin_type->getParamType(i)) { - // TODO: Change this to call convert_fixed_or_scalable_vector_type and - // remove normalize_fixed_scalable_vector_type, fixed_to_scalable_vector_type, - // and scalable_to_fixed_vector_type - arg_values[i] = normalize_fixed_scalable_vector_type(intrin_type->getParamType(i), arg_values[i]); - } - if (arg_values[i]->getType() != intrin_type->getParamType(i)) { - // There can be some mismatches in types, such as when passing scalar Halide type T - // to LLVM vector type <1 x T>. - arg_values[i] = builder->CreateBitCast(arg_values[i], intrin_type->getParamType(i)); + llvm::Type *arg_type = arg_values[i]->getType(); + llvm::Type *formal_param_type = intrin_type->getParamType(i); + if (arg_type != formal_param_type) { + bool both_vectors = isa(arg_type) && isa(formal_param_type); + bool arg_is_fixed = isa(arg_type); + bool formal_is_fixed = isa(formal_param_type); + + // Apparently the bitcast in the else branch below can + // change the scalar type and vector length together so + // long as the total bits are the same. E.g. on HVX, + // <128 x i16> to <64 x i32>. This is probably a bug, but + // it seems to be allowed so it is also supported in the + // fixed/vscale matching path. + if (both_vectors && (arg_is_fixed != formal_is_fixed) && (effective_vscale != 0)) { + bool scalar_types_match = arg_type->getScalarType() == formal_param_type->getScalarType(); + if (arg_is_fixed && !scalar_types_match) { + unsigned fixed_count = dyn_cast(formal_param_type)->getElementCount().getKnownMinValue() * effective_vscale; + llvm::Type *match_scalar_type = llvm::VectorType::get(formal_param_type->getScalarType(), fixed_count, false); + arg_values[i] = builder->CreateBitCast(arg_values[i], match_scalar_type); + } + llvm::ElementCount ec = dyn_cast(arg_values[i]->getType())->getElementCount(); + int mid_count = formal_is_fixed ? (ec.getKnownMinValue() * effective_vscale) : (ec.getFixedValue() / effective_vscale); + llvm::Type *match_vector_flavor_type = llvm::VectorType::get(arg_values[i]->getType()->getScalarType(), mid_count, !formal_is_fixed); + arg_values[i] = convert_fixed_or_scalable_vector_type(arg_values[i], match_vector_flavor_type); + if (formal_is_fixed && !scalar_types_match) { + arg_values[i] = builder->CreateBitCast(arg_values[i], formal_param_type); + } + } else { + // TODO(https://github.com/halide/Halide/issues/8117): That this + // can happen is probably a bug. It will crash in module + // validation for anything LLVM doesn't support. Better to + // regularize the Halide IR by inserting an intentional cast or + // to add extra intrinsics patterns. At the very least, some + // extra validation should be added here. + + // There can be some mismatches in types, such as when passing + // scalar Halide type T to LLVM vector type <1 x T>. + arg_values[i] = builder->CreateBitCast(arg_values[i], formal_param_type); + } } } @@ -4785,16 +4826,45 @@ Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) { return builder->CreateExtractElement(vec, (uint64_t)start); } - vector indices(size); - for (int i = 0; i < size; i++) { - int idx = start + i; - if (idx >= 0 && idx < vec_lanes) { - indices[i] = idx; - } else { - indices[i] = -1; + bool is_fixed = isa(vec->getType()); + + // TODO(https://github.com/halide/Halide/issues/8118): It is likely worth + // looking into using llvm.vector.{extract,insert} for this case + // too. However that would need to be validated performance wise for all + // architectures. + if (is_fixed) { + vector indices(size); + for (int i = 0; i < size; i++) { + int idx = start + i; + if (idx >= 0 && idx < vec_lanes) { + indices[i] = idx; + } else { + indices[i] = -1; + } } + return shuffle_vectors(vec, indices); + } else { + // Extract a fixed vector with all the values in the source. + // Then insert back into a vector extended to size. This will + // be a scalable vector if size can be scalable, fixed + // otherwise. + llvm::Type *scalar_type = vec->getType()->getScalarType(); + + int intermediate_lanes = std::min(size, vec_lanes - start); + llvm::Type *intermediate_type = get_vector_type(scalar_type, intermediate_lanes, VectorTypeConstraint::Fixed); + + vec = builder->CreateExtractVector(intermediate_type, vec, ConstantInt::get(i64_t, start)); + + // Insert vector into a poison vector and return. + int effective_size = is_fixed ? size : (size / effective_vscale); + llvm::VectorType *result_type = dyn_cast(get_vector_type(scalar_type, effective_size, + is_fixed ? VectorTypeConstraint::Fixed : VectorTypeConstraint::VScale)); + Constant *poison = PoisonValue::get(scalar_type); + llvm::Value *result_vec = ConstantVector::getSplat(result_type->getElementCount(), poison); + vec = builder->CreateInsertVector(result_type, result_vec, vec, ConstantInt::get(i64_t, 0)); + + return vec; } - return shuffle_vectors(vec, indices); } Value *CodeGen_LLVM::concat_vectors(const vector &v) { @@ -4831,6 +4901,11 @@ Value *CodeGen_LLVM::concat_vectors(const vector &v) { } int w_matched = std::max(w1, w2); + if (v1->getType() != v2->getType()) { + // arbitrary decision here to convert v2 to type of v1 rather than + // target fixed or scalable. + v2 = convert_fixed_or_scalable_vector_type(v2, v1->getType()); + } internal_assert(v1->getType() == v2->getType()); vector indices(w1 + w2); @@ -4903,8 +4978,11 @@ std::pair CodeGen_LLVM::find_vector_runtime_function(cons while (l < lanes) { l *= 2; } - for (int i = l; i > 1; i /= 2) { - sizes_to_try.push_back(i); + + // This will be 1 for non-vscale architectures. + int vscale_divisor = std::max(effective_vscale, 1); + for (int i = l; i > vscale_divisor; i /= 2) { + sizes_to_try.push_back(i / vscale_divisor); } // If none of those match, we'll also try doubling @@ -4913,10 +4991,11 @@ std::pair CodeGen_LLVM::find_vector_runtime_function(cons // vector implementation). sizes_to_try.push_back(l * 2); + std::string vec_prefix = effective_vscale != 0 ? "nx" : "x"; for (int l : sizes_to_try) { - llvm::Function *vec_fn = module->getFunction(name + "x" + std::to_string(l)); + llvm::Function *vec_fn = module->getFunction(name + vec_prefix + std::to_string(l)); if (vec_fn) { - return {vec_fn, l}; + return {vec_fn, l * vscale_divisor}; } } @@ -4982,6 +5061,42 @@ llvm::Value *CodeGen_LLVM::normalize_fixed_scalable_vector_type(llvm::Type *desi return result; } +llvm::Value *CodeGen_LLVM::match_vector_type_scalable(llvm::Value *value, VectorTypeConstraint constraint) { + if (constraint == VectorTypeConstraint::None) { + return value; + } + + llvm::Type *value_type = value->getType(); + if (!isa(value_type)) { + return value; + } + + bool value_fixed = isa(value_type); + bool guide_fixed = (constraint == VectorTypeConstraint::Fixed); + if (value_fixed != guide_fixed) { + int value_scaled_elements = get_vector_num_elements(value_type); + if (!guide_fixed) { + value_scaled_elements /= effective_vscale; + } + llvm::Type *desired_type = get_vector_type(value_type->getScalarType(), value_scaled_elements, + guide_fixed ? VectorTypeConstraint::Fixed : VectorTypeConstraint::VScale); + value = convert_fixed_or_scalable_vector_type(value, desired_type); + } + + return value; +} + +llvm::Value *CodeGen_LLVM::match_vector_type_scalable(llvm::Value *value, llvm::Type *guide_type) { + if (!isa(guide_type)) { + return value; + } + return match_vector_type_scalable(value, isa(guide_type) ? VectorTypeConstraint::Fixed : VectorTypeConstraint::VScale); +} + +llvm::Value *CodeGen_LLVM::match_vector_type_scalable(llvm::Value *value, llvm::Value *guide) { + return match_vector_type_scalable(value, guide->getType()); +} + llvm::Value *CodeGen_LLVM::convert_fixed_or_scalable_vector_type(llvm::Value *arg, llvm::Type *desired_type) { llvm::Type *arg_type = arg->getType(); @@ -5007,13 +5122,21 @@ llvm::Value *CodeGen_LLVM::convert_fixed_or_scalable_vector_type(llvm::Value *ar if (isa(arg_type) && isa(result_type)) { use_insert = true; + if (arg_elements > result_elements) { + arg = slice_vector(arg, 0, result_elements); + } + arg_elements = result_elements; } else if (isa(result_type) && isa(arg_type)) { use_insert = false; + if (arg_elements < result_elements) { + arg = slice_vector(arg, 0, result_elements); + } + arg_elements = result_elements; } else { // Use extract to make smaller, insert to make bigger. // A somewhat arbitary decision. - use_insert = (arg_elements > result_elements); + use_insert = (arg_elements < result_elements); } std::string intrin_name = "llvm.vector."; @@ -5165,10 +5288,27 @@ llvm::Type *CodeGen_LLVM::get_vector_type(llvm::Type *t, int n, bool scalable = false; switch (type_constraint) { case VectorTypeConstraint::None: - scalable = effective_vscale != 0 && - ((n % effective_vscale) == 0); - if (scalable) { - n = n / effective_vscale; + if (effective_vscale > 0) { + bool wide_enough = true; + // TODO(https://github.com/halide/Halide/issues/8119): Architecture + // specific code should not go here. Ideally part of this can go + // away via LLVM fixes and modifying intrinsic selection to handle + // scalable vs. fixed vectors. Making this method virtual is + // possibly expensive. + if (target.arch == Target::ARM) { + if (!target.has_feature(Target::NoNEON)) { + // force booleans into bytes. TODO(https://github.com/halide/Halide/issues/8119): figure out a better way to do this. + int bit_size = std::max((int)t->getScalarSizeInBits(), 8); + wide_enough = (bit_size * n) > 128; + } else { + // TODO(https://github.com/halide/Halide/issues/8119): AArch64 SVE2 support is crashy with scalable vectors of min size 1. + wide_enough = (n / effective_vscale) > 1; + } + } + scalable = wide_enough && ((n % effective_vscale) == 0); + if (scalable) { + n = n / effective_vscale; + } } break; case VectorTypeConstraint::Fixed: @@ -5190,10 +5330,12 @@ llvm::Constant *CodeGen_LLVM::get_splat(int lanes, llvm::Constant *value, bool scalable = false; switch (type_constraint) { case VectorTypeConstraint::None: - scalable = effective_vscale != 0 && - ((lanes % effective_vscale) == 0); - if (scalable) { - lanes = lanes / effective_vscale; + if (effective_vscale > 0) { + bool wide_enough = (lanes / effective_vscale) > 1; + scalable = wide_enough && ((lanes % effective_vscale) == 0); + if (scalable) { + lanes = lanes / effective_vscale; + } } break; case VectorTypeConstraint::Fixed: diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index b3e9cdabd498..908929e54373 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -579,6 +579,13 @@ class CodeGen_LLVM : public IRVisitor { llvm::Constant *get_splat(int lanes, llvm::Constant *value, VectorTypeConstraint type_constraint = VectorTypeConstraint::None) const; + /** Make sure a value type has the same scalable/fixed vector type as a guide. */ + // @{ + llvm::Value *match_vector_type_scalable(llvm::Value *value, VectorTypeConstraint constraint); + llvm::Value *match_vector_type_scalable(llvm::Value *value, llvm::Type *guide); + llvm::Value *match_vector_type_scalable(llvm::Value *value, llvm::Value *guide); + // @} + /** Support for generating LLVM vector predication intrinsics * ("@llvm.vp.*" and "@llvm.experimental.vp.*") */ diff --git a/src/Function.cpp b/src/Function.cpp index cbb4b61574d4..b72a39e1c90a 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -491,8 +491,10 @@ ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgumen } // namespace void Function::deep_copy(const FunctionPtr ©, DeepCopyMap &copied_map) const { - internal_assert(copy.defined() && contents.defined()) - << "Cannot deep-copy undefined Function\n"; + internal_assert(copy.defined()) + << "Cannot deep-copy to undefined Function\n"; + internal_assert(contents.defined()) + << "Cannot deep-copy from undefined Function\n"; // Add reference to this Function's deep-copy to the map in case of // self-reference, e.g. self-reference in an Definition. diff --git a/src/IR.cpp b/src/IR.cpp index c0bdb718291d..81cf0a0f41ff 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -690,6 +690,7 @@ const char *const intrinsic_op_names[] = { "widening_shift_left", "widening_shift_right", "widening_sub", + "get_runtime_vscale", }; static_assert(sizeof(intrinsic_op_names) / sizeof(intrinsic_op_names[0]) == Call::IntrinsicOpCount, diff --git a/src/IR.h b/src/IR.h index 252e4588db03..31aa3f195e43 100644 --- a/src/IR.h +++ b/src/IR.h @@ -629,6 +629,8 @@ struct Call : public ExprNode { widening_shift_right, widening_sub, + get_runtime_vscale, + IntrinsicOpCount // Sentinel: keep last. }; diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index 3e5d95d787e6..10521f82ac03 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -262,6 +262,9 @@ class IRMatch : public IRVisitor { if (result && e && types_match(op->type, e->type)) { expr = e->value; op->value.accept(this); + } else if (op->lanes == 0 && types_match(op->value.type(), expr.type())) { + // zero lanes means any number of lanes, so match scalars too. + op->value.accept(this); } else { result = false; } diff --git a/src/LLVM_Output.cpp b/src/LLVM_Output.cpp index 6b54aeef0e97..e40441b388f0 100644 --- a/src/LLVM_Output.cpp +++ b/src/LLVM_Output.cpp @@ -331,6 +331,12 @@ std::unique_ptr clone_module(const llvm::Module &module_in) { // Read it back in. llvm::MemoryBufferRef buffer_ref(llvm::StringRef(clone_buffer.data(), clone_buffer.size()), "clone_buffer"); auto cloned_module = llvm::parseBitcodeFile(buffer_ref, module_in.getContext()); + + // TODO(): Add support for returning the error. + if (!cloned_module) { + llvm::dbgs() << cloned_module.takeError(); + module_in.print(llvm::dbgs(), nullptr, false, true); + } internal_assert(cloned_module); return std::move(cloned_module.get()); diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index fd7a12d66995..a207b3ce63f5 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -10,6 +10,7 @@ #include "Monotonic.h" #include "Simplify.h" #include "Substitute.h" +#include "Util.h" #include namespace Halide { @@ -17,10 +18,6 @@ namespace Internal { namespace { -int64_t next_power_of_two(int64_t x) { - return static_cast(1) << static_cast(std::ceil(std::log2(x))); -} - using std::map; using std::string; using std::vector; diff --git a/src/Util.h b/src/Util.h index 15c297796911..bce0a7f1d015 100644 --- a/src/Util.h +++ b/src/Util.h @@ -13,6 +13,7 @@ /** \file * Various utility functions used internally Halide. */ +#include #include #include #include @@ -532,6 +533,16 @@ int clz64(uint64_t x); int ctz64(uint64_t x); // @} +/** Return an integer 2^n, for some n, which is >= x. Argument x must be > 0. */ +inline int64_t next_power_of_two(int64_t x) { + return static_cast(1) << static_cast(std::ceil(std::log2(x))); +} + +template +inline T align_up(T x, int n) { + return (x + n - 1) / n * n; +} + } // namespace Internal } // namespace Halide diff --git a/src/WasmExecutor.cpp b/src/WasmExecutor.cpp index b99efdc6d67e..bfe66213f44f 100644 --- a/src/WasmExecutor.cpp +++ b/src/WasmExecutor.cpp @@ -101,11 +101,6 @@ struct debug_sink { // BDMalloc // --------------------- -template -inline T align_up(T p, int alignment = 32) { - return (p + alignment - 1) & ~(alignment - 1); -} - // Debugging our Malloc is extremely noisy and usually undesired #define BDMALLOC_DEBUG_LEVEL 0 @@ -318,7 +313,7 @@ std::vector compile_to_wasm(const Module &module, const std::string &fn_na stack_size += cg->get_requested_alloca_total(); } - stack_size = align_up(stack_size); + stack_size = align_up(stack_size, 32); wdebug(1) << "Requesting stack size of " << stack_size << "\n"; std::unique_ptr llvm_module = @@ -708,7 +703,7 @@ wasm32_ptr_t hostbuf_to_wasmbuf(WabtContext &wabt_context, const halide_buffer_t const size_t dims_size_in_bytes = sizeof(halide_dimension_t) * src->dimensions; const size_t dims_offset = sizeof(wasm_halide_buffer_t); const size_t mem_needed_base = sizeof(wasm_halide_buffer_t) + dims_size_in_bytes; - const size_t host_offset = align_up(mem_needed_base); + const size_t host_offset = align_up(mem_needed_base, 32); const size_t host_size_in_bytes = src->size_in_bytes(); const size_t mem_needed = host_offset + host_size_in_bytes; @@ -1613,7 +1608,7 @@ wasm32_ptr_t hostbuf_to_wasmbuf(const Local &context, const halide_buff const size_t dims_size_in_bytes = sizeof(halide_dimension_t) * src->dimensions; const size_t dims_offset = sizeof(wasm_halide_buffer_t); const size_t mem_needed_base = sizeof(wasm_halide_buffer_t) + dims_size_in_bytes; - const size_t host_offset = align_up(mem_needed_base); + const size_t host_offset = align_up(mem_needed_base, 32); const size_t host_size_in_bytes = src->size_in_bytes(); const size_t mem_needed = host_offset + host_size_in_bytes; diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 1a19202745bb..1d0843be0329 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1246,6 +1246,10 @@ enum halide_error_code_t { /** A factor used to split a loop was discovered to be zero or negative at * runtime. */ halide_error_code_split_factor_not_positive = -46, + + /** "vscale" value of Scalable Vector detected in runtime does not match + * the vscale value used in compilation. */ + halide_error_code_vscale_invalid = -47, }; /** Halide calls the functions below on various error conditions. The @@ -1321,7 +1325,7 @@ extern int halide_error_storage_bound_too_small(void *user_context, const char * int provided_size, int required_size); extern int halide_error_device_crop_failed(void *user_context); extern int halide_error_split_factor_not_positive(void *user_context, const char *func_name, const char *orig, const char *outer, const char *inner, const char *factor_str, int factor); - +extern int halide_error_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale); // @} /** Optional features a compilation Target can have. diff --git a/src/runtime/aarch64.ll b/src/runtime/aarch64.ll index 9ae3b8e46ac2..c68a4f05fb42 100644 --- a/src/runtime/aarch64.ll +++ b/src/runtime/aarch64.ll @@ -48,25 +48,34 @@ define weak_odr <2 x i64> @vabdl_u32x2(<2 x i32> %a, <2 x i32> %b) nounwind alwa declare <4 x float> @llvm.aarch64.neon.frecpe.v4f32(<4 x float> %x) nounwind readnone; declare <2 x float> @llvm.aarch64.neon.frecpe.v2f32(<2 x float> %x) nounwind readnone; +declare float @llvm.aarch64.neon.frecpe.f32(float) declare <4 x float> @llvm.aarch64.neon.frsqrte.v4f32(<4 x float> %x) nounwind readnone; declare <2 x float> @llvm.aarch64.neon.frsqrte.v2f32(<2 x float> %x) nounwind readnone; +declare float @llvm.aarch64.neon.frsqrte.f32(float) declare <4 x float> @llvm.aarch64.neon.frecps.v4f32(<4 x float> %x, <4 x float> %y) nounwind readnone; declare <2 x float> @llvm.aarch64.neon.frecps.v2f32(<2 x float> %x, <2 x float> %y) nounwind readnone; +declare float @llvm.aarch64.neon.frecps.f32(float, float) declare <4 x float> @llvm.aarch64.neon.frsqrts.v4f32(<4 x float> %x, <4 x float> %y) nounwind readnone; declare <2 x float> @llvm.aarch64.neon.frsqrts.v2f32(<2 x float> %x, <2 x float> %y) nounwind readnone; +declare float @llvm.aarch64.neon.frsqrts.f32(float, float) + declare <8 x half> @llvm.aarch64.neon.frecpe.v8f16(<8 x half> %x) nounwind readnone; declare <4 x half> @llvm.aarch64.neon.frecpe.v4f16(<4 x half> %x) nounwind readnone; +declare half @llvm.aarch64.neon.frecpe.f16(half) declare <8 x half> @llvm.aarch64.neon.frsqrte.v8f16(<8 x half> %x) nounwind readnone; declare <4 x half> @llvm.aarch64.neon.frsqrte.v4f16(<4 x half> %x) nounwind readnone; +declare half @llvm.aarch64.neon.frsqrte.f16(half) declare <8 x half> @llvm.aarch64.neon.frecps.v8f16(<8 x half> %x, <8 x half> %y) nounwind readnone; declare <4 x half> @llvm.aarch64.neon.frecps.v4f16(<4 x half> %x, <4 x half> %y) nounwind readnone; +declare half @llvm.aarch64.neon.frecps.f16(half, half) declare <8 x half> @llvm.aarch64.neon.frsqrts.v8f16(<8 x half> %x, <8 x half> %y) nounwind readnone; declare <4 x half> @llvm.aarch64.neon.frsqrts.v4f16(<4 x half> %x, <4 x half> %y) nounwind readnone; +declare half @llvm.aarch64.neon.frsqrts.f16(half, half) define weak_odr float @fast_inverse_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> poison, float %x, i32 0 - %approx = tail call <2 x float> @fast_inverse_f32x2(<2 x float> %vec) - %result = extractelement <2 x float> %approx, i32 0 + %approx = tail call float @llvm.aarch64.neon.frecpe.f32(float %x) + %correction = tail call float @llvm.aarch64.neon.frecps.f32(float %approx, float %x) + %result = fmul float %approx, %correction ret float %result } @@ -85,9 +94,9 @@ define weak_odr <4 x float> @fast_inverse_f32x4(<4 x float> %x) nounwind alwaysi } define weak_odr half @fast_inverse_f16(half %x) nounwind alwaysinline { - %vec = insertelement <4 x half> poison, half %x, i32 0 - %approx = tail call <4 x half> @fast_inverse_f16x4(<4 x half> %vec) - %result = extractelement <4 x half> %approx, i32 0 + %approx = tail call half @llvm.aarch64.neon.frecpe.f16(half %x) + %correction = tail call half @llvm.aarch64.neon.frecps.f16(half %approx, half %x) + %result = fmul half %approx, %correction ret half %result } @@ -106,9 +115,10 @@ define weak_odr <8 x half> @fast_inverse_f16x8(<8 x half> %x) nounwind alwaysinl } define weak_odr float @fast_inverse_sqrt_f32(float %x) nounwind alwaysinline { - %vec = insertelement <2 x float> poison, float %x, i32 0 - %approx = tail call <2 x float> @fast_inverse_sqrt_f32x2(<2 x float> %vec) - %result = extractelement <2 x float> %approx, i32 0 + %approx = tail call float @llvm.aarch64.neon.frsqrte.f32(float %x) + %approx2 = fmul float %approx, %approx + %correction = tail call float @llvm.aarch64.neon.frsqrts.f32(float %approx2, float %x) + %result = fmul float %approx, %correction ret float %result } @@ -129,9 +139,10 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al } define weak_odr half @fast_inverse_sqrt_f16(half %x) nounwind alwaysinline { - %vec = insertelement <4 x half> poison, half %x, i32 0 - %approx = tail call <4 x half> @fast_inverse_sqrt_f16x4(<4 x half> %vec) - %result = extractelement <4 x half> %approx, i32 0 + %approx = tail call half @llvm.aarch64.neon.frsqrte.f16(half %x) + %approx2 = fmul half %approx, %approx + %correction = tail call half @llvm.aarch64.neon.frsqrts.f16(half %approx2, half %x) + %result = fmul half %approx, %correction ret half %result } @@ -149,4 +160,43 @@ define weak_odr <8 x half> @fast_inverse_sqrt_f16x8(<8 x half> %x) nounwind alwa %correction = tail call <8 x half> @llvm.aarch64.neon.frsqrts.v8f16(<8 x half> %approx2, <8 x half> %x) %result = fmul <8 x half> %approx, %correction ret <8 x half> %result -} \ No newline at end of file +} + +declare @llvm.aarch64.sve.frecpe.x.nxv4f32( %x) nounwind readnone; +declare @llvm.aarch64.sve.frsqrte.x.nxv4f32( %x) nounwind readnone; +declare @llvm.aarch64.sve.frecps.x.nxv4f32( %x, %y) nounwind readnone; +declare @llvm.aarch64.sve.frsqrts.x.nxv4f32( %x, %y) nounwind readnone; +declare @llvm.aarch64.sve.frecpe.x.nxv8f16( %x) nounwind readnone; +declare @llvm.aarch64.sve.frsqrte.x.nxv8f16( %x) nounwind readnone; +declare @llvm.aarch64.sve.frecps.x.nxv8f16( %x, %y) nounwind readnone; +declare @llvm.aarch64.sve.frsqrts.x.nxv8f16( %x, %y) nounwind readnone; + +define weak_odr @fast_inverse_f32nx4( %x) nounwind alwaysinline { + %approx = tail call @llvm.aarch64.sve.frecpe.x.nxv4f32( %x) + %correction = tail call @llvm.aarch64.sve.frecps.x.nxv4f32( %approx, %x) + %result = fmul %approx, %correction + ret %result +} + +define weak_odr @fast_inverse_f16nx8( %x) nounwind alwaysinline { + %approx = tail call @llvm.aarch64.sve.frecpe.x.nxv8f16( %x) + %correction = tail call @llvm.aarch64.sve.frecps.x.nxv8f16( %approx, %x) + %result = fmul %approx, %correction + ret %result +} + +define weak_odr @fast_inverse_sqrt_f32nx4( %x) nounwind alwaysinline { + %approx = tail call @llvm.aarch64.sve.frsqrte.x.nxv4f32( %x) + %approx2 = fmul %approx, %approx + %correction = tail call @llvm.aarch64.sve.frsqrts.x.nxv4f32( %approx2, %x) + %result = fmul %approx, %correction + ret %result +} + +define weak_odr @fast_inverse_sqrt_f16nx8( %x) nounwind alwaysinline { + %approx = tail call @llvm.aarch64.sve.frsqrte.x.nxv8f16( %x) + %approx2 = fmul %approx, %approx + %correction = tail call @llvm.aarch64.sve.frsqrts.x.nxv8f16( %approx2, %x) + %result = fmul %approx, %correction + ret %result +} diff --git a/src/runtime/errors.cpp b/src/runtime/errors.cpp index 0879cc4a7c60..acb640c44b52 100644 --- a/src/runtime/errors.cpp +++ b/src/runtime/errors.cpp @@ -300,4 +300,12 @@ WEAK int halide_error_split_factor_not_positive(void *user_context, const char * return halide_error_code_split_factor_not_positive; } +WEAK int halide_error_vscale_invalid(void *user_context, const char *func_name, int runtime_vscale, int compiletime_vscale) { + error(user_context) + << "The function " << func_name + << " is compiled with the assumption that vscale of Scalable Vector is " << compiletime_vscale + << ". However, the detected runtime vscale is " << runtime_vscale << "."; + return halide_error_code_vscale_invalid; +} + } // extern "C" diff --git a/src/runtime/posix_math.ll b/src/runtime/posix_math.ll index 236652279615..ee6c2571f4eb 100644 --- a/src/runtime/posix_math.ll +++ b/src/runtime/posix_math.ll @@ -322,4 +322,30 @@ define weak_odr double @neg_inf_f64() nounwind uwtable readnone alwaysinline { define weak_odr double @nan_f64() nounwind uwtable readnone alwaysinline { ret double 0x7FF8000000000000 -} \ No newline at end of file +} + +; In case scalable vector with un-natural vector size, LLVM doesn't auto-vectorize the above scalar version +define weak_odr @inf_f32nx4() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, float 0x7FF0000000000000, i32 0), undef, zeroinitializer) +} + +define weak_odr @neg_inf_f32nx4() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, float 0xFFF0000000000000, i32 0), undef, zeroinitializer) +} + +define weak_odr @nan_f32nx4() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, float 0x7FF8000000000000, i32 0), undef, zeroinitializer) +} + + +define weak_odr @inf_f64nx2() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, double 0x7FF0000000000000, i32 0), undef, zeroinitializer) +} + +define weak_odr @neg_inf_f64nx2() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, double 0xFFF0000000000000, i32 0), undef, zeroinitializer) +} + +define weak_odr @nan_f64nx2() nounwind uwtable readnone alwaysinline { + ret shufflevector ( insertelement ( undef, double 0x7FF8000000000000, i32 0), undef, zeroinitializer) +} diff --git a/src/runtime/runtime_api.cpp b/src/runtime/runtime_api.cpp index db8ada2f4b8e..7955e8749df7 100644 --- a/src/runtime/runtime_api.cpp +++ b/src/runtime/runtime_api.cpp @@ -89,6 +89,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = { (void *)&halide_error_unaligned_host_ptr, (void *)&halide_error_storage_bound_too_small, (void *)&halide_error_device_crop_failed, + (void *)&halide_error_vscale_invalid, (void *)&halide_float16_bits_to_double, (void *)&halide_float16_bits_to_float, (void *)&halide_free, diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 9b934b768cdd..604ceda468f5 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -277,6 +277,7 @@ tests(GROUPS correctness simd_op_check_hvx.cpp simd_op_check_powerpc.cpp simd_op_check_riscv.cpp + simd_op_check_sve2.cpp simd_op_check_wasm.cpp simd_op_check_x86.cpp simplified_away_embedded_image.cpp diff --git a/test/correctness/simd_op_check_arm.cpp b/test/correctness/simd_op_check_arm.cpp index e8762a6ea2d8..3ebf5071569e 100644 --- a/test/correctness/simd_op_check_arm.cpp +++ b/test/correctness/simd_op_check_arm.cpp @@ -230,6 +230,13 @@ class SimdOpCheckARM : public SimdOpCheckTest { check(arm32 ? "vcvt.s32.f32" : "fcvtzs", 2 * w, i32(f32_1)); // skip the fixed point conversions for now + if (!arm32) { + check("fcvtmu *v", 2 * w, u32(floor(f32_1))); + check("fcvtpu *v", 2 * w, u32(ceil(f32_1))); + check("fcvtms *v", 2 * w, i32(floor(f32_1))); + check("fcvtps *v", 2 * w, i32(ceil(f32_1))); + } + // VDIV - F, D Divide // This doesn't actually get vectorized in 32-bit. Not sure cortex processors can do vectorized division. check(arm32 ? "vdiv.f32" : "fdiv", 2 * w, f32_1 / f32_2); diff --git a/test/correctness/simd_op_check_sve2.cpp b/test/correctness/simd_op_check_sve2.cpp new file mode 100644 index 000000000000..1a176dbccecd --- /dev/null +++ b/test/correctness/simd_op_check_sve2.cpp @@ -0,0 +1,1387 @@ +#include "simd_op_check.h" + +#include "Halide.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Halide; +using namespace Halide::ConciseCasts; +using namespace std; + +namespace { + +using CastFuncTy = function; + +class SimdOpCheckArmSve : public SimdOpCheckTest { +public: + SimdOpCheckArmSve(Target t, int w = 384, int h = 32) + : SimdOpCheckTest(t, w, h), debug_mode(Internal::get_env_variable("HL_DEBUG_SIMDOPCHECK")) { + + // Determine and hold can_run_the_code + // TODO: Since features of Arm CPU cannot be obtained automatically from get_host_target(), + // it is necessary to set some feature (e.g. "arm_fp16") explicitly to HL_JIT_TARGET. + // Halide throws error if there is unacceptable mismatch between jit_target and host_target. + + Target host = get_host_target(); + Target jit_target = get_jit_target_from_environment(); + cout << "host is: " << host.to_string() << endl; + cout << "HL_TARGET is: " << target.to_string() << endl; + cout << "HL_JIT_TARGET is: " << jit_target.to_string() << endl; + + auto is_same_triple = [](const Target &t1, const Target &t2) -> bool { + return t1.arch == t2.arch && t1.bits == t2.bits && t1.os == t2.os && t1.vector_bits == t2.vector_bits; + }; + + can_run_the_code = is_same_triple(host, target) && is_same_triple(jit_target, target); + + // A bunch of feature flags also need to match between the + // compiled code and the host in order to run the code. + for (Target::Feature f : {Target::ARMv7s, Target::ARMFp16, Target::NoNEON, Target::SVE2}) { + if (target.has_feature(f) != jit_target.has_feature(f)) { + can_run_the_code = false; + } + } + if (!can_run_the_code) { + cout << "[WARN] To perform verification of realization, " + << R"(the target triple "arm--" and key feature "arm_fp16")" + << " must be the same between HL_TARGET and HL_JIT_TARGET" << endl; + } + } + + bool can_run_code() const override { + // If we can meet the condition about target, run the error checking Halide::Func. + return can_run_the_code; + } + + void add_tests() override { + check_arm_integer(); + check_arm_float(); + check_arm_load_store(); + check_arm_pairwise(); + } + +private: + void check_arm_integer() { + // clang-format off + vector> test_params{ + {8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat}, + {16, in_i16, in_u16, in_f16, in_i32, in_u32, i16, i16_sat, i32, i8, i8_sat, u16, u16_sat, u32, u8, u8_sat}, + {32, in_i32, in_u32, in_f32, in_i64, in_u64, i32, i32_sat, i64, i16, i16_sat, u32, u32_sat, u64, u16, u16_sat}, + {64, in_i64, in_u64, in_f64, in_i64, in_u64, i64, i64_sat, i64, i32, i32_sat, u64, u64_sat, u64, u32, u32_sat}, + }; + // clang-format on + + for (const auto &[bits, in_i, in_u, in_f, in_i_wide, in_u_wide, + cast_i, satcast_i, widen_i, narrow_i, satnarrow_i, + cast_u, satcast_u, widen_u, narrow_u, satnarrow_u] : test_params) { + + Expr i_1 = in_i(x), i_2 = in_i(x + 16), i_3 = in_i(x + 32); + Expr u_1 = in_u(x), u_2 = in_u(x + 16), u_3 = in_u(x + 32); + Expr i_wide_1 = in_i_wide(x), i_wide_2 = in_i_wide(x + 16); + Expr u_wide_1 = in_u_wide(x), u_wide_2 = in_u_wide(x + 16); + Expr f_1 = in_f(x); + + // TODO: reconcile this comment and logic and figure out + // whether we're test 192 and 256 for NEON and which bit + // widths other that the target one for SVE2. + // + // In general neon ops have the 64-bit version, the 128-bit + // version (ending in q), and the widening version that takes + // 64-bit args and produces a 128-bit result (ending in l). We try + // to peephole match any with vector, so we just try 64-bits, 128 + // bits, 192 bits, and 256 bits for everything. + std::vector simd_bit_widths; + if (has_neon()) { + simd_bit_widths.push_back(64); + simd_bit_widths.push_back(128); + } + if (has_sve() && ((target.vector_bits > 128) || !has_neon())) { + simd_bit_widths.push_back(target.vector_bits); + } + for (auto &total_bits : simd_bit_widths) { + const int vf = total_bits / bits; + + // Due to workaround for SVE LLVM issues, in case of vector of half length of natural_lanes, + // there is some inconsistency in generated SVE insturction about the number of lanes. + // So the verification of lanes is skipped for this specific case. + const int instr_lanes = (total_bits == 64 && has_sve()) ? + Instruction::ANY_LANES : + Instruction::get_instr_lanes(bits, vf, target); + const int widen_lanes = Instruction::get_instr_lanes(bits * 2, vf, target); + const int narrow_lanes = Instruction::get_instr_lanes(bits, vf * 2, target); + + AddTestFunctor add_all(*this, bits, instr_lanes, vf); + AddTestFunctor add_all_vec(*this, bits, instr_lanes, vf, vf != 1); + AddTestFunctor add_8_16_32(*this, bits, instr_lanes, vf, bits != 64); + AddTestFunctor add_16_32_64(*this, bits, instr_lanes, vf, bits != 8); + AddTestFunctor add_16_32(*this, bits, instr_lanes, vf, bits == 16 || bits == 32); + AddTestFunctor add_32(*this, bits, instr_lanes, vf, bits == 32); + + AddTestFunctor add_8_16_32_widen(*this, bits, widen_lanes, vf, bits != 64 && !has_sve()); + + AddTestFunctor add_16_32_64_narrow(*this, bits, narrow_lanes, vf * 2, bits != 8 && !has_sve()); + AddTestFunctor add_16_32_narrow(*this, bits, narrow_lanes, vf * 2, (bits == 16 || bits == 32) && !has_sve()); + AddTestFunctor add_16_narrow(*this, bits, narrow_lanes, vf * 2, bits == 16 && !has_sve()); + + // VABA I - Absolute Difference and Accumulate + if (!has_sve()) { + // Relying on LLVM to detect accumulation + add_8_16_32(sel_op("vaba.s", "saba"), i_1 + absd(i_2, i_3)); + add_8_16_32(sel_op("vaba.u", "uaba"), u_1 + absd(u_2, u_3)); + } + + // VABAL I - Absolute Difference and Accumulate Long + add_8_16_32_widen(sel_op("vabal.s", "sabal"), i_wide_1 + absd(i_2, i_3)); + add_8_16_32_widen(sel_op("vabal.u", "uabal"), u_wide_1 + absd(u_2, u_3)); + + // VABD I, F - Absolute Difference + add_8_16_32(sel_op("vabd.s", "sabd"), absd(i_2, i_3)); + add_8_16_32(sel_op("vabd.u", "uabd"), absd(u_2, u_3)); + + // Via widening, taking abs, then narrowing + add_8_16_32(sel_op("vabd.s", "sabd"), cast_u(abs(widen_i(i_2) - i_3))); + add_8_16_32(sel_op("vabd.u", "uabd"), cast_u(abs(widen_i(u_2) - u_3))); + + // VABDL I - Absolute Difference Long + add_8_16_32_widen(sel_op("vabdl.s", "sabdl"), widen_i(absd(i_2, i_3))); + add_8_16_32_widen(sel_op("vabdl.u", "uabdl"), widen_u(absd(u_2, u_3))); + + // Via widening then taking an abs + add_8_16_32_widen(sel_op("vabdl.s", "sabdl"), abs(widen_i(i_2) - widen_i(i_3))); + add_8_16_32_widen(sel_op("vabdl.u", "uabdl"), abs(widen_i(u_2) - widen_i(u_3))); + + // VABS I, F F, D Absolute + add_8_16_32(sel_op("vabs.s", "abs"), abs(i_1)); + + // VADD I, F F, D Add + add_all_vec(sel_op("vadd.i", "add"), i_1 + i_2); + add_all_vec(sel_op("vadd.i", "add"), u_1 + u_2); + + // VADDHN I - Add and Narrow Returning High Half + add_16_32_64_narrow(sel_op("vaddhn.i", "addhn"), narrow_i((i_1 + i_2) >> (bits / 2))); + add_16_32_64_narrow(sel_op("vaddhn.i", "addhn"), narrow_u((u_1 + u_2) >> (bits / 2))); + + // VADDL I - Add Long + add_8_16_32_widen(sel_op("vaddl.s", "saddl"), widen_i(i_1) + widen_i(i_2)); + add_8_16_32_widen(sel_op("vaddl.u", "uaddl"), widen_u(u_1) + widen_u(u_2)); + + // VADDW I - Add Wide + add_8_16_32_widen(sel_op("vaddw.s", "saddw"), i_1 + i_wide_1); + add_8_16_32_widen(sel_op("vaddw.u", "uaddw"), u_1 + u_wide_1); + + // VAND X - Bitwise AND + // Not implemented in front-end yet + // VBIC I - Bitwise Clear + // VBIF X - Bitwise Insert if False + // VBIT X - Bitwise Insert if True + // skip these ones + + // VCEQ I, F - Compare Equal + add_8_16_32(sel_op("vceq.i", "cmeq", "cmpeq"), select(i_1 == i_2, cast_i(1), cast_i(2))); + add_8_16_32(sel_op("vceq.i", "cmeq", "cmpeq"), select(u_1 == u_2, cast_u(1), cast_u(2))); +#if 0 + // VCGE I, F - Compare Greater Than or Equal + // Halide flips these to less than instead + check("vcge.s8", 16, select(i8_1 >= i8_2, i8(1), i8(2))); + check("vcge.u8", 16, select(u8_1 >= u8_2, u8(1), u8(2))); + check("vcge.s16", 8, select(i16_1 >= i16_2, i16(1), i16(2))); + check("vcge.u16", 8, select(u16_1 >= u16_2, u16(1), u16(2))); + check("vcge.s32", 4, select(i32_1 >= i32_2, i32(1), i32(2))); + check("vcge.u32", 4, select(u32_1 >= u32_2, u32(1), u32(2))); + check("vcge.f32", 4, select(f32_1 >= f32_2, 1.0f, 2.0f)); + check("vcge.s8", 8, select(i8_1 >= i8_2, i8(1), i8(2))); + check("vcge.u8", 8, select(u8_1 >= u8_2, u8(1), u8(2))); + check("vcge.s16", 4, select(i16_1 >= i16_2, i16(1), i16(2))); + check("vcge.u16", 4, select(u16_1 >= u16_2, u16(1), u16(2))); + check("vcge.s32", 2, select(i32_1 >= i32_2, i32(1), i32(2))); + check("vcge.u32", 2, select(u32_1 >= u32_2, u32(1), u32(2))); + check("vcge.f32", 2, select(f32_1 >= f32_2, 1.0f, 2.0f)); +#endif + // VCGT I, F - Compare Greater Than + add_8_16_32(sel_op("vcgt.s", "cmgt", "cmpgt"), select(i_1 > i_2, cast_i(1), cast_i(2))); + add_8_16_32(sel_op("vcgt.u", "cmhi", "cmphi"), select(u_1 > u_2, cast_u(1), cast_u(2))); +#if 0 + // VCLS I - Count Leading Sign Bits + // We don't currently match these, but it wouldn't be hard to do. + check(arm32 ? "vcls.s8" : "cls", 8 * w, max(count_leading_zeros(i8_1), count_leading_zeros(~i8_1))); + check(arm32 ? "vcls.s16" : "cls", 8 * w, max(count_leading_zeros(i16_1), count_leading_zeros(~i16_1))); + check(arm32 ? "vcls.s32" : "cls", 8 * w, max(count_leading_zeros(i32_1), count_leading_zeros(~i32_1))); +#endif + // VCLZ I - Count Leading Zeros + add_8_16_32(sel_op("vclz.i", "clz"), count_leading_zeros(i_1)); + add_8_16_32(sel_op("vclz.i", "clz"), count_leading_zeros(u_1)); + + // VCMP - F, D Compare Setting Flags + // We skip this + + // VCNT I - Count Number of Set Bits + if (!has_sve()) { + // In NEON, there is only cnt for bytes, and then horizontal adds. + add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(i_1)); + add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(u_1)); + } else { + add_8_16_32("cnt", popcount(i_1)); + add_8_16_32("cnt", popcount(u_1)); + } + + // VDUP X - Duplicate + add_8_16_32(sel_op("vdup.", "dup", "mov"), cast_i(y)); + add_8_16_32(sel_op("vdup.", "dup", "mov"), cast_u(y)); + + // VEOR X - Bitwise Exclusive OR + // check("veor", 4, bool1 ^ bool2); + + // VEXT I - Extract Elements and Concatenate + // unaligned loads with known offsets should use vext +#if 0 + // We currently don't do this. + check("vext.8", 16, in_i8(x+1)); + check("vext.16", 8, in_i16(x+1)); + check("vext.32", 4, in_i32(x+1)); +#endif + // VHADD I - Halving Add + add_8_16_32(sel_op("vhadd.s", "shadd"), cast_i((widen_i(i_1) + widen_i(i_2)) / 2)); + add_8_16_32(sel_op("vhadd.u", "uhadd"), cast_u((widen_u(u_1) + widen_u(u_2)) / 2)); + + // Halide doesn't define overflow behavior for i32 so we + // can use vhadd instruction. We can't use it for unsigned u8,i16,u16,u32. + add_32(sel_op("vhadd.s", "shadd"), (i_1 + i_2) / 2); + + // VHSUB I - Halving Subtract + add_8_16_32(sel_op("vhsub.s", "shsub"), cast_i((widen_i(i_1) - widen_i(i_2)) / 2)); + add_8_16_32(sel_op("vhsub.u", "uhsub"), cast_u((widen_u(u_1) - widen_u(u_2)) / 2)); + + add_32(sel_op("vhsub.s", "shsub"), (i_1 - i_2) / 2); + + // VMAX I, F - Maximum + add_8_16_32(sel_op("vmax.s", "smax"), max(i_1, i_2)); + add_8_16_32(sel_op("vmax.u", "umax"), max(u_1, u_2)); + + // VMIN I, F - Minimum + add_8_16_32(sel_op("vmin.s", "smin"), min(i_1, i_2)); + add_8_16_32(sel_op("vmin.u", "umin"), min(u_1, u_2)); + + // VMLA I, F F, D Multiply Accumulate + add_8_16_32("mla signed", sel_op("vmla.i", "mla", "(mad|mla)"), i_1 + i_2 * i_3); + add_8_16_32("mla unsigned", sel_op("vmla.i", "mla", "(mad|mla)"), u_1 + u_2 * u_3); + // VMLS I, F F, D Multiply Subtract + add_8_16_32("mls signed", sel_op("vmls.i", "mls", "(mls|msb)"), i_1 - i_2 * i_3); + add_8_16_32("mls unsigned", sel_op("vmls.i", "mls", "(mls|msb)"), u_1 - u_2 * u_3); + + // VMLAL I - Multiply Accumulate Long + // Try to trick LLVM into generating a zext instead of a sext by making + // LLVM think the operand never has a leading 1 bit. zext breaks LLVM's + // pattern matching of mlal. + add_8_16_32_widen(sel_op("vmlal.s", "smlal"), i_wide_1 + widen_i(i_2 & 0x3) * i_3); + add_8_16_32_widen(sel_op("vmlal.u", "umlal"), u_wide_1 + widen_u(u_2) * u_3); + + // VMLSL I - Multiply Subtract Long + add_8_16_32_widen(sel_op("vmlsl.s", "smlsl"), i_wide_1 - widen_i(i_2 & 0x3) * i_3); + add_8_16_32_widen(sel_op("vmlsl.u", "umlsl"), u_wide_1 - widen_u(u_2) * u_3); + + // VMOV X F, D Move Register or Immediate + // This is for loading immediates, which we won't do in the inner loop anyway + + // VMOVL I - Move Long + // For aarch64, llvm does a widening shift by 0 instead of using the sxtl instruction. + add_8_16_32_widen(sel_op("vmovl.s", "sshll"), widen_i(i_1)); + add_8_16_32_widen(sel_op("vmovl.u", "ushll"), widen_u(u_1)); + add_8_16_32_widen(sel_op("vmovl.u", "ushll"), widen_i(u_1)); + + // VMOVN I - Move and Narrow + if (Halide::Internal::get_llvm_version() >= 140 && total_bits >= 128) { + if (is_arm32()) { + add_16_32_64_narrow("vmovn.i", narrow_i(i_1)); + add_16_32_64_narrow("vmovn.i", narrow_u(u_1)); + } else { + add_16_32_64({{"uzp1", bits / 2, narrow_lanes * 2}}, vf * 2, narrow_i(i_1)); + add_16_32_64({{"uzp1", bits / 2, narrow_lanes * 2}}, vf * 2, narrow_u(u_1)); + } + } else { + add_16_32_64_narrow(sel_op("vmovn.i", "xtn"), narrow_i(i_1)); + add_16_32_64_narrow(sel_op("vmovn.i", "xtn"), narrow_u(u_1)); + } + + // VMRS X F, D Move Advanced SIMD or VFP Register to ARM compute Engine + // VMSR X F, D Move ARM Core Register to Advanced SIMD or VFP + // trust llvm to use this correctly + + // VMUL I, F, P F, D Multiply + add_8_16_32(sel_op("vmul.i", "mul"), i_2 * i_1); + add_8_16_32(sel_op("vmul.i", "mul"), u_2 * u_1); + + // VMULL I, F, P - Multiply Long + add_8_16_32_widen(sel_op("vmull.s", "smull"), widen_i(i_1) * i_2); + add_8_16_32_widen(sel_op("vmull.u", "umull"), widen_u(u_1) * u_2); + + // integer division by a constant should use fixed point unsigned + // multiplication, which is done by using a widening multiply + // followed by a narrowing + add_8_16_32_widen(sel_op("vmull.u", "umull"), i_1 / 37); + add_8_16_32_widen(sel_op("vmull.u", "umull"), u_1 / 37); + + // VMVN X - Bitwise NOT + // check("vmvn", ~bool1); + + // VNEG I, F F, D Negate + add_8_16_32(sel_op("vneg.s", "neg"), -i_1); + +#if 0 + // These are vfp, not neon. They only work on scalars + check("vnmla.f32", 4, -(f32_1 + f32_2*f32_3)); + check("vnmla.f64", 2, -(f64_1 + f64_2*f64_3)); + check("vnmls.f32", 4, -(f32_1 - f32_2*f32_3)); + check("vnmls.f64", 2, -(f64_1 - f64_2*f64_3)); + check("vnmul.f32", 4, -(f32_1*f32_2)); + check("vnmul.f64", 2, -(f64_1*f64_2)); + + // Of questionable value. Catching abs calls is annoying, and the + // slow path is only one more op (for the max). + check("vqabs.s8", 16, abs(max(i8_1, -max_i8))); + check("vqabs.s8", 8, abs(max(i8_1, -max_i8))); + check("vqabs.s16", 8, abs(max(i16_1, -max_i16))); + check("vqabs.s16", 4, abs(max(i16_1, -max_i16))); + check("vqabs.s32", 4, abs(max(i32_1, -max_i32))); + check("vqabs.s32", 2, abs(max(i32_1, -max_i32))); +#endif + // VQADD I - Saturating Add + add_8_16_32(sel_op("vqadd.s", "sqadd"), satcast_i(widen_i(i_1) + widen_i(i_2))); + const Expr max_u = UInt(bits).max(); + add_8_16_32(sel_op("vqadd.u", "uqadd"), cast_u(min(widen_u(u_1) + widen_u(u_2), max_u))); + + // Check the case where we add a constant that could be narrowed + add_8_16_32(sel_op("vqadd.u", "uqadd"), cast_u(min(widen_u(u_1) + 17, max_u))); + + // Can't do larger ones because we can't represent the intermediate 128-bit wide ops. + + // VQDMLAL I - Saturating Double Multiply Accumulate Long + // VQDMLSL I - Saturating Double Multiply Subtract Long + // We don't do these, but it would be possible. + + // VQDMULH I - Saturating Doubling Multiply Returning High Half + // VQDMULL I - Saturating Doubling Multiply Long + add_16_32(sel_op("vqdmulh.s", "sqdmulh"), satcast_i((widen_i(i_1) * widen_i(i_2)) >> (bits - 1))); + + // VQMOVN I - Saturating Move and Narrow + // VQMOVUN I - Saturating Move and Unsigned Narrow + add_16_32_64_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_1)); + add_16_32_64_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(i_1)); + const Expr max_u_narrow = UInt(bits / 2).max(); + add_16_32_64_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u_1, max_u_narrow))); + // Double saturating narrow + add_16_32_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_wide_1)); + add_16_32_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u_wide_1, max_u_narrow))); + add_16_32_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_wide_1)); + add_16_32_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(i_wide_1)); + // Triple saturating narrow + Expr i64_1 = in_i64(x), u64_1 = in_u64(x), f32_1 = in_f32(x), f64_1 = in_f64(x); + add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i64_1)); + add_16_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u64_1, max_u_narrow))); + add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(f32_1)); + add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(f64_1)); + add_16_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(f32_1)); + add_16_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(f64_1)); + + // VQNEG I - Saturating Negate + const Expr max_i = Int(bits).max(); + add_8_16_32(sel_op("vqneg.s", "sqneg"), -max(i_1, -max_i)); + + // VQRDMULH I - Saturating Rounding Doubling Multiply Returning High Half + // Note: division in Halide always rounds down (not towards + // zero). Otherwise these patterns would be more complicated. + add_16_32(sel_op("vqrdmulh.s", "sqrdmulh"), satcast_i((widen_i(i_1) * widen_i(i_2) + (1 << (bits - 2))) / (widen_i(1) << (bits - 1)))); + + // VQRSHRN I - Saturating Rounding Shift Right Narrow + // VQRSHRUN I - Saturating Rounding Shift Right Unsigned Narrow + add_16_32_64_narrow(sel_op("vqrshrn.s", "sqrshrn"), satnarrow_i((widen_i(i_1) + 8) / 16)); + add_16_32_64_narrow(sel_op("vqrshrun.s", "sqrshrun"), satnarrow_u((widen_i(i_1) + 8) / 16)); + add_16_32_narrow(sel_op("vqrshrn.u", "uqrshrn"), narrow_u(min((widen_u(u_1) + 8) / 16, max_u_narrow))); + + // VQSHL I - Saturating Shift Left + add_8_16_32(sel_op("vqshl.s", "sqshl"), satcast_i(widen_i(i_1) * 16)); + add_8_16_32(sel_op("vqshl.u", "uqshl"), cast_u(min(widen_u(u_1) * 16, max_u))); + + // VQSHLU I - Saturating Shift Left Unsigned + if (!has_sve()) { + add_8_16_32(sel_op("vqshlu.s", "sqshlu"), satcast_u(widen_i(i_1) * 16)); + } + + // VQSHRN I - Saturating Shift Right Narrow + // VQSHRUN I - Saturating Shift Right Unsigned Narrow + add_16_32_64_narrow(sel_op("vqshrn.s", "sqshrn"), satnarrow_i(i_1 / 16)); + add_16_32_64_narrow(sel_op("vqshrun.s", "sqshrun"), satnarrow_u(i_1 / 16)); + add_16_32_narrow(sel_op("vqshrn.u", "uqshrn"), narrow_u(min(u_1 / 16, max_u_narrow))); + + // VQSUB I - Saturating Subtract + add_8_16_32(sel_op("vqsub.s", "sqsub"), satcast_i(widen_i(i_1) - widen_i(i_2))); + + // N.B. Saturating subtracts are expressed by widening to a igned* type + add_8_16_32(sel_op("vqsub.u", "uqsub"), satcast_u(widen_i(u_1) - widen_i(u_2))); + + // VRADDHN I - Rounding Add and Narrow Returning High Half + add_16_32_64_narrow(sel_op("vraddhn.i", "raddhn"), narrow_i((widen_i(i_1 + i_2) + (Expr(cast_i(1)) << (bits / 2 - 1))) >> (bits / 2))); + add_16_32_narrow(sel_op("vraddhn.i", "raddhn"), narrow_u((widen_u(u_1 + u_2) + (Expr(cast_u(1)) << (bits / 2 - 1))) >> (bits / 2))); + + // VREV16 X - Reverse in Halfwords + // VREV32 X - Reverse in Words + // VREV64 X - Reverse in Doublewords + + // These reverse within each halfword, word, and doubleword + // respectively. Sometimes llvm generates them, and sometimes + // it generates vtbl instructions. + + // VRHADD I - Rounding Halving Add + add_8_16_32(sel_op("vrhadd.s", "srhadd"), cast_i((widen_i(i_1) + widen_i(i_2) + 1) / 2)); + add_8_16_32(sel_op("vrhadd.u", "urhadd"), cast_u((widen_u(u_1) + widen_u(u_2) + 1) / 2)); + + // VRSHL I - Rounding Shift Left + Expr shift = (i_2 % bits) - (bits / 2); + Expr round_s = (cast_i(1) >> min(shift, 0)) / 2; + Expr round_u = (cast_u(1) >> min(shift, 0)) / 2; + add_8_16_32(sel_op("vrshl.s", "srshl", "srshlr"), cast_i((widen_i(i_1) + round_s) << shift)); + add_8_16_32(sel_op("vrshl.u", "urshl", "urshlr"), cast_u((widen_u(u_1) + round_u) << shift)); + + round_s = (cast_i(1) << max(shift, 0)) / 2; + round_u = (cast_u(1) << max(shift, 0)) / 2; + add_8_16_32(sel_op("vrshl.s", "srshl", "srshlr"), cast_i((widen_i(i_1) + round_s) >> shift)); + add_8_16_32(sel_op("vrshl.u", "urshl", "urshlr"), cast_u((widen_u(u_1) + round_u) >> shift)); + + // VRSHR I - Rounding Shift Right + add_8_16_32(sel_op("vrshr.s", "srshr", "srshl"), cast_i((widen_i(i_1) + 1) >> 1)); + add_8_16_32(sel_op("vrshr.u", "urshr", "urshl"), cast_u((widen_u(u_1) + 1) >> 1)); + + // VRSHRN I - Rounding Shift Right Narrow + if (Halide::Internal::get_llvm_version() >= 140) { + // LLVM14 converts RSHRN/RSHRN2 to RADDHN/RADDHN2 when the shift amount is half the width of the vector element + // See https://reviews.llvm.org/D116166 + add_16_32_narrow(sel_op("vrshrn.i", "raddhn"), narrow_i((widen_i(i_1) + (cast_i(1) << (bits / 2 - 1))) >> (bits / 2))); + add_16_32_narrow(sel_op("vrshrn.i", "raddhn"), narrow_u((widen_u(u_1) + (cast_u(1) << (bits / 2 - 1))) >> (bits / 2))); + } + add_16_32_64_narrow(sel_op("vrshrn.i", "rshrn"), narrow_i((widen_i(i_1) + (1 << (bits / 4))) >> (bits / 4 + 1))); + add_16_32_narrow(sel_op("vrshrn.i", "rshrn"), narrow_u((widen_u(u_1) + (1 << (bits / 4))) >> (bits / 4 + 1))); + + // VRSRA I - Rounding Shift Right and Accumulate + if (!has_sve()) { + // Relying on LLVM to detect accumulation + add_8_16_32(sel_op("vrsra.s", "srsra"), i_2 + cast_i((widen_i(i_1) + 1) >> 1)); + add_8_16_32(sel_op("vrsra.u", "ursra"), i_2 + cast_u((widen_u(u_1) + 1) >> 1)); + } + + // VRSUBHN I - Rounding Subtract and Narrow Returning High Half + add_16_32_64_narrow(sel_op("vrsubhn.i", "rsubhn"), narrow_i((widen_i(i_1 - i_2) + (Expr(cast_i(1)) << (bits / 2 - 1))) >> (bits / 2))); + add_16_32_narrow(sel_op("vrsubhn.i", "rsubhn"), narrow_u((widen_u(u_1 - u_2) + (Expr(cast_u(1)) << (bits / 2 - 1))) >> (bits / 2))); + + // VSHL I - Shift Left + add_all_vec(sel_op("vshl.i", "shl", "lsl"), i_1 * 16); + add_all_vec(sel_op("vshl.i", "shl", "lsl"), u_1 * 16); + + if (!has_sve()) { // No equivalent instruction in SVE. + add_all_vec(sel_op("vshl.s", "sshl"), i_1 << shift); + add_all_vec(sel_op("vshl.s", "sshl"), i_1 >> shift); + add_all_vec(sel_op("vshl.u", "ushl"), u_1 << shift); + add_all_vec(sel_op("vshl.u", "ushl"), u_1 >> shift); + } + + // VSHLL I - Shift Left Long + add_8_16_32_widen(sel_op("vshll.s", "sshll"), widen_i(i_1) * 16); + add_8_16_32_widen(sel_op("vshll.u", "ushll"), widen_u(u_1) * 16); + + // VSHR I - Shift Right + add_all_vec(sel_op("vshr.s", "sshr", "asr"), i_1 / 16); + add_all_vec(sel_op("vshr.u", "ushr", "lsr"), u_1 / 16); + + // VSHRN I - Shift Right Narrow + add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_i(i_1 >> (bits / 2))); + add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_u(u_1 >> (bits / 2))); + + add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_i(i_1 / 16)); + add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_u(u_1 / 16)); + + // VSLI X - Shift Left and Insert + // I guess this could be used for (x*256) | (y & 255)? We don't do bitwise ops on integers, so skip it. + + // VSRA I - Shift Right and Accumulate + if (!has_sve()) { + // Relying on LLVM to detect accumulation + add_all_vec(sel_op("vsra.s", "ssra"), i_2 + i_1 / 16); + add_all_vec(sel_op("vsra.u", "usra"), u_2 + u_1 / 16); + } + + // VSRI X - Shift Right and Insert + // See VSLI + + // VSUB I, F F, D Subtract + add_all_vec(sel_op("vsub.i", "sub"), i_1 - i_2); + add_all_vec(sel_op("vsub.i", "sub"), u_1 - u_2); + + // VSUBHN I - Subtract and Narrow + add_16_32_64_narrow(sel_op("vsubhn.i", "subhn"), narrow_i((i_1 - i_2) >> (bits / 2))); + add_16_32_64_narrow(sel_op("vsubhn.i", "subhn"), narrow_u((u_1 - u_2) >> (bits / 2))); + + // VSUBL I - Subtract Long + add_8_16_32_widen(sel_op("vsubl.s", "ssubl"), widen_i(i_1) - widen_i(i_2)); + add_8_16_32_widen(sel_op("vsubl.u", "usubl"), widen_u(u_1) - widen_u(u_2)); + + add_8_16_32_widen(sel_op("vsubl.s", "ssubl"), widen_i(i_1) - widen_i(in_i(0))); + add_8_16_32_widen(sel_op("vsubl.u", "usubl"), widen_u(u_1) - widen_u(in_u(0))); + + // VSUBW I - Subtract Wide + add_8_16_32_widen(sel_op("vsubw.s", "ssubw"), i_wide_1 - i_1); + add_8_16_32_widen(sel_op("vsubw.u", "usubw"), u_wide_1 - u_1); + } + } + } + + void check_arm_float() { + vector> test_params{ + {16, in_f16, in_u16, in_i16, f16}, + {32, in_f32, in_u32, in_i32, f32}, + {64, in_f64, in_u64, in_i64, f64}, + }; + + for (const auto &[bits, in_f, in_u, in_i, cast_f] : test_params) { + Expr f_1 = in_f(x), f_2 = in_f(x + 16), f_3 = in_f(x + 32); + Expr u_1 = in_u(x); + Expr i_1 = in_i(x); + + // Arithmetic which could throw FP exception could return NaN, which results in output mismatch. + // To avoid that, we need a positive value within certain range + Func in_f_clamped; + in_f_clamped(x) = clamp(in_f(x), cast_f(1e-3f), cast_f(1.0f)); + in_f_clamped.compute_root(); // To prevent LLVM optimization which results in a different instruction + Expr f_1_clamped = in_f_clamped(x); + Expr f_2_clamped = in_f_clamped(x + 16); + + if (bits == 16 && !is_float16_supported()) { + continue; + } + + vector total_bits_params = {256}; // {64, 128, 192, 256}; + if (bits != 64) { + // Add scalar case to verify float16 native operation + total_bits_params.push_back(bits); + } + + for (auto total_bits : total_bits_params) { + const int vf = total_bits / bits; + const bool is_vector = vf > 1; + + const int instr_lanes = Instruction::get_instr_lanes(bits, vf, target); + const int force_vectorized_lanes = Instruction::get_force_vectorized_instr_lanes(bits, vf, target); + + AddTestFunctor add(*this, bits, instr_lanes, vf); + AddTestFunctor add_arm32_f32(*this, bits, vf, is_arm32() && bits == 32); + AddTestFunctor add_arm64(*this, bits, instr_lanes, vf, !is_arm32()); + + add({{sel_op("vabs.f", "fabs"), bits, force_vectorized_lanes}}, vf, abs(f_1)); + add(sel_op("vadd.f", "fadd"), f_1 + f_2); + add(sel_op("vsub.f", "fsub"), f_1 - f_2); + add(sel_op("vmul.f", "fmul"), f_1 * f_2); + add("fdiv", sel_op("vdiv.f", "fdiv", "(fdiv|fdivr)"), f_1 / f_2_clamped); + auto fneg_lanes = has_sve() ? force_vectorized_lanes : instr_lanes; + add({{sel_op("vneg.f", "fneg"), bits, fneg_lanes}}, vf, -f_1); + add({{sel_op("vsqrt.f", "fsqrt"), bits, force_vectorized_lanes}}, vf, sqrt(f_1_clamped)); + + add_arm32_f32(is_vector ? "vceq.f" : "vcmp.f", select(f_1 == f_2, cast_f(1.0f), cast_f(2.0f))); + add_arm32_f32(is_vector ? "vcgt.f" : "vcmp.f", select(f_1 > f_2, cast_f(1.0f), cast_f(2.0f))); + add_arm64(is_vector ? "fcmeq" : "fcmp", select(f_1 == f_2, cast_f(1.0f), cast_f(2.0f))); + add_arm64(is_vector ? "fcmgt" : "fcmp", select(f_1 > f_2, cast_f(1.0f), cast_f(2.0f))); + + add_arm32_f32("vcvt.f32.u", cast_f(u_1)); + add_arm32_f32("vcvt.f32.s", cast_f(i_1)); + add_arm32_f32("vcvt.u32.f", cast(UInt(bits), f_1)); + add_arm32_f32("vcvt.s32.f", cast(Int(bits), f_1)); + // The max of Float(16) is less than that of UInt(16), which generates "nan" in emulator + Expr float_max = Float(bits).max(); + add_arm64("ucvtf", cast_f(min(float_max, u_1))); + add_arm64("scvtf", cast_f(i_1)); + add_arm64({{"fcvtzu", bits, force_vectorized_lanes}}, vf, cast(UInt(bits), f_1)); + add_arm64({{"fcvtzs", bits, force_vectorized_lanes}}, vf, cast(Int(bits), f_1)); + add_arm64({{"frintn", bits, force_vectorized_lanes}}, vf, round(f_1)); + add_arm64({{"frintm", bits, force_vectorized_lanes}}, vf, floor(f_1)); + add_arm64({{"frintp", bits, force_vectorized_lanes}}, vf, ceil(f_1)); + add_arm64({{"frintz", bits, force_vectorized_lanes}}, vf, trunc(f_1)); + + add_arm32_f32({{"vmax.f", bits, force_vectorized_lanes}}, vf, max(f_1, f_2)); + add_arm32_f32({{"vmin.f", bits, force_vectorized_lanes}}, vf, min(f_1, f_2)); + + add_arm64({{"fmax", bits, force_vectorized_lanes}}, vf, max(f_1, f_2)); + add_arm64({{"fmin", bits, force_vectorized_lanes}}, vf, min(f_1, f_2)); + if (bits != 64 && total_bits != 192) { + // Halide relies on LLVM optimization for this pattern, and in some case it doesn't work + add_arm64("fmla", is_vector ? (has_sve() ? "(fmla|fmad)" : "fmla") : "fmadd", f_1 + f_2 * f_3); + add_arm64("fmls", is_vector ? (has_sve() ? "(fmls|fmsb)" : "fmls") : "fmsub", f_1 - f_2 * f_3); + } + if (bits != 64) { + add_arm64(vector{"frecpe", "frecps"}, fast_inverse(f_1_clamped)); + add_arm64(vector{"frsqrte", "frsqrts"}, fast_inverse_sqrt(f_1_clamped)); + } + + if (bits == 16) { + // Some of the math ops (exp,log,pow) for fp16 are converted into "xxx_fp32" call + // and then lowered to Internal::halide_xxx() function. + // In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated + // instead of emulated equivalent code with other types. + if (is_vector && !has_sve()) { + add_arm64("exp", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, exp(f_1_clamped)); + add_arm64("log", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, log(f_1_clamped)); + add_arm64("pow", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, pow(f_1_clamped, f_2_clamped)); + } else { + add_arm64("exp", "fcvt", exp(f_1_clamped)); + add_arm64("log", "fcvt", log(f_1_clamped)); + add_arm64("pow", "fcvt", pow(f_1_clamped, f_2_clamped)); + } + } + + // No corresponding instructions exists for is_nan, is_inf, is_finite. + // The instructions expected to be generated depends on CodeGen_LLVM::visit(const Call *op) + add_arm64("nan", is_vector ? sel_op("", "fcmge", "fcmuo") : "fcmp", is_nan(f_1)); + add_arm64("inf", {{"fabs", bits, force_vectorized_lanes}}, vf, is_inf(f_1)); + add_arm64("finite", {{"fabs", bits, force_vectorized_lanes}}, vf, is_inf(f_1)); + } + + if (bits == 16) { + // Actually, the following ops are not vectorized because SIMD instruction is unavailable. + // The purpose of the test is just to confirm no error. + // In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated + // instead of emulated equivalent code with other types. + AddTestFunctor add_f16(*this, 16, 1); + + add_f16("sinf", {{"bl", "sinf"}, {"fcvt", 16, 1}}, 1, sin(f_1_clamped)); + add_f16("asinf", {{"bl", "asinf"}, {"fcvt", 16, 1}}, 1, asin(f_1_clamped)); + add_f16("cosf", {{"bl", "cosf"}, {"fcvt", 16, 1}}, 1, cos(f_1_clamped)); + add_f16("acosf", {{"bl", "acosf"}, {"fcvt", 16, 1}}, 1, acos(f_1_clamped)); + add_f16("tanf", {{"bl", "tanf"}, {"fcvt", 16, 1}}, 1, tan(f_1_clamped)); + add_f16("atanf", {{"bl", "atanf"}, {"fcvt", 16, 1}}, 1, atan(f_1_clamped)); + add_f16("atan2f", {{"bl", "atan2f"}, {"fcvt", 16, 1}}, 1, atan2(f_1_clamped, f_2_clamped)); + add_f16("sinhf", {{"bl", "sinhf"}, {"fcvt", 16, 1}}, 1, sinh(f_1_clamped)); + add_f16("asinhf", {{"bl", "asinhf"}, {"fcvt", 16, 1}}, 1, asinh(f_1_clamped)); + add_f16("coshf", {{"bl", "coshf"}, {"fcvt", 16, 1}}, 1, cosh(f_1_clamped)); + add_f16("acoshf", {{"bl", "acoshf"}, {"fcvt", 16, 1}}, 1, acosh(max(f_1, cast_f(1.0f)))); + add_f16("tanhf", {{"bl", "tanhf"}, {"fcvt", 16, 1}}, 1, tanh(f_1_clamped)); + add_f16("atanhf", {{"bl", "atanhf"}, {"fcvt", 16, 1}}, 1, atanh(clamp(f_1, cast_f(-0.5f), cast_f(0.5f)))); + } + } + } + + void check_arm_load_store() { + vector> test_params = { + {Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}}; + + for (const auto &[elt, in_im] : test_params) { + const int bits = elt.bits(); + if ((elt == Float(16) && !is_float16_supported()) || + (is_arm32() && bits == 64)) { + continue; + } + + // LD/ST - Load/Store + for (int width = 64; width <= 64 * 4; width *= 2) { + const int total_lanes = width / bits; + const int instr_lanes = min(total_lanes, 128 / bits); + if (instr_lanes < 2) continue; // bail out scalar op + + // In case of arm32, instruction selection looks inconsistent due to optimization by LLVM + AddTestFunctor add(*this, bits, total_lanes, target.bits == 64); + // NOTE: if the expr is too simple, LLVM might generate "bl memcpy" + Expr load_store_1 = in_im(x) * 3; + + if (has_sve()) { + // in native width, ld1b/st1b is used regardless of data type + const bool allow_byte_ls = (width == target.vector_bits); + add({get_sve_ls_instr("ld1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1); + add({get_sve_ls_instr("st1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1); + } else { + // vector register is not used for simple load/store + string reg_prefix = (width <= 64) ? "d" : "q"; + add({{"st[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1); + add({{"ld[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1); + } + } + + // LD2/ST2 - Load/Store two-element structures + int base_vec_bits = has_sve() ? target.vector_bits : 128; + for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) { + const int total_lanes = width / bits; + const int vector_lanes = total_lanes / 2; + const int instr_lanes = min(vector_lanes, base_vec_bits / bits); + if (instr_lanes < 2) continue; // bail out scalar op + + AddTestFunctor add_ldn(*this, bits, vector_lanes); + AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes); + + Func tmp1, tmp2; + tmp1(x) = cast(elt, x); + tmp1.compute_root(); + tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16)); + tmp2.compute_root().vectorize(x, total_lanes); + Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1); + Expr store_2 = tmp2(0, 0) + tmp2(0, 127); + + if (has_sve()) { + // TODO(inssue needed): Added strided load support. +#if 0 + add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2); +#endif + add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2); + } else { + add_ldn(sel_op("vld2.", "ld2"), load_2); + add_stn(sel_op("vst2.", "st2"), store_2); + } + } + + // Also check when the two expressions interleaved have a common + // subexpression, which results in a vector var being lifted out. + for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) { + const int total_lanes = width / bits; + const int vector_lanes = total_lanes / 2; + const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target); + if (instr_lanes < 2) continue; // bail out scalar op + + AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes); + + Func tmp1, tmp2; + tmp1(x) = cast(elt, x); + tmp1.compute_root(); + Expr e = (tmp1(x / 2) * 2 + 7) / 4; + tmp2(x, y) = select(x % 2 == 0, e * 3, e + 17); + tmp2.compute_root().vectorize(x, total_lanes); + Expr store_2 = tmp2(0, 0) + tmp2(0, 127); + + if (has_sve()) { + add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2); + } else { + add_stn(sel_op("vst2.", "st2"), store_2); + } + } + + // LD3/ST3 - Store three-element structures + for (int width = 192; width <= 192 * 4; width *= 2) { + const int total_lanes = width / bits; + const int vector_lanes = total_lanes / 3; + const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target); + if (instr_lanes < 2) continue; // bail out scalar op + + AddTestFunctor add_ldn(*this, bits, vector_lanes); + AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes); + + Func tmp1, tmp2; + tmp1(x) = cast(elt, x); + tmp1.compute_root(); + tmp2(x, y) = select(x % 3 == 0, tmp1(x / 3), + x % 3 == 1, tmp1(x / 3 + 16), + tmp1(x / 3 + 32)); + tmp2.compute_root().vectorize(x, total_lanes); + Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2); + Expr store_3 = tmp2(0, 0) + tmp2(0, 127); + + if (has_sve()) { + // TODO(issue needed): Added strided load support. +#if 0 + add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3); + add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3); +#endif + } else { + add_ldn(sel_op("vld3.", "ld3"), load_3); + add_stn(sel_op("vst3.", "st3"), store_3); + } + } + + // LD4/ST4 - Store four-element structures + for (int width = 256; width <= 256 * 4; width *= 2) { + const int total_lanes = width / bits; + const int vector_lanes = total_lanes / 4; + const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target); + if (instr_lanes < 2) continue; // bail out scalar op + + AddTestFunctor add_ldn(*this, bits, vector_lanes); + AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes); + + Func tmp1, tmp2; + tmp1(x) = cast(elt, x); + tmp1.compute_root(); + tmp2(x, y) = select(x % 4 == 0, tmp1(x / 4), + x % 4 == 1, tmp1(x / 4 + 16), + x % 4 == 2, tmp1(x / 4 + 32), + tmp1(x / 4 + 48)); + tmp2.compute_root().vectorize(x, total_lanes); + Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3); + Expr store_4 = tmp2(0, 0) + tmp2(0, 127); + + if (has_sve()) { + // TODO(issue needed): Added strided load support. +#if 0 + add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4); + add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4); +#endif + } else { + add_ldn(sel_op("vld4.", "ld4"), load_4); + add_stn(sel_op("vst4.", "st4"), store_4); + } + } + + // SVE Gather/Scatter + if (has_sve()) { + for (int width = 64; width <= 64 * 4; width *= 2) { + const int total_lanes = width / bits; + const int instr_lanes = min(total_lanes, 128 / bits); + if (instr_lanes < 2) continue; // bail out scalar op + + AddTestFunctor add(*this, bits, total_lanes); + Expr index = clamp(cast(in_im(x)), 0, W - 1); + Func tmp; + tmp(x, y) = cast(elt, y); + tmp(x, index) = cast(elt, 1); + tmp.compute_root().update().vectorize(x, total_lanes); + Expr gather = in_im(index); + Expr scatter = tmp(0, 0) + tmp(0, 127); + + const int index_bits = std::max(32, bits); + add({get_sve_ls_instr("ld1", bits, index_bits, "uxtw")}, total_lanes, gather); + add({get_sve_ls_instr("st1", bits, index_bits, "uxtw")}, total_lanes, scatter); + } + } + } + } + + void check_arm_pairwise() { + // A summation reduction that starts at something + // non-trivial, to avoid llvm simplifying accumulating + // widening summations into just widening summations. + auto sum_ = [&](Expr e) { + Func f; + f(x) = cast(e.type(), 123); + f(x) += e; + return f(x); + }; + + // Tests for integer type + { + vector> test_params{ + {8, in_i8, in_u8, i16, i32, u16, u32}, + {16, in_i16, in_u16, i32, i64, u32, u64}, + {32, in_i32, in_u32, i64, i64, u64, u64}, + {64, in_i64, in_u64, i64, i64, u64, u64}, + }; + // clang-format on + + for (const auto &[bits, in_i, in_u, widen_i, widenx4_i, widen_u, widenx4_u] : test_params) { + + for (auto &total_bits : {64, 128}) { + const int vf = total_bits / bits; + const int instr_lanes = Instruction::get_force_vectorized_instr_lanes(bits, vf, target); + AddTestFunctor add(*this, bits, instr_lanes, vf, !(is_arm32() && bits == 64)); // 64 bit is unavailable in neon 32 bit + AddTestFunctor add_8_16_32(*this, bits, instr_lanes, vf, bits != 64); + const int widen_lanes = Instruction::get_instr_lanes(bits, vf * 2, target); + AddTestFunctor add_widen(*this, bits, widen_lanes, vf, bits != 64); + + if (!has_sve()) { + // VPADD I, F - Pairwise Add + // VPMAX I, F - Pairwise Maximum + // VPMIN I, F - Pairwise Minimum + for (int f : {2, 4}) { + RDom r(0, f); + + add(sel_op("vpadd.i", "addp"), sum_(in_i(f * x + r))); + add(sel_op("vpadd.i", "addp"), sum_(in_u(f * x + r))); + add_8_16_32(sel_op("vpmax.s", "smaxp"), maximum(in_i(f * x + r))); + add_8_16_32(sel_op("vpmax.u", "umaxp"), maximum(in_u(f * x + r))); + add_8_16_32(sel_op("vpmin.s", "sminp"), minimum(in_i(f * x + r))); + add_8_16_32(sel_op("vpmin.u", "uminp"), minimum(in_u(f * x + r))); + } + } + + // VPADAL I - Pairwise Add and Accumulate Long + // VPADDL I - Pairwise Add Long + { + int f = 2; + RDom r(0, f); + + // If we're reducing by a factor of two, we can + // use the forms with an accumulator + add_widen(sel_op("vpadal.s", "sadalp"), sum_(widen_i(in_i(f * x + r)))); + add_widen(sel_op("vpadal.u", "uadalp"), sum_(widen_i(in_u(f * x + r)))); + add_widen(sel_op("vpadal.u", "uadalp"), sum_(widen_u(in_u(f * x + r)))); + } + { + int f = 4; + RDom r(0, f); + + // If we're reducing by more than that, that's not + // possible. + // In case of SVE, addlp is unavailable, so adalp is used with accumulator=0 instead. + add_widen(sel_op("vpaddl.s", "saddlp", "sadalp"), sum_(widen_i(in_i(f * x + r)))); + add_widen(sel_op("vpaddl.u", "uaddlp", "uadalp"), sum_(widen_i(in_u(f * x + r)))); + add_widen(sel_op("vpaddl.u", "uaddlp", "uadalp"), sum_(widen_u(in_u(f * x + r)))); + } + + const bool is_arm_dot_prod_available = (!is_arm32() && target.has_feature(Target::ARMDotProd) && bits == 8) || + (has_sve() && (bits == 8 || bits == 16)); + if ((bits == 8 || bits == 16) && !is_arm_dot_prod_available) { // udot/sdot is applied if available + int f = 4; + RDom r(0, f); + // If we're widening the type by a factor of four + // as well as reducing by a factor of four, we + // expect vpaddl followed by vpadal + // Note that when going from u8 to i32 like this, + // the vpaddl is unsigned and the vpadal is a + // signed, because the intermediate type is u16 + const int widenx4_lanes = Instruction::get_instr_lanes(bits * 2, vf, target); + string op_addl, op_adal; + op_addl = sel_op("vpaddl.s", "saddlp"); + op_adal = sel_op("vpadal.s", "sadalp"); + add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_i(in_i(f * x + r)))); + op_addl = sel_op("vpaddl.u", "uaddlp"); + op_adal = sel_op("vpadal.u", "uadalp"); + add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_i(in_u(f * x + r)))); + add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_u(in_u(f * x + r)))); + } + + // UDOT/SDOT + if (is_arm_dot_prod_available) { + const int factor_32bit = vf / 4; + for (int f : {4, 8}) { + // checks vector register for narrow src data type (i.e. 8 or 16 bit) + const int lanes_src = Instruction::get_instr_lanes(bits, f * factor_32bit, target); + AddTestFunctor add_dot(*this, bits, lanes_src, factor_32bit); + RDom r(0, f); + + add_dot("udot", sum(widenx4_u(in_u(f * x + r)) * in_u(f * x + r + 32))); + add_dot("sdot", sum(widenx4_i(in_i(f * x + r)) * in_i(f * x + r + 32))); + if (f == 4) { + // This doesn't generate for higher reduction factors because the + // intermediate is 16-bit instead of 32-bit. It seems like it would + // be slower to fix this (because the intermediate sum would be + // 32-bit instead of 16-bit). + add_dot("udot", sum(widenx4_u(in_u(f * x + r)))); + add_dot("sdot", sum(widenx4_i(in_i(f * x + r)))); + } + } + } + } + } + } + + // Tests for Float type + { + // clang-format off + vector> test_params{ + {16, in_f16}, + {32, in_f32}, + {64, in_f64}, + }; + // clang-format on + if (!has_sve()) { + for (const auto &[bits, in_f] : test_params) { + for (auto &total_bits : {64, 128}) { + const int vf = total_bits / bits; + if (vf < 2) continue; + AddTestFunctor add(*this, bits, vf); + AddTestFunctor add_16_32(*this, bits, vf, bits != 64); + + if (bits == 16 && !is_float16_supported()) { + continue; + } + + for (int f : {2, 4}) { + RDom r(0, f); + + add(sel_op("vadd.f", "faddp"), sum_(in_f(f * x + r))); + add_16_32(sel_op("vmax.f", "fmaxp"), maximum(in_f(f * x + r))); + add_16_32(sel_op("vmin.f", "fminp"), minimum(in_f(f * x + r))); + } + } + } + } + } + } + + struct ArmTask { + vector instrs; + }; + + struct Instruction { + string opcode; + optional operand; + optional bits; + optional pattern_lanes; + static inline const int ANY_LANES = -1; + + // matching pattern for opcode/operand is directly set + Instruction(const string &opcode, const string &operand) + : opcode(opcode), operand(operand), bits(nullopt), pattern_lanes(nullopt) { + } + + // matching pattern for opcode/operand is generated from bits/lanes + Instruction(const string &opcode, int bits, int lanes) + : opcode(opcode), operand(nullopt), bits(bits), pattern_lanes(lanes) { + } + + string generate_pattern(const Target &target) const { + bool is_arm32 = target.bits == 32; + bool has_sve = target.has_feature(Target::SVE2); + + string opcode_pattern; + string operand_pattern; + if (bits && pattern_lanes) { + if (is_arm32) { + opcode_pattern = get_opcode_neon32(); + operand_pattern = get_reg_neon32(); + } else if (!has_sve) { + opcode_pattern = opcode; + operand_pattern = get_reg_neon64(); + } else { + opcode_pattern = opcode; + operand_pattern = get_reg_sve(); + } + } else { + opcode_pattern = opcode; + operand_pattern = operand.value_or(""); + } + // e.g "add v15.h " -> "\s*add\s.*\bv\d\d?\.h\b.*" + return opcode_pattern + R"(\s.*\b)" + operand_pattern + R"(\b.*)"; + } + + // TODO Fix this for SVE2 + static int natural_lanes(int bits) { + return 128 / bits; + } + + static int get_instr_lanes(int bits, int vec_factor, const Target &target) { + return min(natural_lanes(bits), vec_factor); + } + + static int get_force_vectorized_instr_lanes(int bits, int vec_factor, const Target &target) { + // For some cases, where scalar operation is forced to vectorize + if (target.has_feature(Target::SVE2)) { + if (vec_factor == 1) { + return 1; + } else { + return natural_lanes(bits); + } + } else { + int min_lanes = std::max(2, natural_lanes(bits) / 2); // 64 bit wide VL + return max(min_lanes, get_instr_lanes(bits, vec_factor, target)); + } + } + + string get_opcode_neon32() const { + return opcode + to_string(bits.value()); + } + + const char *get_bits_designator() const { + static const map designators{ + // NOTE: vector or float only + {8, "b"}, + {16, "h"}, + {32, "s"}, + {64, "d"}, + }; + auto iter = designators.find(bits.value()); + assert(iter != designators.end()); + return iter->second; + } + + string get_reg_sve() const { + if (pattern_lanes == ANY_LANES) { + return R"((z\d\d?\.[bhsd])|(s\d\d?))"; + } else { + const char *bits_designator = get_bits_designator(); + // TODO(need issue): This should only match the scalar register, and likely a NEON instruction opcode. + // Generating a full SVE vector instruction for a scalar operation is inefficient. However this is + // happening and fixing it involves changing intrinsic selection. Likely to use NEON intrinsics where + // applicable. For now, accept both a scalar operation and a vector one. + std::string scalar_reg_pattern = (pattern_lanes > 1) ? "" : std::string("|(") + bits_designator + R"(\d\d?))"; // e.g. "h15" + + return std::string(R"(((z\d\d?\.)") + bits_designator + ")|(" + + R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator + ")" + scalar_reg_pattern + ")"; + } + } + + string get_reg_neon32() const { + return ""; + } + + string get_reg_neon64() const { + const char *bits_designator = get_bits_designator(); + if (pattern_lanes == 1) { + return std::string(bits_designator) + R"(\d\d?)"; // e.g. "h15" + } else if (pattern_lanes == ANY_LANES) { + return R"(v\d\d?\.[bhsd])"; + } else { + return R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator; // e.g. "v15.4h" + } + } + }; + + Instruction get_sve_ls_instr(const string &base_opcode, int opcode_bits, int operand_bits, const string &additional = "", const string &optional_type = "") { + static const map opcode_suffix_map = {{8, "b"}, {16, "h"}, {32, "w"}, {64, "d"}}; + static const map operand_suffix_map = {{8, "b"}, {16, "h"}, {32, "s"}, {64, "d"}}; + string opcode_size_specifier; + string operand_size_specifier; + if (!optional_type.empty()) { + opcode_size_specifier = "["; + operand_size_specifier = "["; + } + opcode_size_specifier += opcode_suffix_map.at(opcode_bits); + operand_size_specifier += operand_suffix_map.at(operand_bits); + if (!optional_type.empty()) { + opcode_size_specifier += optional_type; + opcode_size_specifier += "]"; + operand_size_specifier += optional_type; + operand_size_specifier += "]"; + } + const string opcode = base_opcode + opcode_size_specifier; + string operand = R"(z\d\d?\.)" + operand_size_specifier; + if (!additional.empty()) { + operand += ", " + additional; + } + return Instruction(opcode, operand); + } + + Instruction get_sve_ls_instr(const string &base_opcode, int bits) { + return get_sve_ls_instr(base_opcode, bits, bits, ""); + } + + // Helper functor to add test case + class AddTestFunctor { + public: + AddTestFunctor(SimdOpCheckArmSve &p, + int default_bits, + int default_instr_lanes, + int default_vec_factor, + bool is_enabled = true /* false to skip testing */) + : parent(p), default_bits(default_bits), default_instr_lanes(default_instr_lanes), + default_vec_factor(default_vec_factor), is_enabled(is_enabled){}; + + AddTestFunctor(SimdOpCheckArmSve &p, + int default_bits, + // default_instr_lanes is inferred from bits and vec_factor + int default_vec_factor, + bool is_enabled = true /* false to skip testing */) + : parent(p), default_bits(default_bits), + default_instr_lanes(Instruction::get_instr_lanes(default_bits, default_vec_factor, p.target)), + default_vec_factor(default_vec_factor), is_enabled(is_enabled){}; + + // Constructs single Instruction with default parameters + void operator()(const string &opcode, Expr e) { + // Use opcode for name + (*this)(opcode, opcode, e); + } + + // Constructs single Instruction with default parameters except for custom name + void operator()(const string &op_name, const string &opcode, Expr e) { + create_and_register(op_name, {Instruction{opcode, default_bits, default_instr_lanes}}, default_vec_factor, e); + } + + // Constructs multiple Instruction with default parameters + void operator()(const vector &opcodes, Expr e) { + assert(!opcodes.empty()); + (*this)(opcodes[0], opcodes, e); + } + + // Constructs multiple Instruction with default parameters except for custom name + void operator()(const string &op_name, const vector &opcodes, Expr e) { + vector instrs; + for (const auto &opcode : opcodes) { + instrs.emplace_back(opcode, default_bits, default_instr_lanes); + } + create_and_register(op_name, instrs, default_vec_factor, e); + } + + // Set single or multiple Instructions of custom parameters + void operator()(const vector &instructions, int vec_factor, Expr e) { + // Use the 1st opcode for name + assert(!instructions.empty()); + string op_name = instructions[0].opcode; + (*this)(op_name, instructions, vec_factor, e); + } + + // Set single or multiple Instructions of custom parameters, with custom name + void operator()(const string &op_name, const vector &instructions, int vec_factor, Expr e) { + create_and_register(op_name, instructions, vec_factor, e); + } + + private: + void create_and_register(const string &op_name, const vector &instructions, int vec_factor, Expr e) { + if (!is_enabled) return; + + // Generate regular expression for the instruction we check + vector instr_patterns; + transform(instructions.begin(), instructions.end(), back_inserter(instr_patterns), + [t = parent.target](const Instruction &instr) { return instr.generate_pattern(t); }); + + std::stringstream type_name_stream; + type_name_stream << e.type(); + std::string decorated_op_name = op_name + "_" + type_name_stream.str() + "_x" + std::to_string(vec_factor); + auto unique_name = "op_" + decorated_op_name + "_" + std::to_string(parent.tasks.size()); + + // Bail out after generating the unique_name, so that names are + // unique across different processes and don't depend on filter + // settings. + if (!parent.wildcard_match(parent.filter, decorated_op_name)) return; + + // Create a deep copy of the expr and all Funcs referenced by it, so + // that no IR is shared between tests. This is required by the base + // class, and is why we can parallelize. + { + using namespace Halide::Internal; + class FindOutputs : public IRVisitor { + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->func.defined()) { + outputs.insert(op->func); + } + IRVisitor::visit(op); + } + + public: + std::set outputs; + } finder; + e.accept(&finder); + std::vector outputs(finder.outputs.begin(), finder.outputs.end()); + auto env = deep_copy(outputs, build_environment(outputs)).second; + class DeepCopy : public IRMutator { + std::map copied; + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->func.defined()) { + auto it = env.find(op->name); + if (it != env.end()) { + return Func(it->second)(mutate(op->args)); + } + } + return IRMutator::visit(op); + } + const std::map &env; + + public: + DeepCopy(const std::map &env) + : env(env) { + } + } copier(env); + e = copier.mutate(e); + } + + // Create Task and register + parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e}); + parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)}); + } + + SimdOpCheckArmSve &parent; + int default_bits; + int default_instr_lanes; + int default_vec_factor; + bool is_enabled; + }; + + void compile_and_check(Func error, const string &op, const string &name, int vector_width, const std::vector &arg_types, ostringstream &error_msg) override { + // This is necessary as LLVM validation errors, crashes, etc. don't tell which op crashed. + cout << "Starting op " << op << "\n"; + string fn_name = "test_" + name; + string file_name = output_directory + fn_name; + + auto ext = Internal::get_output_info(target); + std::map outputs = { + {OutputFileType::llvm_assembly, file_name + ext.at(OutputFileType::llvm_assembly).extension}, + {OutputFileType::c_header, file_name + ext.at(OutputFileType::c_header).extension}, + {OutputFileType::object, file_name + ext.at(OutputFileType::object).extension}, + {OutputFileType::assembly, file_name + ".s"}, + }; + + error.compile_to(outputs, arg_types, fn_name, target); + + std::ifstream asm_file; + asm_file.open(file_name + ".s"); + + auto arm_task = arm_tasks.find(name); + assert(arm_task != arm_tasks.end()); + + std::ostringstream msg; + msg << op << " did not generate for target=" << target.to_string() + << " vector_width=" << vector_width << ". Instead we got:\n"; + + string line; + vector matched_lines; + vector &patterns = arm_task->second.instrs; + while (getline(asm_file, line) && !patterns.empty()) { + msg << line << "\n"; + auto pattern = patterns.begin(); + while (pattern != patterns.end()) { + smatch match; + if (regex_search(line, match, regex(*pattern))) { + pattern = patterns.erase(pattern); + matched_lines.emplace_back(match[0]); + } else { + ++pattern; + } + } + } + + if (!patterns.empty()) { + error_msg << "Failed: " << msg.str() << "\n"; + error_msg << "The following instruction patterns were not found:\n"; + for (auto &p : patterns) { + error_msg << p << "\n"; + } + } else if (debug_mode == "1") { + for (auto &l : matched_lines) { + error_msg << " " << setw(20) << name << ", vf=" << setw(2) << vector_width << ", "; + error_msg << l << endl; + } + } + } + + inline const string &sel_op(const string &neon32, const string &neon64) { + return is_arm32() ? neon32 : neon64; + } + + inline const string &sel_op(const string &neon32, const string &neon64, const string &sve) { + return is_arm32() ? neon32 : + target.has_feature(Target::SVE) || target.has_feature(Target::SVE2) ? sve : + neon64; + } + + inline bool is_arm32() const { + return target.bits == 32; + }; + inline bool has_neon() const { + return !target.has_feature(Target::NoNEON); + }; + inline bool has_sve() const { + return target.has_feature(Target::SVE2); + }; + + bool is_float16_supported() const { + return (target.bits == 64) && target.has_feature(Target::ARMFp16); + } + + bool can_run_the_code; + string debug_mode; + std::unordered_map arm_tasks; + const Var x{"x"}, y{"y"}; +}; +} // namespace + +int main(int argc, char **argv) { + if (Halide::Internal::get_llvm_version() < 190) { + std::cout << "[SKIP] simd_op_check_sve2 requires LLVM 19 or later.\n"; + return 0; + } + + return SimdOpCheckTest::main( + argc, argv, + { + Target("arm-64-linux-sve2-no_neon-vector_bits_128"), + Target("arm-64-linux-sve2-no_neon-vector_bits_256"), + }); +}