diff --git a/cinn/backends/llvm/codegen_llvm.cc b/cinn/backends/llvm/codegen_llvm.cc index 7e5f6d5546d99..6a1233e8327d8 100644 --- a/cinn/backends/llvm/codegen_llvm.cc +++ b/cinn/backends/llvm/codegen_llvm.cc @@ -1374,7 +1374,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::UnaryIntrin *op) { } } CHECK(!op->args.empty()); - llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_); + llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_, true); llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type); CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); return b_->CreateCall(fn, arg_value); diff --git a/cinn/backends/llvm/llvm_intrin_rule.h b/cinn/backends/llvm/llvm_intrin_rule.h index 86a6fed20cb0e..b16c840f29421 100644 --- a/cinn/backends/llvm/llvm_intrin_rule.h +++ b/cinn/backends/llvm/llvm_intrin_rule.h @@ -24,8 +24,8 @@ inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) { CHECK(node); CHECK_GE(node->read_args.size(), arg_nums); if (add_float_suffix) { - CHECK_EQ(node->type(), Float(32)); - *rv = ir::intrinsics::UnaryIntrin::Make(node->name + "f", node->read_args, id, arg_nums, Float(32)); + CHECK(node->type().is_float()); + *rv = ir::intrinsics::UnaryIntrin::Make(node->name + "f", node->read_args, id, arg_nums, node->type()); } else { *rv = ir::intrinsics::UnaryIntrin::Make(node->name, node->read_args, id, arg_nums, node->type()); } diff --git a/cinn/backends/llvm/llvm_util.cc b/cinn/backends/llvm/llvm_util.cc index c77281a8d15e5..9423a72f72426 100644 --- a/cinn/backends/llvm/llvm_util.cc +++ b/cinn/backends/llvm/llvm_util.cc @@ -9,7 +9,7 @@ namespace cinn { namespace backends { -llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m) { +llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m, bool is_vec) { llvm::Type *ir_type = nullptr; if (type.is_cpp_const()) { // TODO(fc500110) support it latter. @@ -51,9 +51,13 @@ llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m) { } CHECK(ir_type) << "LLVM can't convert type: " << type; - // C array. + // C array / vector. if (type.lanes() > 1) { - ir_type = llvm::ArrayType::get(ir_type, type.lanes()); + if (is_vec) { + ir_type = llvm::FixedVectorType::get(ir_type, type.lanes()); + } else { + ir_type = llvm::ArrayType::get(ir_type, type.lanes()); + } } if (type.is_cpp_handle()) { diff --git a/cinn/backends/llvm/llvm_util.h b/cinn/backends/llvm/llvm_util.h index 1d092b622ed30..2145585cc0919 100644 --- a/cinn/backends/llvm/llvm_util.h +++ b/cinn/backends/llvm/llvm_util.h @@ -32,7 +32,7 @@ std::string DumpToString(const T &entity) { inline llvm::StringRef AsStringRef(std::string_view str) { return llvm::StringRef(str.data(), str.size()); } -llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m); +llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m, bool is_vec = false); template llvm::Type *llvm_type_of(llvm::Module *m); diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc index 7190ec374e54f..f0754ab3166a8 100644 --- a/cinn/hlir/op/elementwise.cc +++ b/cinn/hlir/op/elementwise.cc @@ -62,6 +62,11 @@ std::shared_ptr StrategyForElementwise(const framework::NodeAttr &at stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); } + } else if (target.arch == Target::Arch::X86) { + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.back(), target); } *ret = arg_pack; }); diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc index e853de437333d..59a36f9bc18af 100755 --- a/cinn/ir/ir.cc +++ b/cinn/ir/ir.cc @@ -464,6 +464,7 @@ Expr Load::Make(Expr tensor, const std::vector &indices) { auto node = make_shared(); node->tensor = tensor; node->indices = indices; + node->set_type(node->type()); return Expr(node); } Type Load::type() const { @@ -662,7 +663,7 @@ void Select::Verify() const { CHECK(condition.defined()); CHECK(true_value.defined()); CHECK(false_value.defined()); - CHECK_EQ(condition.type(), type_of()) << "Select Node's condition should be a boolean"; + CHECK(condition.type().is_bool()) << "Select Node's condition should be a boolean"; CHECK_EQ(true_value.type(), false_value.type()) << "Select Node's true_value and false_value should have the same type"; } diff --git a/cinn/lang/builtin.cc b/cinn/lang/builtin.cc index 946394538f9ff..663cc17aba590 100644 --- a/cinn/lang/builtin.cc +++ b/cinn/lang/builtin.cc @@ -32,10 +32,16 @@ Expr logic_or(const std::vector& conds) { //! extern call op #define EXTERN_CALL_IMP(name__, target__) \ - Expr name__(Expr e) { return CallExtern(#target__, {e}); } + Expr name__(Expr e) { return ir::Call::Make(e->type(), #target__, {e}, {}, ir::CallType::Extern); } + +#define EXTERN_CALL_IMP_NO_VEC(name__, target__) \ + Expr name__(Expr e) { \ + return ir::Call::Make( \ + e->type(), #target__, {e}, {}, ir::CallType::Extern, ir::FunctionRef(), 0, {{"vectorizable", false}}); \ + } EXTERN_CALL_IMP(Exp, exp); -EXTERN_CALL_IMP(Erf, erf); +EXTERN_CALL_IMP_NO_VEC(Erf, erf); EXTERN_CALL_IMP(Sqrt, sqrt); EXTERN_CALL_IMP(Log, log); EXTERN_CALL_IMP(Log2, log2); @@ -45,17 +51,17 @@ EXTERN_CALL_IMP(Ceil, ceil); EXTERN_CALL_IMP(Round, round); EXTERN_CALL_IMP(Trunc, trunc); EXTERN_CALL_IMP(Cos, cos); +EXTERN_CALL_IMP(Sin, sin); EXTERN_CALL_IMP(Cosh, cosh); EXTERN_CALL_IMP(Tan, tan); -EXTERN_CALL_IMP(Sin, sin); -EXTERN_CALL_IMP(Sinh, sinh); -EXTERN_CALL_IMP(Acos, acos); -EXTERN_CALL_IMP(Acosh, acosh); -EXTERN_CALL_IMP(Asin, asin); -EXTERN_CALL_IMP(Asinh, asinh); -EXTERN_CALL_IMP(Atan, atan); -EXTERN_CALL_IMP(Atanh, atanh); EXTERN_CALL_IMP(Tanh, tanh); +EXTERN_CALL_IMP(Sinh, sinh); +EXTERN_CALL_IMP_NO_VEC(Acos, acos); +EXTERN_CALL_IMP_NO_VEC(Acosh, acosh); +EXTERN_CALL_IMP_NO_VEC(Asin, asin); +EXTERN_CALL_IMP_NO_VEC(Asinh, asinh); +EXTERN_CALL_IMP_NO_VEC(Atan, atan); +EXTERN_CALL_IMP_NO_VEC(Atanh, atanh); Expr min_value(const Type& type) { CHECK_EQ(type.lanes(), 1); @@ -114,7 +120,6 @@ Expr Abs(Expr e) { Expr IsNan(Expr e) { Type type = e->type(); - // Type bool_type = Bool(type.lanes()); if (type.is_int() || type.is_uint()) { return common::make_bool(false, type.lanes()); } else if (type.is_float()) { @@ -126,7 +131,7 @@ Expr IsNan(Expr e) { if (type.bits() == 16) { arg = ir::Cast::Make(Float(32), std::move(e)); } - return CallExtern("isnan", {arg}); + return CallExtern("isnan", {arg}, {{"vectorizable", false}}); } else { LOG(FATAL) << type << "is not supported for isnan op."; return e; diff --git a/cinn/optim/optimize.cc b/cinn/optim/optimize.cc index 19546eb312d3d..258ce45adb0df 100644 --- a/cinn/optim/optimize.cc +++ b/cinn/optim/optimize.cc @@ -36,7 +36,6 @@ Expr Optimize(Expr e, Target target, bool runtime_debug_info) { CastSimplify(&copied); Simplify(&copied); VectorizeLoops(&copied, Target()); - EliminateBroadcastInForloop(&copied); UnrollLoop(&copied); #ifdef CINN_WITH_CUDA RemoveGpuForloopsAxis(&copied); diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index 4298657ba63c4..91f4cfa344f79 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -137,8 +137,15 @@ class Vectorizer : public IRMutator { } } if (!need_visit) return; - - *expr = Load::Make(node->tensor, node->indices); + int lanes = 0; + for (auto &idx : node->indices) { + lanes = std::max(idx.type().lanes(), lanes); + } + std::vector new_indices; + for (auto &idx : node->indices) { + new_indices.push_back(Widen(idx, lanes)); + } + *expr = Load::Make(node->tensor, new_indices); } void Visit(const Store *op, Expr *expr) override { @@ -173,7 +180,45 @@ class Vectorizer : public IRMutator { *expr = Store::Make(node->tensor, node->value, new_indices); } - void Visit(const Call *op, Expr *expr) override { LOG(ERROR) << "Ignore widen Call node"; } + void Visit(const Call *op, Expr *expr) override { + std::vector read_args = op->read_args; + std::vector write_args = op->write_args; + auto *node = expr->As(); + ir::IRMutator<>::Visit(op, expr); + bool is_changed = false; + int lanes = 0; + for (int i = 0; i < node->read_args.size(); i++) { + lanes = std::max(node->read_args[i].type().lanes(), lanes); + if (!node->read_args[i].same_as(read_args[i])) { + is_changed = true; + } + } + for (int i = 0; i < node->write_args.size(); i++) { + lanes = std::max(node->write_args[i].type().lanes(), lanes); + if (!node->write_args[i].same_as(write_args[i])) { + is_changed = true; + } + } + if (!is_changed) return; + + for (int i = 0; i < read_args.size(); i++) { + node->read_args[i] = Widen(node->read_args[i], lanes); + } + for (int i = 0; i < write_args.size(); i++) { + node->write_args[i] = Widen(node->write_args[i], lanes); + } + + CHECK(!read_args.empty()); + Type type = op->type().with_lanes(lanes); + *expr = Call::Make(type, + node->name, + node->read_args, + node->write_args, + node->call_type, + node->func, + node->value_index, + node->attrs); + } void Visit(const Let *op, Expr *expr) override { auto *node = expr->As();