diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index a668d73251af2..e1106cf726aaf 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -671,6 +671,18 @@ void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) { os() << ")"; } +void CodeGenC::Visit(const ir::intrinsics::UnaryIntrin *op) { + os() << runtime::intrisic::unary_intrin_repr << "_"; + os() << op->name << "("; + if (!op->args.empty()) { + for (int i = 0; i < op->args.size() - 1; i++) { + Print(op->args[i]); + os() << ", "; + } + Print(op->args.back()); + } +} + std::string ReadWholeFile(const std::string &path) { CHECK(!path.empty()); std::ifstream file(path); diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index ed4e681a20ea1..7ccbcb4471a70 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -488,7 +488,7 @@ TEST(CodeGenC, call_extern) { Placeholder x("x", {M}); ir::Tensor y = Compute( - {M}, [=](Var i) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {x(i)}); }, "y"); + {M}, [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, "y"); auto stages = CreateStages({y}); diff --git a/cinn/backends/extern_func_protos.cc b/cinn/backends/extern_func_protos.cc index 03839f37d9045..8802dab41aca2 100644 --- a/cinn/backends/extern_func_protos.cc +++ b/cinn/backends/extern_func_protos.cc @@ -1,24 +1,36 @@ #include "cinn/backends/extern_func_protos.h" +#include +#include + namespace cinn { namespace backends { ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() { - static const std::vector extern_funcs_fp32 = { - "exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", - "ceil", "round", "trunc", "cos", "cosh", "tan", "sin", "sinh", - "acos", "acosh", "asin", "asinh", "atan", "atanh", "isnan", "tanh", - "isfinite", "isinf", "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; - static const std::vector extern_funcs_int64 = { + static const std::vector extern_funcs_fp32_unary = { + "exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", "cos", + "cosh", "tan", "tanh", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", "fabs"}; + static const std::vector extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"}; + static const std::vector extern_funcs_int_binary = { "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; - for (int i = 0; i < extern_funcs_fp32.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_fp32[i], {Float(32)}, Float(32)); + static const std::vector extern_funcs_int_int_unary = {"bitwise_not"}; + for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32)); + Register(proto->name, proto); + } + for (int i = 0; i < extern_funcs_float_bool_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_float_bool_unary[i], {Float(32)}, Bool()); Register(proto->name, proto); } - for (int i = 0; i < extern_funcs_int64.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_int64[i], {Int(64)}, Int(64)); + for (int i = 0; i < extern_funcs_int_binary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_int_binary[i], {Int(32), Int(32)}, Int(32)); Register(proto->name, proto); } + for (int i = 0; i < extern_funcs_int_int_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32)); + Register(proto->name, proto); + } + auto* n = detail::CreateTanhVProto(); Register(n->name, n); } diff --git a/cinn/backends/llvm/codegen_llvm.cc b/cinn/backends/llvm/codegen_llvm.cc index 85f7991e6d9ea..dd88c62ec10e1 100644 --- a/cinn/backends/llvm/codegen_llvm.cc +++ b/cinn/backends/llvm/codegen_llvm.cc @@ -2,10 +2,13 @@ #include #include +#include #include #include +#include #include #include +#include #include #include @@ -78,8 +81,11 @@ int NextPowerOfTwo(int x) { } // namespace -CodeGenLLVM::CodeGenLLVM(llvm::Module *m, llvm::IRBuilder<> *b, const std::shared_ptr &symbol_table) - : m_(m), b_(b), symbol_table_(symbol_table) { +CodeGenLLVM::CodeGenLLVM(llvm::Module *m, + llvm::IRBuilder<> *b, + const std::shared_ptr &symbol_table, + const Target &target) + : m_(m), b_(b), symbol_table_(symbol_table), target_(target) { if (!symbol_table.get()) { symbol_table_ = std::make_shared(); } @@ -88,8 +94,7 @@ CodeGenLLVM::CodeGenLLVM(llvm::Module *m, llvm::IRBuilder<> *b, const std::share md_builder_ = std::make_unique(b_->getContext()); md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); - - InitTarget(common::DefaultHostTarget()); + InitTarget(target_); } CodeGenLLVM::~CodeGenLLVM() {} @@ -1141,6 +1146,11 @@ llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, int begin, int lanes) } void CodeGenLLVM::InitTarget(const Target &target) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); switch (target.arch) { case Target::Arch::X86: if (target.bits == Target::Bit::k32) { @@ -1283,6 +1293,104 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::ArgsConstruct *op) { return Call(callee, std::move(args)); } +llvm::Function *CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type *ret_type, + llvm::ArrayRef arg_types) { + llvm::Module *module = m_; + + if (!llvm::Intrinsic::isOverloaded(id)) { + return llvm::Intrinsic::getDeclaration(module, id, {}); + } + + llvm::SmallVector infos; + llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos); + llvm::SmallVector overload_types; + + auto try_match = [&](llvm::FunctionType *f_ty, bool var_arg) { + overload_types.clear(); + llvm::ArrayRef ref(infos); + auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + if (llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref)) { + return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; + } + } + return match; + }; + + auto *fn_ty = llvm::FunctionType::get(ret_type, arg_types, false); + switch (try_match(fn_ty, false)) { + case llvm::Intrinsic::MatchIntrinsicTypes_Match: + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet: + return nullptr; + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg: + break; + } + + // try matching the var arg signature. + llvm::SmallVector var_types; + for (int i = 0; i <= arg_types.size(); ++i) { + if (i > 0) { + var_types.push_back(arg_types[i - 1]); + } + auto *ft = llvm::FunctionType::get(ret_type, var_types, true); + if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + } + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::UnaryIntrin *op) { + std::string func_name = op->name; + if (op->id == -1) { + if (func_name == "bitwise_and") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateAnd(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_or") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateOr(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_xor") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateXor(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_not") { + CHECK_GE(op->args.size(), 1U); + return b_->CreateNot(Visit(&op->args[0])); + } else if (func_name == "left_shift") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateShl(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "right_shift") { + CHECK_GE(op->args.size(), 2U); + if (op->args[0]->type().is_int()) { + return b_->CreateAShr(Visit(&op->args[0]), Visit(&op->args[1])); + } else { + return b_->CreateLShr(Visit(&op->args[0]), Visit(&op->args[1])); + } + } else if (func_name == "isnan") { + CHECK_GE(op->args.size(), 1U); + llvm::Value *v = Visit(&op->args[0]); + return b_->CreateFCmpUNO(v, v); + } + } + + llvm::Intrinsic::ID id = op->id; + int64_t num_signature = op->arg_nums; + std::vector arg_value; + std::vector arg_type; + for (size_t i = 0; i < op->args.size(); ++i) { + arg_value.push_back(Visit(&op->args[i])); + if (i < static_cast(num_signature)) { + arg_type.push_back(arg_value.back()->getType()); + } + } + CHECK(!op->args.empty()); + llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_); + llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type); + CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); + return b_->CreateCall(fn, arg_value); +} + llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) { auto to_type = op->GetOutputType(0); llvm::Function *callee{}; diff --git a/cinn/backends/llvm/codegen_llvm.h b/cinn/backends/llvm/codegen_llvm.h index 20a08c5439b05..00eb453a1964f 100644 --- a/cinn/backends/llvm/codegen_llvm.h +++ b/cinn/backends/llvm/codegen_llvm.h @@ -82,7 +82,7 @@ class SymbolTable { }; struct SymbolTableGuard { - SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); } + explicit SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); } ~SymbolTableGuard() { symbol_table_.PopScope(); } @@ -97,7 +97,8 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { public: explicit CodeGenLLVM(llvm::Module *m, llvm::IRBuilder<> *b, - const std::shared_ptr &symbol_table = nullptr); + const std::shared_ptr &symbol_table = nullptr, + const Target &target = common::DefaultHostTarget()); // Common llvm types // @{ @@ -146,6 +147,10 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { virtual llvm::Value *GetVar(const std::string &name, bool lazy = true); + llvm::Function *GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type *ret_type, + llvm::ArrayRef arg_types); + // Constants // @{ inline llvm::Value *llvm_int32_constant(int v) { return llvm::ConstantInt::get(ll_int32_ty(), v); } @@ -200,6 +205,7 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { llvm::MDNode *md_tbaa_alias_set_{nullptr}; int naive_vec_alignment_{0}; + Target target_; }; namespace detail { Expr StridedRampBase(Expr e, int stride); diff --git a/cinn/backends/llvm/execution_engine_test.cc b/cinn/backends/llvm/execution_engine_test.cc index 75773288ae1e1..69d3ed8d84975 100644 --- a/cinn/backends/llvm/execution_engine_test.cc +++ b/cinn/backends/llvm/execution_engine_test.cc @@ -268,7 +268,7 @@ TEST(ExecutionEngine, call_extern) { {M, N}, [=](Var i, Var j) { return x(i, j) + y(i, j); }, "add_out"); ir::Tensor res = Compute( - {M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {add_out(i, j)}); }, "res"); + {M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("tanh", {add_out(i, j)}); }, "res"); auto stages = CreateStages({add_out, res}); @@ -297,7 +297,7 @@ TEST(ExecutionEngine, call_extern) { auto *cd = reinterpret_cast(cb->memory); for (int m = 0; m < kM; m++) { for (int n = 0; n < kN; n++) { - ASSERT_NEAR(cd[m * kN + n], cinn_cpu_tanh_fp32(ad[m * kN + n] + bd[m * kN + n]), 1e-5); + ASSERT_NEAR(cd[m * kN + n], tanh(ad[m * kN + n] + bd[m * kN + n]), 1e-5); } } } diff --git a/cinn/backends/llvm/llvm_intrin_rule.h b/cinn/backends/llvm/llvm_intrin_rule.h new file mode 100644 index 0000000000000..a0cbb84f8c823 --- /dev/null +++ b/cinn/backends/llvm/llvm_intrin_rule.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/registry.h" +#include "cinn/lang/packed_func.h" + +namespace cinn { +namespace codegen { + +template +inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg = args[0]; + ir::Call *node = arg->as(); + CHECK(node); + CHECK_GE(node->read_args.size(), arg_nums); + if (add_float_suffix) { + CHECK_EQ(node->type(), Float(32)); + *rv = ir::intrinsics::UnaryIntrin::Make(node->name + "f", node->read_args, id, arg_nums, Float(32)); + } else { + *rv = ir::intrinsics::UnaryIntrin::Make(node->name, node->read_args, id, arg_nums, node->type()); + } +} + +void RegisterCpuIntrinRule() { +#define __(intrin_name__, id) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp); + __(exp, ::llvm::Intrinsic::exp) + __(exp2, ::llvm::Intrinsic::exp2) + __(sqrt, ::llvm::Intrinsic::sqrt) + __(log, ::llvm::Intrinsic::log) + __(log2, ::llvm::Intrinsic::log2) + __(log10, ::llvm::Intrinsic::log10) + __(floor, ::llvm::Intrinsic::floor) + __(ceil, ::llvm::Intrinsic::ceil) + __(round, ::llvm::Intrinsic::round) + __(trunc, ::llvm::Intrinsic::trunc) + __(cos, ::llvm::Intrinsic::cos) + __(sin, ::llvm::Intrinsic::sin) + __(fabs, ::llvm::Intrinsic::fabs) +#undef __ + +// set id -1 if not llvm intrinsics +#define RegisterBitwise(intrin_name__) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp<-1, 2, false>); + RegisterBitwise(bitwise_or) RegisterBitwise(bitwise_xor) RegisterBitwise(bitwise_and) RegisterBitwise(left_shift) + RegisterBitwise(right_shift) +#undef RegisterBitwise + + ir::Registry::Register("lower_cpu_intrinsic_bitwise_not", true) + .SetBody(MakeFloatIntrinOp<-1, 1, false>); + ir::Registry::Register("lower_cpu_intrinsic_isnan", true).SetBody(MakeFloatIntrinOp<-1, 1, false>); + + ir::Registry::Register("lower_cpu_intrinsic_exp10", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr ln10 = make_const(arg->type(), 2.302585093); + *rv = lang::Exp(arg * ln10); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tan", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = lang::Sin(arg) / lang::Cos(arg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tanh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr zero = make_const(arg->type(), 0); + Expr one = make_const(arg->type(), 1); + Expr two = make_const(arg->type(), 2); + Expr neg_two = make_const(arg->type(), -2); + + Expr exp_neg2x = lang::Exp(neg_two * arg); + Expr exp_pos2x = lang::Exp(two * arg); + + Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = ir::Select::Make(arg >= zero, tanh_pos, tanh_neg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_cosh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); + }); + + ir::Registry::Register("lower_cpu_intrinsic_sinh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); + }); +} +} // namespace codegen +} // namespace cinn diff --git a/cinn/common/ir_util.h b/cinn/common/ir_util.h index a8defc7723175..eb0e559a457eb 100644 --- a/cinn/common/ir_util.h +++ b/cinn/common/ir_util.h @@ -48,6 +48,7 @@ inline Expr make_one() { return make_const(static_cast(1)); } inline Expr make_bool(bool x) { return common::make_shared(Bool(), x); } +inline Expr make_bool(bool x, int lanes) { return common::make_shared(Bool(lanes), x); } // @} /** diff --git a/cinn/hlir/framework/op_test.cc b/cinn/hlir/framework/op_test.cc index 4840112112f4c..ee439fde35d06 100644 --- a/cinn/hlir/framework/op_test.cc +++ b/cinn/hlir/framework/op_test.cc @@ -46,7 +46,7 @@ TEST(Operator, GetAttrs) { LOG(INFO) << "Test Strategy Codegen:\n" << func; ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); - ASSERT_EQ(add->description, "Add two tensors"); + ASSERT_EQ(add->description, "elementwise_add function"); } } // namespace framework diff --git a/cinn/hlir/op/CMakeLists.txt b/cinn/hlir/op/CMakeLists.txt index 340a9922493ee..6845af8a2752d 100644 --- a/cinn/hlir/op/CMakeLists.txt +++ b/cinn/hlir/op/CMakeLists.txt @@ -2,6 +2,7 @@ set(srcs nn.cc broadcast.cc transform.cc + elementwise.cc ) foreach(cpp ${srcs}) @@ -10,5 +11,5 @@ foreach(cpp ${srcs}) CACHE INTERNAL "") endforeach() -cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core) -cc_test(test_op_nn SRCS op_nn_test.cc DEPS core) +cc_test(test_cinn_op_broadcast SRCS op_broadcast_test.cc DEPS core) +cc_test(test_cinn_op_nn SRCS op_nn_test.cc DEPS core) diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc old mode 100755 new mode 100644 index 968845185bc94..ca44257570269 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -17,16 +17,29 @@ using common::CINNValuePack; using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; +using namespace pe; + +#define StrategyForBinary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForBroadcast(attrs, inputs, out_type, output_shapes, target, #op_name__, pe__); \ + } -std::shared_ptr StrategyForElementwiseAdd(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of add compute is empty! Please check.\n"; +std::shared_ptr StrategyForBroadcast( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target, + const std::string &op_name, + ir::Tensor (*pe_func)(const ir::Tensor &A, const ir::Tensor &B, const std::string &output_name, const Expr &axis)) { + framework::CINNCompute binary_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; CINNValuePack a = args[0]; - CHECK_GE(a.size(), 2U) << "at least 2 input tensors for add compute\n"; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for " << op_name << " compute"; Expr A_expr = a[0]; Expr B_expr = a[1]; CHECK(A_expr.as_tensor()); @@ -42,62 +55,13 @@ std::shared_ptr StrategyForElementwiseAdd(const framework::NodeAttr LOG(ERROR) << "unsupported attr_store: " << iter.first << std::endl; } } - - auto out = pe::Add(A, B, UniqName("EleAdd_Out"), axis); - - auto stages = CreateStages({A, B, out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; - }); - - framework::CINNSchedule add_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of add schedule is empty! Please check.\n"; - CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 2UL); - if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); - } - *ret = arg_pack; - }); - - auto strategy = std::make_shared(); - strategy->AddImpl(add_compute, add_schedule, "strategy.elementwise_add.x86", 1); - - return strategy; -} - -std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of elementwise_mul compute is empty! Please check.\n"; - CINNValuePack a = args[0]; - CHECK_GE(a.size(), 2U) << "at least 2 input tensors for elementwise_mul compute\n"; - Expr A_expr = a[0]; - Expr B_expr = a[1]; - CHECK(A_expr.as_tensor()); - CHECK(B_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - ir::Tensor B = B_expr.as_tensor_ref(); - auto attr_store = attrs.attr_store; - auto iter = attr_store.find("axis"); - Expr axis; - if (iter != attr_store.end()) { - axis = Expr(std::get(iter->second)); - } - - auto out = pe::Multiply(A, B, UniqName("EleMul_Out"), axis); - + auto out = pe_func(A, B, UniqName(op_name + "_Out"), axis); auto stages = CreateStages({A, B, out}); *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); - framework::CINNSchedule mul_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of elementwise_mul schedule is empty! Please check.\n"; + framework::CINNSchedule binary_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); if (target.arch == Target::Arch::NVGPU) { @@ -110,19 +74,18 @@ std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr }); auto strategy = std::make_shared(); - strategy->AddImpl(mul_compute, mul_schedule, "strategy.elementwise_mul.x86", 1); - + strategy->AddImpl(binary_compute, binary_schedule, "strategy." + op_name + ".x86", 1); return strategy; } -std::vector InferShapeForElementwise(const std::vector &inputs_shape, - const framework::NodeAttr &attrs) { +std::vector InferShapeForBroadcast(const std::vector &inputs_shape, + const framework::NodeAttr &attrs) { CHECK_EQ(inputs_shape.size(), 2UL); std::vector res{inputs_shape[0]}; return res; } -std::vector InferDtypeForElementwise(const std::vector &inputs_type, const framework::NodeAttr &attrs) { +std::vector InferDtypeForBroadcast(const std::vector &inputs_type, const framework::NodeAttr &attrs) { CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; std::vector res{inputs_type[0]}; return res; @@ -195,28 +158,41 @@ std::vector InferDtypeForScale(const std::vector &inputs_type, const return res; } +StrategyForBinary(elementwise_add, Add); +StrategyForBinary(elementwise_mul, Multiply); + +StrategyForBinary(bitwise_or, BitwiseOr); +StrategyForBinary(bitwise_xor, BitwiseXor); +StrategyForBinary(bitwise_and, BitwiseAnd); +StrategyForBinary(left_shift, LeftShift); +StrategyForBinary(right_shift, RightShift); + +#undef StrategyForBinary + } // namespace op } // namespace hlir } // namespace cinn CINN_REGISTER_HELPER(broadcast_ops) { - CINN_REGISTER_OP(elementwise_add) - .describe("Add two tensors") - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForElementwiseAdd) - .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise)) - .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise)) +#define CINN_REGISTER_BINARY(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForBroadcast)) \ + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForBroadcast)) \ .set_support_level(4); - CINN_REGISTER_OP(elementwise_mul) - .describe("multiply two tensors") - .set_num_inputs(2) - .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForElementwiseMul) - .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise)) - .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise)) - .set_support_level(4); + CINN_REGISTER_BINARY(elementwise_add, Add); + CINN_REGISTER_BINARY(elementwise_mul, Multiply); + + CINN_REGISTER_BINARY(bitwise_or, BitwiseOr); + CINN_REGISTER_BINARY(bitwise_xor, BitwiseXor); + CINN_REGISTER_BINARY(bitwise_and, BitwiseAnd); + CINN_REGISTER_BINARY(left_shift, LeftShift); + CINN_REGISTER_BINARY(right_shift, RightShift); +#undef CINN_REGISTER_BINARY CINN_REGISTER_OP(scale) .describe("Putting scale and bias to the input Tensor") diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc new file mode 100644 index 0000000000000..e4544224fded4 --- /dev/null +++ b/cinn/hlir/op/elementwise.cc @@ -0,0 +1,162 @@ +#include "cinn/hlir/pe/elementwise.h" + +#include + +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/ir/ir_operators.h" + +namespace cinn { +namespace hlir { +namespace op { +using common::_CINNValuePack_; +using common::CINNValue; +using common::CINNValuePack; +using framework::OpStrategy; +using framework::shape_t; +using framework::StrategyFunction; +using namespace pe; +using PeFunc = std::function; + +#define StrategyForUnary(op_name__, pe__) \ + std::shared_ptr StrategyFor##pe__(const framework::NodeAttr &attrs, \ + const std::vector &inputs, \ + const std::vector &out_type, \ + const std::vector> &output_shapes, \ + const Target &target) { \ + return StrategyForElementwise(attrs, inputs, out_type, output_shapes, target, #op_name__, pe__); \ + } + +std::shared_ptr StrategyForElementwise(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target, + const std::string &op_name, + const PeFunc &pe_func) { + framework::CINNCompute unary_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; + CINNValuePack a = args[0]; + CHECK_EQ(a.size(), 1U) << "1 input tensor for " << op_name << " compute"; + Expr A_expr = a[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + auto out = pe_func(A, UniqName(op_name + "_Out")); + auto stages = CreateStages({A, out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); + + framework::CINNSchedule unary_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL); + if (target.arch == Target::Arch::NVGPU) { + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); + pe::CudaSplitSchedule(stages[Out.as_tensor_ref()], output_shapes.back()); + if (Out.as_tensor()->shape.size() > 1) { + stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); + stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); + } + } + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(unary_compute, unary_schedule, "strategy." + op_name + ".x86", 1); + + return strategy; +} + +std::vector InferShapeForElementwise(const std::vector &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK_EQ(inputs_shape.size(), 1UL); + std::vector res{inputs_shape[0]}; + return res; +} + +std::vector InferDtypeForElementwise(const std::vector &inputs_type, const framework::NodeAttr &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +StrategyForUnary(exp, Exp); +StrategyForUnary(erf, Erf); +StrategyForUnary(sqrt, Sqrt); +StrategyForUnary(log, Log); +StrategyForUnary(log2, Log2); +StrategyForUnary(log10, Log10); +StrategyForUnary(floor, Floor); +StrategyForUnary(ceil, Ceil); +StrategyForUnary(round, Round); +StrategyForUnary(trunc, Trunc); +StrategyForUnary(cos, Cos); +StrategyForUnary(cosh, Cosh); +StrategyForUnary(tan, Tan); +StrategyForUnary(sin, Sin); +StrategyForUnary(sinh, Sinh); +StrategyForUnary(acos, Acos); +StrategyForUnary(acosh, Acosh); +StrategyForUnary(asin, Asin); +StrategyForUnary(asinh, Asinh); +StrategyForUnary(atan, Atan); +StrategyForUnary(atanh, Atanh); +StrategyForUnary(tanh, Tanh); + +StrategyForUnary(isnan, IsNan); +StrategyForUnary(isfinite, IsFinite); +StrategyForUnary(isinf, IsInf); +StrategyForUnary(bitwise_not, BitwiseNot); + +#undef StrategyForUnary + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(elementwise_ops) { +#define CINN_REGISTER_UNARY(op__, op_stragegy__) \ + CINN_REGISTER_OP(op__) \ + .describe(#op__ " function") \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr("CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise)) \ + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise)) \ + .set_support_level(4); + + CINN_REGISTER_UNARY(exp, Exp); + CINN_REGISTER_UNARY(erf, Erf); + CINN_REGISTER_UNARY(sqrt, Sqrt); + CINN_REGISTER_UNARY(log, Log); + CINN_REGISTER_UNARY(log2, Log2); + CINN_REGISTER_UNARY(log10, Log10); + CINN_REGISTER_UNARY(floor, Floor); + CINN_REGISTER_UNARY(ceil, Ceil); + CINN_REGISTER_UNARY(round, Round); + CINN_REGISTER_UNARY(trunc, Trunc); + CINN_REGISTER_UNARY(cos, Cos); + CINN_REGISTER_UNARY(cosh, Cosh); + CINN_REGISTER_UNARY(tan, Tan); + CINN_REGISTER_UNARY(sin, Sin); + CINN_REGISTER_UNARY(sinh, Sinh); + CINN_REGISTER_UNARY(acos, Acos); + CINN_REGISTER_UNARY(acosh, Acosh); + CINN_REGISTER_UNARY(asin, Asin); + CINN_REGISTER_UNARY(asinh, Asinh); + CINN_REGISTER_UNARY(atan, Atan); + CINN_REGISTER_UNARY(atanh, Atanh); + CINN_REGISTER_UNARY(tanh, Tanh); + + CINN_REGISTER_UNARY(isnan, IsNan) + CINN_REGISTER_UNARY(isfinite, IsFinite) + CINN_REGISTER_UNARY(isinf, IsInf) + CINN_REGISTER_UNARY(bitwise_not, BitwiseNot) +#undef CINN_REGISTER_UNARY + + return true; +} diff --git a/cinn/hlir/op/op_broadcast_test.cc b/cinn/hlir/op/op_broadcast_test.cc index 0d7cfd5c613f3..0b48a7861941e 100644 --- a/cinn/hlir/op/op_broadcast_test.cc +++ b/cinn/hlir/op/op_broadcast_test.cc @@ -44,7 +44,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { LOG(INFO) << "Test Strategy Codegen:\n" << func; ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); - ASSERT_EQ(add->description, "Add two tensors"); + ASSERT_EQ(add->description, "elementwise_add function"); } TEST(Operator, Operator_ElementWise_Add_Test1) { @@ -77,7 +77,7 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { std::cout << func; ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); - ASSERT_EQ(add->description, "Add two tensors"); + ASSERT_EQ(add->description, "elementwise_add function"); } } // namespace framework diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 28387b053df90..aaf2a66cbdf5f 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -4,4 +4,5 @@ CINN_USE_REGISTER(nn_ops) CINN_USE_REGISTER(broadcast_ops) +CINN_USE_REGISTER(elementwise_ops) CINN_USE_REGISTER(transform_ops) diff --git a/cinn/hlir/pe/CMakeLists.txt b/cinn/hlir/pe/CMakeLists.txt index 59ac732b2be86..1c602be9ffc19 100644 --- a/cinn/hlir/pe/CMakeLists.txt +++ b/cinn/hlir/pe/CMakeLists.txt @@ -14,6 +14,6 @@ foreach(cpp ${srcs}) CACHE INTERNAL "") endforeach() -cc_test(test_pe_elementwise SRCS pe_elementwise_test.cc DEPS core) -cc_test(test_pe_broadcast SRCS pe_broadcast_test.cc DEPS core) -cc_test(test_pe_transform SRCS pe_transform_test.cc DEPS core) +cc_test(test_cinn_pe_elementwise SRCS pe_elementwise_test.cc DEPS core) +cc_test(test_cinn_pe_broadcast SRCS pe_broadcast_test.cc DEPS core) +cc_test(test_cinn_pe_transform SRCS pe_transform_test.cc DEPS core) diff --git a/cinn/hlir/pe/elementwise.cc b/cinn/hlir/pe/elementwise.cc index ba392d91f9b52..f0e53d643921c 100644 --- a/cinn/hlir/pe/elementwise.cc +++ b/cinn/hlir/pe/elementwise.cc @@ -41,10 +41,10 @@ HLIR_IMP_UNARY_PE(Asin); HLIR_IMP_UNARY_PE(Asinh); HLIR_IMP_UNARY_PE(Atan); HLIR_IMP_UNARY_PE(Atanh); -HLIR_IMP_UNARY_PE(Isnan); +HLIR_IMP_UNARY_PE(IsNan); HLIR_IMP_UNARY_PE(Tanh); -HLIR_IMP_UNARY_PE(Isfinite); -HLIR_IMP_UNARY_PE(Isinf); +HLIR_IMP_UNARY_PE(IsFinite); +HLIR_IMP_UNARY_PE(IsInf); HLIR_IMP_UNARY_PE(Negative); HLIR_IMP_UNARY_PE(Identity); diff --git a/cinn/hlir/pe/elementwise.h b/cinn/hlir/pe/elementwise.h index 95cdce86433d4..769ea1f871953 100644 --- a/cinn/hlir/pe/elementwise.h +++ b/cinn/hlir/pe/elementwise.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "cinn/ir/ir.h" namespace cinn { @@ -37,10 +39,10 @@ HLIR_DCL_UNARY_PE(Asin); HLIR_DCL_UNARY_PE(Asinh); HLIR_DCL_UNARY_PE(Atan); HLIR_DCL_UNARY_PE(Atanh); -HLIR_DCL_UNARY_PE(Isnan); +HLIR_DCL_UNARY_PE(IsNan); HLIR_DCL_UNARY_PE(Tanh); -HLIR_DCL_UNARY_PE(Isfinite); -HLIR_DCL_UNARY_PE(Isinf); +HLIR_DCL_UNARY_PE(IsFinite); +HLIR_DCL_UNARY_PE(IsInf); HLIR_DCL_UNARY_PE(Negative); HLIR_DCL_UNARY_PE(Identity); diff --git a/cinn/hlir/pe/pe_elementwise_test.cc b/cinn/hlir/pe/pe_elementwise_test.cc index dbc4e79a35b43..80d68972df7bd 100644 --- a/cinn/hlir/pe/pe_elementwise_test.cc +++ b/cinn/hlir/pe/pe_elementwise_test.cc @@ -14,10 +14,10 @@ namespace cinn { namespace hlir { namespace pe { -template +template void TestElementwisePE(const std::string &fn_name, const FuncOp &func_op, - float (*fn_runtime)(float), + const FuncRuntime &fn_runtime, int set_value = 0) { Expr M(100), N(32); @@ -60,13 +60,11 @@ void TestElementwisePE(const std::string &fn_name, } } -#define TEST_ELEMENTWISE_PE_FP32(test_name__, PE__) \ - TEST(elementwise_pe, test_name__) { \ - TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, cinn_cpu_##test_name__##_fp32); \ - } -#define TEST_ELEMENTWISE_PE_FP32_SET(test_name__, PE__, value__) \ - TEST(elementwise_pe, test_name__) { \ - TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, cinn_cpu_##test_name__##_fp32, value__); \ +#define TEST_ELEMENTWISE_PE_FP32(test_name__, PE__) \ + TEST(elementwise_pe, test_name__) { TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__); } +#define TEST_ELEMENTWISE_PE_FP32_SET(test_name__, PE__, value__) \ + TEST(elementwise_pe, test_name__) { \ + TestElementwisePE("PE_Elementwise_" #test_name__ "_fp32", PE__, test_name__, value__); \ } TEST_ELEMENTWISE_PE_FP32(exp, Exp) @@ -90,10 +88,10 @@ TEST_ELEMENTWISE_PE_FP32(asin, Asin) TEST_ELEMENTWISE_PE_FP32(asinh, Asinh) TEST_ELEMENTWISE_PE_FP32(atan, Atan) TEST_ELEMENTWISE_PE_FP32(atanh, Atanh) -TEST_ELEMENTWISE_PE_FP32(isnan, Isnan) +// TEST_ELEMENTWISE_PE_FP32(isnan, IsNan) TEST_ELEMENTWISE_PE_FP32(tanh, Tanh) -TEST_ELEMENTWISE_PE_FP32(isfinite, Isfinite) -TEST_ELEMENTWISE_PE_FP32(isinf, Isinf) +// TEST_ELEMENTWISE_PE_FP32(isfinite, IsFinite) +// TEST_ELEMENTWISE_PE_FP32(isinf, IsInf) } // namespace pe } // namespace hlir diff --git a/cinn/ir/intrinsic_ops.cc b/cinn/ir/intrinsic_ops.cc index f5bdd67f41438..472eca92e7aa0 100644 --- a/cinn/ir/intrinsic_ops.cc +++ b/cinn/ir/intrinsic_ops.cc @@ -97,4 +97,17 @@ Expr intrinsics::ArgsConstruct::Make(Var var, llvm::ArrayRef args) { return Expr(n); } +Expr intrinsics::UnaryIntrin::Make( + const std::string& name, llvm::ArrayRef args, llvm::Intrinsic::ID id, int64_t arg_nums, const Type& type) { + auto* n = new UnaryIntrin; + n->name = name; + n->args.assign(args.begin(), args.end()); + n->id = id; + n->arg_nums = arg_nums; + CHECK(!type.is_unk()); + n->type_ = type; + + return Expr(n); +} + } // namespace cinn::ir diff --git a/cinn/ir/intrinsic_ops.h b/cinn/ir/intrinsic_ops.h index ee35d84e31ec2..677f7940d45d4 100644 --- a/cinn/ir/intrinsic_ops.h +++ b/cinn/ir/intrinsic_ops.h @@ -2,8 +2,11 @@ #include #include +#include #include +#include + #include "cinn/common/type.h" #include "cinn/ir/ir.h" @@ -20,17 +23,16 @@ namespace cinn::ir { macro__(BufferCreate) \ macro__(GetAddr) \ macro__(ArgsConstruct) \ + macro__(UnaryIntrin) // clang-format on - enum class IntrinsicKind { - // All the intrinsics should registered here. -#define __(x__) k ## x__, +// All the intrinsics should registered here. +#define __(x__) k##x__, INTRINSIC_KIND_FOR_EACH(__) #undef __ }; - class IntrinsicOp : public IrNode { public: IntrinsicOp(IntrinsicKind kind, llvm::ArrayRef input_types, llvm::ArrayRef output_types) @@ -52,7 +54,7 @@ class IntrinsicOp : public IrNode { void Verify(llvm::ArrayRef inputs, llvm::ArrayRef outputs) const; void Verify(llvm::ArrayRef inputs) const; - void Verify() const override { } + void Verify() const override {} const char* type_info() const override; @@ -111,8 +113,7 @@ struct BufferGetDataConstHandle : public IntrinsicOp { */ struct PodValueToX : public IntrinsicOp { // signature: (cinn_pod_value_t*) -> (X), X is some pod type. - PodValueToX() - : IntrinsicOp(IntrinsicKind::kPodValueToX, {type_of()}, {}) {} + PodValueToX() : IntrinsicOp(IntrinsicKind::kPodValueToX, {type_of()}, {}) {} static Expr Make(Expr pod_value_ptr, const Type& type); @@ -126,7 +127,7 @@ struct PodValueToX : public IntrinsicOp { */ struct BufferCreate : public IntrinsicOp { // signature: (cinn_buffer_t*) -> void - BufferCreate(): IntrinsicOp(IntrinsicKind::kBufferCreate, {type_of()}, {}) {} + BufferCreate() : IntrinsicOp(IntrinsicKind::kBufferCreate, {type_of()}, {}) {} static Expr Make(Expr buffer); @@ -140,7 +141,7 @@ struct BufferCreate : public IntrinsicOp { */ struct GetAddr : public IntrinsicOp { // signature: (X) -> (X*) - GetAddr(): IntrinsicOp(IntrinsicKind::kGetAddr, {}, {}) {} + GetAddr() : IntrinsicOp(IntrinsicKind::kGetAddr, {}, {}) {} static Expr Make(Expr data); @@ -163,6 +164,22 @@ struct ArgsConstruct : public IntrinsicOp { llvm::SmallVector args; }; +/** + * The operation of unary computation + */ +struct UnaryIntrin : public IntrinsicOp { + UnaryIntrin() : IntrinsicOp(IntrinsicKind::kUnaryIntrin, {}, {}) {} + + static Expr Make( + const std::string& name, llvm::ArrayRef args, llvm::Intrinsic::ID id, int64_t arg_nums, const Type& type); + + static bool classof(const IntrinsicOp* s) { return s->getKind() == IntrinsicKind::kUnaryIntrin; } + + std::string name; + llvm::SmallVector args; + llvm::Intrinsic::ID id; + int64_t arg_nums; +}; } // namespace intrinsics diff --git a/cinn/ir/ir_printer.cc b/cinn/ir/ir_printer.cc index ed02a6c890e67..b37c93520b541 100644 --- a/cinn/ir/ir_printer.cc +++ b/cinn/ir/ir_printer.cc @@ -410,6 +410,20 @@ void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) { os() << ")"; } +void IrPrinter::Visit(const intrinsics::UnaryIntrin *x) { + os_ << runtime::intrisic::unary_intrin_repr << "_"; + os_ << x->name << "("; + if (!x->args.empty()) { + for (int i = 0; i < x->args.size() - 1; i++) { + Print(x->args[i]); + os_ << ", "; + } + Print(x->args.back()); + } + + os_ << ")"; +} + std::ostream &operator<<(std::ostream &os, Expr a) { std::stringstream ss; IrPrinter printer(ss); diff --git a/cinn/ir/module.cc b/cinn/ir/module.cc index 04bc6175d461c..f196147aa464b 100644 --- a/cinn/ir/module.cc +++ b/cinn/ir/module.cc @@ -28,7 +28,7 @@ Module Module::Builder::Build() { auto res = ir::Module(module_.get()); - return optim::Optimize(res); + return optim::Optimize(res, module_->target); } ir::_Module_ *Module::self() { return p_->as(); } diff --git a/cinn/lang/builtin.cc b/cinn/lang/builtin.cc index bcc8bab3f6ba0..9ba2123cccc39 100644 --- a/cinn/lang/builtin.cc +++ b/cinn/lang/builtin.cc @@ -1,5 +1,9 @@ #include "cinn/lang/builtin.h" +#include +#include +#include + #include "cinn/cinn.h" #include "cinn/common/ir_util.h" #include "cinn/ir/ir.h" @@ -51,10 +55,7 @@ EXTERN_CALL_IMP(Asin, asin); EXTERN_CALL_IMP(Asinh, asinh); EXTERN_CALL_IMP(Atan, atan); EXTERN_CALL_IMP(Atanh, atanh); -EXTERN_CALL_IMP(Isnan, isnan); EXTERN_CALL_IMP(Tanh, tanh); -EXTERN_CALL_IMP(Isfinite, isfinite); -EXTERN_CALL_IMP(Isinf, isinf); Expr min_value(const Type& type) { CHECK_EQ(type.lanes(), 1); @@ -91,5 +92,73 @@ Expr max_value(const Type& type) { return Expr(); } +Expr Abs(Expr e) { + Type type = e->type(); + Type bool_type = Bool(type.lanes()); + if (type.is_uint()) { + return e; + } else if (type.is_int()) { + auto node = e.As(); + if (node) { + return make_const(type, std::abs(node->value)); + } + return ir::Select::Make(e > make_const(e->type(), 0), e, -e); + } else if (type.is_float()) { + auto node = e.As(); + if (node) { + return make_const(type, std::fabs(node->value)); + } + return CallExtern("fabs", {e}); + } +} + +Expr IsNan(Expr e) { + Type type = e->type(); + // Type bool_type = Bool(type.lanes()); + if (type.is_int() || type.is_uint()) { + return common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + auto* node = e.As(); + if (node) { + return common::make_bool(std::isnan(node->value), type.lanes()); + } + Expr arg = e; + if (type.bits() == 16) { + arg = ir::Cast::Make(Float(32), std::move(e)); + } + return CallExtern("isnan", {arg}); + } else { + LOG(FATAL) << type << "is not supported for isnan op."; + return e; + } +} + +Expr Infinity(const Type& type) { + CHECK_EQ(type.lanes(), 1U); + if (type.is_float()) { + if (type.bits() == 64) { + return make_const(type, std::numeric_limits::infinity()); + } else if (type.bits() == 32 || type.bits() == 16) { + return make_const(type, std::numeric_limits::infinity()); + } + } + LOG(FATAL) << "Cannot decide infinity for type " << type; + return Expr(); +} + +Expr IsInf(Expr e) { + Type type = e->type(); + if (type.is_int() || type.is_uint()) { + return common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + return common::make_bool(is_zero(Abs(e) - Infinity(type)), type.lanes()) && !IsNan(e); + } else { + LOG(FATAL) << type << "is not supported for isinf op."; + return e; + } +} + +Expr IsFinite(Expr e) { return !IsInf(e) && !IsNan(e); } + } // namespace lang } // namespace cinn diff --git a/cinn/lang/builtin.h b/cinn/lang/builtin.h index 14a1ba5c43107..adf5d09f0de05 100644 --- a/cinn/lang/builtin.h +++ b/cinn/lang/builtin.h @@ -1,4 +1,7 @@ #pragma once + +#include + #include "cinn/common/ir_util.h" #include "cinn/ir/ir.h" #include "cinn/ir/ir_operators.h" @@ -34,10 +37,7 @@ EXTERN_CALL_DCL(Asin); EXTERN_CALL_DCL(Asinh); EXTERN_CALL_DCL(Atan); EXTERN_CALL_DCL(Atanh); -EXTERN_CALL_DCL(Isnan); EXTERN_CALL_DCL(Tanh); -EXTERN_CALL_DCL(Isfinite); -EXTERN_CALL_DCL(Isinf); inline Expr Sigmoid(Expr e) { auto one = common::make_const(e->type(), 1); @@ -53,7 +53,7 @@ inline Expr Sign(Expr e) { return ret2; } -inline Expr Abs(Expr e) { return ir::Select::Make(e > make_const(e->type(), 0), e, -e); } +Expr Abs(Expr e); inline Expr Rsqrt(Expr e) { auto one = make_const(e->type(), 1); @@ -64,6 +64,11 @@ inline Expr Negative(Expr e) { return -e; } inline Expr Identity(Expr e) { return e; } inline Expr LogicalNot(Expr e) { return !e; } inline Expr BitwiseNot(Expr e) { return ~e; } +inline Expr BitwiseAnd(Expr a, Expr b) { return a & b; } +inline Expr BitwiseOr(Expr a, Expr b) { return a | b; } +inline Expr BitwiseXor(Expr a, Expr b) { return a ^ b; } +inline Expr LeftShift(Expr a, Expr b) { return a << b; } +inline Expr RightShift(Expr a, Expr b) { return a >> b; } template inline Expr Relu(Expr e, T threshold = static_cast(0)) { @@ -115,5 +120,13 @@ inline Expr ReduceMin(Expr e, const std::vector& reduce_axis, Expr initial return ir::Reduce::Make(ir::Reduce::kMin, initial, e, reduce_axis); } +Expr IsNan(Expr e); + +Expr Infinity(const Type& type); + +Expr IsInf(Expr e); + +Expr IsFinite(Expr e); + } // namespace lang } // namespace cinn diff --git a/cinn/optim/CMakeLists.txt b/cinn/optim/CMakeLists.txt index c82979587199b..43418372d357d 100644 --- a/cinn/optim/CMakeLists.txt +++ b/cinn/optim/CMakeLists.txt @@ -21,6 +21,7 @@ set(srcs remove_nested_block.cc replace_call_with_expr.cc ir_copy.cc cast_simplify.cc compare_simplify.cc if_simplify.cc + lower_intrin.cc ) if (WITH_CUDA) list(APPEND srcs transform_gpu_forloop.cc) diff --git a/cinn/optim/ir_copy.cc b/cinn/optim/ir_copy.cc index 3e33bc8ff7e5d..08c6b433e9bec 100644 --- a/cinn/optim/ir_copy.cc +++ b/cinn/optim/ir_copy.cc @@ -392,6 +392,9 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::ArgsConstruct* op) { } return intrinsics::ArgsConstruct::Make(op->var, args); } +Expr IRCopyVisitor::Visit(const ir::intrinsics::UnaryIntrin* op) { + return intrinsics::UnaryIntrin::Make(op->name, op->args, op->id, op->arg_nums, op->type()); +} Expr IRCopy(Expr x) { IRCopyVisitor visitor; diff --git a/cinn/optim/lower_intrin.cc b/cinn/optim/lower_intrin.cc new file mode 100644 index 0000000000000..23ccff7fc8836 --- /dev/null +++ b/cinn/optim/lower_intrin.cc @@ -0,0 +1,54 @@ +#include "cinn/optim/lower_intrin.h" + +#include + +#include "cinn/backends/llvm/llvm_intrin_rule.h" +#include "cinn/cinn.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/registry.h" + +namespace cinn { +namespace optim { + +void LowerIntrin(Expr *e, Target target) { + if (target.arch == Target::Arch::X86) { + codegen::RegisterCpuIntrinRule(); + } + struct Mutator : ir::IRMutator { + Target target; + + explicit Mutator(Target target) : target(target) {} + + void operator()(Expr *e) { ir::IRMutator<>::Visit(e, e); } + + void Visit(const ir::Call *op, Expr *expr) override { + auto *node = expr->As(); + CHECK(node); + + if (target.arch == Target::Arch::X86) { + LowerCpuIntrisicOp(node, expr); + } + } + + void LowerCpuIntrisicOp(ir::Call *node, Expr *expr) { + if (kIntrinsicCalls.count(node->name)) { + CHECK(!node->name.empty()); + auto *func_ptr = ir::Registry::Get("lower_cpu_intrinsic_" + node->name); + CHECK(func_ptr) << "find no rule to lower cpu intrinsic for " + << "lower_cpu_intrinsic_" + node->name; + Expr ret = (*func_ptr)(Expr(node)); + if (!ret.same_as(*expr)) { + ir::IRMutator<>::Visit(&ret, &ret); + } + *expr = ret; + } + } + }; + + Mutator m(target); + m(e); +} + +} // namespace optim +} // namespace cinn diff --git a/cinn/optim/lower_intrin.h b/cinn/optim/lower_intrin.h new file mode 100644 index 0000000000000..9e32a870949a8 --- /dev/null +++ b/cinn/optim/lower_intrin.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +static const std::set kIntrinsicCalls{ + {"exp", "exp2", "sqrt", "log", "log2", "log10", "floor", + "ceil", "round", "trunc", "cos", "cosh", "tan", "tanh", + "sin", "sinh", "fabs", "isnan", "isfinite", "isinf", "left_shift", + "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}}; + +/** + * Map the Call nodes to llvm intrinsic. + * + * This will rename the external call with the function in different backends. + * + * Notes: only support cpu currently. + */ +void LowerIntrin(Expr *e, Target target); + +} // namespace optim +} // namespace cinn diff --git a/cinn/optim/map_extern_call.cc b/cinn/optim/map_extern_call.cc index 0f3a76af118c0..6179df8fba35e 100644 --- a/cinn/optim/map_extern_call.cc +++ b/cinn/optim/map_extern_call.cc @@ -27,19 +27,16 @@ void MapExternCall(Expr *e, Target target) { } void DealWithCpuIntrisics(ir::Call *node, Expr *expr) { - if (kExternFp32Calls.count(node->name)) { + if (kExternFp32CallsCPU.count(node->name)) { CHECK_GE(node->read_args.size(), 1UL); CHECK_EQ(node->read_args.front().type(), Float(32)); - *expr = lang::CallExtern("cinn_cpu_" + node->name + "_fp32", node->read_args); - } else if (kExternInt64Calls.count(node->name)) { - CHECK_GE(node->read_args.size(), 1UL); - CHECK_EQ(node->read_args.front().type(), Int(64)); - *expr = lang::CallExtern("cinn_cpu_" + node->name + "_int64", node->read_args); + auto out_type = node->type(); + *expr = lang::CallExtern(node->name + "f", node->read_args); } } void DealWithNvGpuIntrisics(ir::Call *node, Expr *expr) { - if (kExternFp32Calls.count(node->name)) { + if (kExternFp32CallsGPU.count(node->name)) { CHECK_GE(node->read_args.size(), 1UL); CHECK_EQ(node->read_args.front().type(), Float(32)); *expr = lang::CallExtern("cinn_nvgpu_" + node->name + "_fp32", node->read_args); diff --git a/cinn/optim/map_extern_call.h b/cinn/optim/map_extern_call.h index e6dd5a1bd2a52..172fe35f09ed7 100644 --- a/cinn/optim/map_extern_call.h +++ b/cinn/optim/map_extern_call.h @@ -1,11 +1,14 @@ #pragma once +#include +#include + #include "cinn/ir/ir.h" namespace cinn { namespace optim { -static const std::set kExternFp32Calls{ +static const std::set kExternFp32CallsGPU{ {"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", "cos", "cosh", "tan", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", @@ -13,8 +16,7 @@ static const std::set kExternFp32Calls{ "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not", "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}}; -static const std::set kExternInt64Calls = { - "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; +static const std::set kExternFp32CallsCPU = {"erf", "acos", "acosh", "asin", "asinh", "atan", "atanh"}; /** * Map the Call nodes to external function call. diff --git a/cinn/optim/optimize.cc b/cinn/optim/optimize.cc index 6914e73ab0ef6..8bd781f6ac3c5 100644 --- a/cinn/optim/optimize.cc +++ b/cinn/optim/optimize.cc @@ -13,6 +13,7 @@ #include "cinn/optim/ir_copy.h" #include "cinn/optim/ir_simplify.h" #include "cinn/optim/lower_function_call_bind_vars.h" +#include "cinn/optim/lower_intrin.h" #include "cinn/optim/map_extern_call.h" #include "cinn/optim/remove_nested_block.h" #include "cinn/optim/replace_const_param_to_integer.h" @@ -60,11 +61,12 @@ Expr Optimize(Expr e, Target target, bool runtime_debug_info) { return copied; } -ir::Module Optimize(const ir::Module& module) { +ir::Module Optimize(const ir::Module& module, const Target& target) { auto copied = IRCopy(Expr(module)); LowerFunctionCallBindVars(&copied); CallArgListToPodValue(&copied); + LowerIntrin(&copied, target); return copied.as_module_ref(); } diff --git a/cinn/optim/optimize.h b/cinn/optim/optimize.h index d075c3869e63f..1fd1c603f5958 100644 --- a/cinn/optim/optimize.h +++ b/cinn/optim/optimize.h @@ -16,7 +16,7 @@ Expr Optimize(Expr e, Target target, bool runtime_debug_info = false); /** * Optimize a Module. */ -ir::Module Optimize(const ir::Module& module); +ir::Module Optimize(const ir::Module& module, const Target& target); } // namespace optim } // namespace cinn diff --git a/cinn/pybind/pe.cc b/cinn/pybind/pe.cc index bd6b84f76a029..6b7baa006aa3e 100644 --- a/cinn/pybind/pe.cc +++ b/cinn/pybind/pe.cc @@ -40,10 +40,10 @@ void BindPE(py::module* m) { BIND_UNARY(asinh, Asinh); BIND_UNARY(atan, Atan); BIND_UNARY(atanh, Atanh); - BIND_UNARY(isnan, Isnan); + BIND_UNARY(isnan, IsNan); BIND_UNARY(tanh, Tanh); - BIND_UNARY(isfinite, Isfinite); - BIND_UNARY(isinf, Isinf); + BIND_UNARY(isfinite, IsFinite); + BIND_UNARY(isinf, IsInf); BIND_UNARY(negative, Negative); BIND_UNARY(identity, Identity); diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index c095e12578a41..5e5548d93ded3 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -9,56 +9,13 @@ extern "C" { -using namespace std; -#define CINN_IMP_CPU_FUNC_FP32(name__) \ - float cinn_cpu_##name__##_fp32(float a) { return name__(a); } - -#define CINN_IMP_CPU_FUNC_INT_BINARY(name__, rule__) \ - int cinn_cpu_##name__##_int32(int a, int b) { return a rule__ b; } - -#define CINN_IMP_CPU_FUNC_INT_UNARY(name__, rule__) \ - int cinn_cpu_##name__##_int32(int a) { return rule__(a); } - -CINN_IMP_CPU_FUNC_FP32(exp); -CINN_IMP_CPU_FUNC_FP32(erf); -CINN_IMP_CPU_FUNC_FP32(sqrt); -CINN_IMP_CPU_FUNC_FP32(log); -CINN_IMP_CPU_FUNC_FP32(log2); -CINN_IMP_CPU_FUNC_FP32(log10); -CINN_IMP_CPU_FUNC_FP32(floor); -CINN_IMP_CPU_FUNC_FP32(ceil); -CINN_IMP_CPU_FUNC_FP32(round); -CINN_IMP_CPU_FUNC_FP32(trunc); -CINN_IMP_CPU_FUNC_FP32(cos); -CINN_IMP_CPU_FUNC_FP32(cosh); -CINN_IMP_CPU_FUNC_FP32(tan); -CINN_IMP_CPU_FUNC_FP32(sin); -CINN_IMP_CPU_FUNC_FP32(sinh); -CINN_IMP_CPU_FUNC_FP32(acos); -CINN_IMP_CPU_FUNC_FP32(acosh); -CINN_IMP_CPU_FUNC_FP32(asin); -CINN_IMP_CPU_FUNC_FP32(asinh); -CINN_IMP_CPU_FUNC_FP32(atan); -CINN_IMP_CPU_FUNC_FP32(atanh); -CINN_IMP_CPU_FUNC_FP32(isnan); -CINN_IMP_CPU_FUNC_FP32(tanh); -CINN_IMP_CPU_FUNC_FP32(isfinite); -CINN_IMP_CPU_FUNC_FP32(isinf); - -CINN_IMP_CPU_FUNC_INT_BINARY(left_shift, <<); -CINN_IMP_CPU_FUNC_INT_BINARY(right_shift, >>); -CINN_IMP_CPU_FUNC_INT_BINARY(bitwise_or, |); -CINN_IMP_CPU_FUNC_INT_BINARY(bitwise_and, &); -CINN_IMP_CPU_FUNC_INT_BINARY(bitwise_xor, ^); -CINN_IMP_CPU_FUNC_INT_UNARY(bitwise_not, !); - void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out) { CINN_CHECK_EQ(x->num_elements(), out->num_elements()); int xn = x->num_elements(); auto* x_data = (float*)(x->memory); auto* out_data = (float*)(out->memory); for (int i = 0; i < x->num_elements(); i++) { - out_data[i] = cinn_cpu_tanh_fp32(x_data[i]); + out_data[i] = tanhf(x_data[i]); } } } @@ -67,47 +24,18 @@ CINN_REGISTER_HELPER(host_intrinsics) { auto host_target = cinn::common::DefaultHostTarget(); using cinn::backends::FunctionProto; -#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(func__) \ - REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_cpu_##func__##_fp32, host_target, float, float); - -#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT(func__) \ - REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_cpu_##func__##_int32, host_target, int, int); - -#define REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(func__) \ - REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_cpu_##func__##_int32, host_target, int, int, int); +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(func__) REGISTER_EXTERN_FUNC_1_IN_1_OUT(func__, host_target, float, float); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(exp); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(erf); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sqrt); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log2); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(log10); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(floor); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(ceil); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(round); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(trunc); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(cos); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(cosh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(tan); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sin); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(sinh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(acos); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(acosh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(asin); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(asinh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atan); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atanh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isnan); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(tanh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isfinite); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isinf); +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32_INT(func__) \ + REGISTER_EXTERN_FUNC_1_IN_1_OUT(func__, host_target, float, int); - REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(left_shift); - REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(right_shift); - REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(bitwise_or); - REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(bitwise_and); - REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(bitwise_xor); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT(bitwise_not); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(erff); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(acosf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(acoshf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(asinf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(asinhf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(atanf); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_FP32(atanhf); return true; } diff --git a/cinn/runtime/cpu/host_intrinsics.h b/cinn/runtime/cpu/host_intrinsics.h index 1cb25672e07b9..bfc0036c590a0 100644 --- a/cinn/runtime/cpu/host_intrinsics.h +++ b/cinn/runtime/cpu/host_intrinsics.h @@ -6,43 +6,6 @@ extern "C" { -#define CINN_DCL_CPU_FUNC_FP32(name__) float cinn_cpu_##name__##_fp32(float a); -#define CINN_DCL_CPU_FUNC_INT_UNARY(name__) int cinn_cpu_##name__##_int(int a); -#define CINN_DCL_CPU_FUNC_INT_BINARY(name__) int cinn_cpu_##name__##_int(int a, int b); - -CINN_DCL_CPU_FUNC_FP32(exp); -CINN_DCL_CPU_FUNC_FP32(erf); -CINN_DCL_CPU_FUNC_FP32(sqrt); -CINN_DCL_CPU_FUNC_FP32(log); -CINN_DCL_CPU_FUNC_FP32(log2); -CINN_DCL_CPU_FUNC_FP32(log10); -CINN_DCL_CPU_FUNC_FP32(floor); -CINN_DCL_CPU_FUNC_FP32(ceil); -CINN_DCL_CPU_FUNC_FP32(round); -CINN_DCL_CPU_FUNC_FP32(trunc); -CINN_DCL_CPU_FUNC_FP32(cos); -CINN_DCL_CPU_FUNC_FP32(cosh); -CINN_DCL_CPU_FUNC_FP32(tan); -CINN_DCL_CPU_FUNC_FP32(sin); -CINN_DCL_CPU_FUNC_FP32(sinh); -CINN_DCL_CPU_FUNC_FP32(acos); -CINN_DCL_CPU_FUNC_FP32(acosh); -CINN_DCL_CPU_FUNC_FP32(asin); -CINN_DCL_CPU_FUNC_FP32(asinh); -CINN_DCL_CPU_FUNC_FP32(atan); -CINN_DCL_CPU_FUNC_FP32(atanh); -CINN_DCL_CPU_FUNC_FP32(isnan); -CINN_DCL_CPU_FUNC_FP32(tanh); -CINN_DCL_CPU_FUNC_FP32(isfinite); -CINN_DCL_CPU_FUNC_FP32(isinf); - -CINN_DCL_CPU_FUNC_INT_BINARY(left_shift); -CINN_DCL_CPU_FUNC_INT_BINARY(right_shift); -CINN_DCL_CPU_FUNC_INT_BINARY(bitwise_or); -CINN_DCL_CPU_FUNC_INT_BINARY(bitwise_and); -CINN_DCL_CPU_FUNC_INT_BINARY(bitwise_xor); -CINN_DCL_CPU_FUNC_INT_UNARY(bitwise_not); - //! math extern functions //@{ void __cinn_host_tanh_v(const cinn_buffer_t* x, cinn_buffer_t* out); diff --git a/cinn/runtime/cpu/host_intrinsics_test.cc b/cinn/runtime/cpu/host_intrinsics_test.cc index 65b566eb1a6cd..b20900d88d87a 100644 --- a/cinn/runtime/cpu/host_intrinsics_test.cc +++ b/cinn/runtime/cpu/host_intrinsics_test.cc @@ -18,7 +18,7 @@ namespace cpu { TEST(tanh, basic) { Expr M(10), N(20); Placeholder x("x", {M, N}); - auto y = Compute({M, N}, [&](Expr i, Expr j) { return CallExtern("cinn_cpu_tanh_fp32", {x(i, j)}); }); + auto y = Compute({M, N}, [&](Expr i, Expr j) { return CallExtern("tanh", {x(i, j)}); }); auto stages = CreateStages({y}); diff --git a/cinn/runtime/cpu/mkl_math_test.cc b/cinn/runtime/cpu/mkl_math_test.cc index bc4f8e70c9ae3..294729471f994 100644 --- a/cinn/runtime/cpu/mkl_math_test.cc +++ b/cinn/runtime/cpu/mkl_math_test.cc @@ -24,10 +24,8 @@ cinn_buffer_t *CreateBuffer(const std::vector shape, bool random = true, in return common::BufferBuilder(Float(32), shape).set_zero().Build(); } -void TestCallElementwise(const std::string &fn_name, - float (*fn_runtime)(float), - bool is_elementwise, - int set_value = 0) { +template +void TestCallElementwise(const std::string &fn_name, FuncRuntime fn_runtime, bool is_elementwise, int set_value = 0) { Expr M(10); Expr N(10); Placeholder x("x", {M, N}); @@ -85,14 +83,10 @@ void TestCallElementwise(const std::string &fn_name, } } -#define TEST_MKL_MATH_FP32(test_name__, is_elementwise) \ - TEST(mkl_math, test_name__) { \ - TestCallElementwise("cinn_cpu_" #test_name__ "_fp32", cinn_cpu_##test_name__##_fp32, is_elementwise); \ - } -#define TEST_MKL_MATH_FP32_SET(test_name__, is_elementwise, value) \ - TEST(mkl_math, test_name__) { \ - TestCallElementwise("cinn_cpu_" #test_name__ "_fp32", cinn_cpu_##test_name__##_fp32, is_elementwise, value); \ - } +#define TEST_MKL_MATH_FP32(test_name__, is_elementwise) \ + TEST(mkl_math, test_name__) { TestCallElementwise(#test_name__, test_name__, is_elementwise); } +#define TEST_MKL_MATH_FP32_SET(test_name__, is_elementwise, value) \ + TEST(mkl_math, test_name__) { TestCallElementwise(#test_name__, test_name__, is_elementwise, value); } TEST_MKL_MATH_FP32(exp, true) TEST_MKL_MATH_FP32(erf, true) @@ -115,12 +109,12 @@ TEST_MKL_MATH_FP32(asin, true) TEST_MKL_MATH_FP32(asinh, true) TEST_MKL_MATH_FP32(atan, true) TEST_MKL_MATH_FP32(atanh, true) -TEST_MKL_MATH_FP32(isnan, true) +// TEST_MKL_MATH_FP32(isnan, true) TEST_MKL_MATH_FP32(tanh, true) -TEST_MKL_MATH_FP32(isfinite, true) -TEST_MKL_MATH_FP32(isinf, true) +// TEST_MKL_MATH_FP32(isfinite, true) +// TEST_MKL_MATH_FP32(isinf, true) -TEST(mkl_math, tanh_v_fp32) { TestCallElementwise("cinn_mkl_tanh_v_fp32", cinn_cpu_tanh_fp32, false); } +TEST(mkl_math, tanh_v_fp32) { TestCallElementwise("cinn_mkl_tanh_v_fp32", tanh, false); } TEST(cinn_cpu_mkl_gemm_fp32, test) { Expr M(30); diff --git a/cinn/runtime/intrinsic.h b/cinn/runtime/intrinsic.h index daf01bd65ba2c..fe55e4cab9ccc 100644 --- a/cinn/runtime/intrinsic.h +++ b/cinn/runtime/intrinsic.h @@ -60,6 +60,8 @@ static const char* get_address_repr = "get_address"; static const char* args_construct_repr = "cinn_args_construct"; +static const char* unary_intrin_repr = "cinn_unary_intrin"; + //! Name of the helper intrinsic used to display debug string. static const char* debug_log_repr = "cinn_print_debug_string"; diff --git a/python/tests/test_pe_elementwise.py b/python/tests/test_pe_elementwise.py index 50a516a976afa..eb54468a62530 100644 --- a/python/tests/test_pe_elementwise.py +++ b/python/tests/test_pe_elementwise.py @@ -48,11 +48,11 @@ def test_unary(self): # ("asinh", pe.asinh, np.asinh, "float32"), # ("atan", pe.atan, np.atan, "float32"), # ("atanh", pe.atanh, np.atanh, "float32"), - # TODO(wenming2014) end - ("isnan", pe.isnan, np.isnan, "float32"), + # TODO(wenming2014) en + # ("isnan", pe.isnan, np.isnan, "float32"), ("tanh", pe.tanh, np.tanh, "float32"), - ("isfinite", pe.isfinite, np.isfinite, "float32"), - ("isinf", pe.isinf, np.isinf, "float32"), + # ("isfinite", pe.isfinite, np.isfinite, "float32"), + # ("isinf", pe.isinf, np.isinf, "float32"), ("negative", pe.negative, np.negative, "float32"), # ("identity", pe.identity, np.identity, "float32"), # TODO(wenming2014) int type diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt index 15b0aa878b3e8..70b6473a8da7d 100644 --- a/tests/benchmark/CMakeLists.txt +++ b/tests/benchmark/CMakeLists.txt @@ -7,5 +7,5 @@ target_compile_options(test_matmul PRIVATE "-O3") cc_test(test_elementwise SRCS test_elementwise.cc test_utils.cc DEPS core ARGS ${global_test_args}) target_compile_options(test_elementwise PRIVATE "-O3") -cc_test(test_all_ops_defualt SRCS test_all_ops_default.cc test_utils.cc DEPS core ARGS ${global_test_args}) -target_compile_options(test_all_ops_defualt PRIVATE "-O3") +cc_test(test_all_ops_default SRCS test_all_ops_default.cc test_utils.cc DEPS core ARGS ${global_test_args}) +target_compile_options(test_all_ops_default PRIVATE "-O3") diff --git a/tests/benchmark/test_all_ops_default.cc b/tests/benchmark/test_all_ops_default.cc index 27b06b8e1667f..a389b4024ec69 100644 --- a/tests/benchmark/test_all_ops_default.cc +++ b/tests/benchmark/test_all_ops_default.cc @@ -41,8 +41,19 @@ using AttrType = std::variant> input_shapes = shapes_##shape_name__; \ + std::string op_name = #op_name__; \ + hlir::framework::NodeAttr attrs; \ + OpBenchmarkTester tester(op_name, input_shapes); \ + auto input_tensors = tester.CreateInputTensors(); \ + tester.TestOp(common::UniqName(#op_name__), input_tensors, attrs, type__); \ + } + std::vector type{Float(32)}; std::vector type1{Float(32), Float(32)}; +std::vector type2 = {Int(32)}; // add // std::vector> shapes_add = {{1024, 1024, 1024}, {1024, 1024, 1024}}; // TEST_DEFAULT(elementwise_add, add, type) @@ -106,8 +117,8 @@ std::vector> shapes_pool2d1 = {{2, 1024, 14, 14}}; TEST_DEFAULT1(pool2d, pool2d1, type, attr_store_pool2d) // softmax -// std::vector> shapes_softmax = {{1024,2048}}; -// TEST_DEFAULT(softmax, softmax, type1) +std::vector> shapes_softmax = {{1024, 2048}}; +TEST_DEFAULT(softmax, softmax, type1) std::vector> shapes_softmax1 = {{3, 1000}}; TEST_DEFAULT(softmax, softmax1, type1) @@ -143,5 +154,59 @@ std::vector axes({2, 3}); std::unordered_map attr_store_slice = {{"starts", starts}, {"ends", ends}, {"axes", axes}}; TEST_DEFAULT1(slice, slice, type, attr_store_slice) +// unary +#define TEST_DEFAULT_UNARY(op__) \ + std::vector> shapes_unary_##op__ = {{1024, 2048}}; \ + std::vector> shapes_unary_##op__##1 = {{3, 1000}}; \ + TEST_DEFAULT(op__, unary_##op__, type) \ + TEST_DEFAULT(op__, unary_##op__##1, type) + +TEST_DEFAULT_UNARY(exp) +TEST_DEFAULT_UNARY(erf) +TEST_DEFAULT_UNARY(sigmoid) +TEST_DEFAULT_UNARY(sqrt) +TEST_DEFAULT_UNARY(log) +TEST_DEFAULT_UNARY(log2) +TEST_DEFAULT_UNARY(log10) +TEST_DEFAULT_UNARY(floor) +TEST_DEFAULT_UNARY(ceil) +TEST_DEFAULT_UNARY(round) +TEST_DEFAULT_UNARY(trunc) +TEST_DEFAULT_UNARY(cos) +TEST_DEFAULT_UNARY(cosh) +TEST_DEFAULT_UNARY(tan) +TEST_DEFAULT_UNARY(tanh) +TEST_DEFAULT_UNARY(sin) +TEST_DEFAULT_UNARY(sinh) +TEST_DEFAULT_UNARY(acos) +TEST_DEFAULT_UNARY(acosh) +TEST_DEFAULT_UNARY(asin) +TEST_DEFAULT_UNARY(asinh) +TEST_DEFAULT_UNARY(atan) +TEST_DEFAULT_UNARY(atanh) + +// TEST_DEFAULT_UNARY(isnan) +// TEST_DEFAULT_UNARY(isfinite) +// TEST_DEFAULT_UNARY(isinf) + +// bitwise_not +std::vector> shapes_bitwise_not = {{1024, 2048}}; +std::vector> shapes_bitwise_not1 = {{3, 1000}}; +TEST_DEFAULT_INT(bitwise_not, bitwise_not, type2) +TEST_DEFAULT_INT(bitwise_not, bitwise_not1, type2) + +// binary bitwise +#define TEST_DEFAULT_BINARY(op__) \ + std::vector> shapes_binary_##op__ = {{1024, 2048}, {1024, 2048}}; \ + std::vector> shapes_binary_##op__##1 = {{3, 1000}, {3, 1000}}; \ + TEST_DEFAULT_INT(op__, binary_##op__, type2) \ + TEST_DEFAULT_INT(op__, binary_##op__##1, type2) + +TEST_DEFAULT_BINARY(left_shift) +TEST_DEFAULT_BINARY(right_shift) +TEST_DEFAULT_BINARY(bitwise_or) +TEST_DEFAULT_BINARY(bitwise_and) +TEST_DEFAULT_BINARY(bitwise_xor) + } // namespace tests } // namespace cinn diff --git a/tools/tvm_benchmark/test_topi_default.py b/tools/tvm_benchmark/test_topi_default.py index 8eef1f1770f83..67a3a4494e7bf 100644 --- a/tools/tvm_benchmark/test_topi_default.py +++ b/tools/tvm_benchmark/test_topi_default.py @@ -7,23 +7,28 @@ import os from tvm import topi -dtype = "float32" +dtype = ["float32", "float32", "float32", "float32"] target = "llvm" ctx = tvm.context(target, 0) repeat = 10 -def test_op(func, input_shapes, out_shape, attrs={}, name="test_op"): +def test_op(func, + input_shapes, + out_shape, + attrs={}, + name="test_op", + dtype=dtype): assert len(input_shapes) >= 1 - A = te.placeholder(input_shapes[0], name="A") + A = te.placeholder(input_shapes[0], name="A", dtype=dtype[0]) if len(input_shapes) == 1: C = func(A) elif len(input_shapes) == 2: - B = te.placeholder(input_shapes[1], name="B") + B = te.placeholder(input_shapes[1], name="B", dtype=dtype[1]) C = func(A, B) elif len(input_shapes) == 3: - B = te.placeholder(input_shapes[1], name="B") - B1 = te.placeholder(input_shapes[2], name="B1") + B = te.placeholder(input_shapes[1], name="B", dtype=dtype[1]) + B1 = te.placeholder(input_shapes[2], name="B1", dtype=dtype[2]) C = func(A, B, B1) # Default schedule s = te.create_schedule(C.op) @@ -35,14 +40,15 @@ def test_op(func, input_shapes, out_shape, attrs={}, name="test_op"): func = tvm.build(s, [A, B, B1, C], target=target, name=name) assert func print(func) - a = tvm.nd.array(numpy.random.random(input_shapes[0]).astype(dtype), ctx) + a = tvm.nd.array( + numpy.random.random(input_shapes[0]).astype(dtype[0]), ctx) if len(input_shapes) > 1: b = tvm.nd.array( - numpy.random.random(input_shapes[1]).astype(dtype), ctx) + numpy.random.random(input_shapes[1]).astype(dtype[1]), ctx) if len(input_shapes) > 2: b1 = tvm.nd.array( - numpy.random.random(input_shapes[2]).astype(dtype), ctx) - c = tvm.nd.array(numpy.zeros(out_shape, dtype=dtype), ctx) + numpy.random.random(input_shapes[2]).astype(dtype[2]), ctx) + c = tvm.nd.array(numpy.zeros(out_shape, dtype=dtype[len(dtype) - 1]), ctx) evaluator = func.time_evaluator(func.entry_name, ctx, number=repeat) print("repeat: %f" % repeat) @@ -59,8 +65,8 @@ def test_op(func, input_shapes, out_shape, attrs={}, name="test_op"): def test_elementwise(): input_shapes, out_shape = [(100, 32), (100, 32)], (100, 32) - input_shapes1, out_shape1 = [(1024, 1024, 1024), - (1024, 1024, 1024)], (1024, 1024, 1024) + # input_shapes1, out_shape1 = [(1024, 1024, 1024), + # (1024, 1024, 1024)], (1024, 1024, 1024) input_shapes2, out_shape2 = [(1024, 14, 14), (1024, 14, 14)], (1024, 14, 14) @@ -71,18 +77,18 @@ def compute_mul(A, B): return topi.multiply(A, B) test_op(compute_add, input_shapes, out_shape, name="elementwise_add") - test_op(compute_add, input_shapes1, out_shape1, name="elementwise_add") + # test_op(compute_add, input_shapes1, out_shape1, name="elementwise_add") test_op(compute_add, input_shapes2, out_shape2, name="elementwise_add") test_op(compute_mul, input_shapes, out_shape, name="elementwise_mul") - test_op(compute_mul, input_shapes1, out_shape1, name="elementwise_mul") + # test_op(compute_mul, input_shapes1, out_shape1, name="elementwise_mul") test_op(compute_mul, input_shapes2, out_shape2, name="elementwise_mul") def test_relu(): - input_shapes, out_shape = [(100, 32)], (100, 32) + input_shapes, out_shape = [(2, 512, 7, 7)], (2, 512, 7, 7) input_shapes1, out_shape1 = [(1024, 1024, 1024)], (1024, 1024, 1024) - input_shapes2, out_shape2 = [(1024, 14, 14)], (1024, 14, - 14) + input_shapes2, out_shape2 = [(1024, 14, 14)], (1024, 14, 14) + input_shapes3, out_shape3 = [(100, 32)], (100, 32) name = "relu" def compute(A): @@ -91,6 +97,7 @@ def compute(A): test_op(compute, input_shapes, out_shape, name=name) test_op(compute, input_shapes1, out_shape1, name=name) test_op(compute, input_shapes2, out_shape2, name=name) + test_op(compute, input_shapes3, out_shape3, name=name) def test_conv2d_nchw(): @@ -152,16 +159,112 @@ def compute(A): test_op(compute, input_shapes1, out_shape1, name=name) -def test_exp(): +def test_unary(): + input_shapes, out_shape = [(1024, 2048)], (1024, 2048) + input_shapes1, out_shape1 = [(3, 1000)], (3, 1000) + input_shapes2, out_shape2 = [(1024, 2047)], (1024, 2047) + + def test_unary_basic(name, func): + def compute(A): + return func(A) + + test_op(compute, input_shapes, out_shape, name=name) + test_op(compute, input_shapes1, out_shape1, name=name) + test_op(compute, input_shapes2, out_shape2, name=name) + + for opfunc in [ + topi.exp, + topi.erf, + topi.sigmoid, + topi.sqrt, + topi.log, + topi.log2, + topi.log10, + topi.floor, + topi.ceil, + topi.round, + topi.trunc, + topi.cos, + topi.cosh, + topi.tan, + topi.tanh, + topi.sin, + topi.sinh, + topi.acos, + topi.acosh, + topi.asin, + topi.asinh, + topi.atan, + topi.atanh, + ]: + test_unary_basic(str(opfunc), opfunc) + + +def test_is(): input_shapes, out_shape = [(1024, 2048)], (1024, 2048) input_shapes1, out_shape1 = [(3, 1000)], (3, 1000) - name = "exp" + input_shapes2, out_shape2 = [(1024, 2047)], (1024, 2047) + type = ["float32", "bool"] - def compute(A): - return topi.exp(A) + def test_is_basic(name, func): + def compute(A): + return func(A) + + test_op(compute, input_shapes, out_shape, name=name, dtype=type) + test_op(compute, input_shapes1, out_shape1, name=name, dtype=type) + test_op(compute, input_shapes2, out_shape2, name=name, dtype=type) + + for opfunc in [ + topi.isnan, + topi.isfinite, + topi.isinf, + ]: + test_is_basic(str(opfunc), opfunc) + + +def test_bitwise_not(): + input_shapes, out_shape = [(1024, 2048)], (1024, 2048) + input_shapes1, out_shape1 = [(3, 1000)], (3, 1000) + input_shapes2, out_shape2 = [(1024, 2047)], (1024, 2047) + type = ["int32", "int32", "int32"] + + def test_unary_basic(name, func): + def compute(A): + return func(A) + + test_op(compute, input_shapes, out_shape, name=name, dtype=type) + test_op(compute, input_shapes1, out_shape1, name=name, dtype=type) + test_op(compute, input_shapes2, out_shape2, name=name, dtype=type) + + for opfunc in [ + topi.bitwise_not, + ]: + test_unary_basic(str(opfunc), opfunc) + + +def test_bitwise_binary(): + input_shapes, out_shape = [(1024, 2048), (1024, 2048)], (1024, 2048) + input_shapes1, out_shape1 = [(3, 1000), (3, 1000)], (3, 1000) + input_shapes2, out_shape2 = [(1024, 2047), (1024, 2047)], (1024, 2047) + type = ["int32", "int32", "int32"] + + def test_binary_basic(name, func): + def compute(A, B): + return func(A, B) + + test_op(compute, input_shapes, out_shape, name=name, dtype=type) + test_op(compute, input_shapes1, out_shape1, name=name, dtype=type) + test_op(compute, input_shapes2, out_shape2, name=name, dtype=type) + + for opfunc in [ + topi.bitwise_or, + topi.bitwise_and, + topi.bitwise_xor, + topi.left_shift, + topi.right_shift, + ]: + test_binary_basic(str(opfunc), opfunc) - test_op(compute, input_shapes, out_shape, name=name) - test_op(compute, input_shapes1, out_shape1, name=name) def test_sigmoid(): input_shapes, out_shape = [(2, 672, 1, 1)], (2, 672, 1, 1) @@ -176,17 +279,19 @@ def compute(A): def test_matmul(): - # input_shapes, out_shape = [(32,32),(32,32)], (32,32) - input_shapes, out_shape = [(512, 512), (512, 512)], (512, 512) - # input_shapes, out_shape = [(1024,1024),(1024,1024)], (1024,1024) - # input_shapes1, out_shape1 = [(100,32), (32,100)], (100,100) + input_shapes, out_shape = [(32, 32), (32, 32)], (32, 32) + input_shapes1, out_shape1 = [(512, 512), (512, 512)], (512, 512) + # input_shapes2, out_shape2 = [(1024,1024),(1024,1024)], (1024,1024) + input_shapes3, out_shape3 = [(100, 32), (32, 100)], (100, 100) name = "matmul" def compute(A, B): return topi.matmul(A, B, False, False) test_op(compute, input_shapes, out_shape, name=name) - # test_op(compute, input_shapes1, out_shape1, name=name) + test_op(compute, input_shapes1, out_shape1, name=name) + # test_op(compute, input_shapes2, out_shape2, name=name) + test_op(compute, input_shapes3, out_shape3, name=name) # batch_norm @@ -212,7 +317,10 @@ def compute(A, Scale, Shift): test_depthwise_conv2d_nchw() test_pool2d() test_softmax() - test_exp() + test_unary() + test_is() + test_bitwise_not() + test_bitwise_binary() test_sigmoid() test_matmul() test_batch_norm()