Skip to content

Commit

Permalink
develop and optimize unary and bitwise ops with llvm intrinsics and i… (
Browse files Browse the repository at this point in the history
PaddlePaddle#296)

* develop and optimize unary and bitwise ops with llvm intrinsics and ir optmization

* fix code style

* opt ops using extern call(acos,acosh,asin,asinh,atan,atanh,erf)

* fix code styles
  • Loading branch information
wenming2014 authored Nov 28, 2020
1 parent 98d2ff3 commit 92f72ad
Show file tree
Hide file tree
Showing 42 changed files with 1,007 additions and 331 deletions.
12 changes: 12 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,18 @@ void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) {
os() << ")";
}

void CodeGenC::Visit(const ir::intrinsics::UnaryIntrin *op) {
os() << runtime::intrisic::unary_intrin_repr << "_";
os() << op->name << "(";
if (!op->args.empty()) {
for (int i = 0; i < op->args.size() - 1; i++) {
Print(op->args[i]);
os() << ", ";
}
Print(op->args.back());
}
}

std::string ReadWholeFile(const std::string &path) {
CHECK(!path.empty());
std::ifstream file(path);
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ TEST(CodeGenC, call_extern) {
Placeholder<float> x("x", {M});

ir::Tensor y = Compute(
{M}, [=](Var i) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {x(i)}); }, "y");
{M}, [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, "y");

auto stages = CreateStages({y});

Expand Down
32 changes: 22 additions & 10 deletions cinn/backends/extern_func_protos.cc
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
#include "cinn/backends/extern_func_protos.h"

#include <string>
#include <vector>

namespace cinn {
namespace backends {

ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() {
static const std::vector<std::string> extern_funcs_fp32 = {
"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor",
"ceil", "round", "trunc", "cos", "cosh", "tan", "sin", "sinh",
"acos", "acosh", "asin", "asinh", "atan", "atanh", "isnan", "tanh",
"isfinite", "isinf", "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"};
static const std::vector<std::string> extern_funcs_int64 = {
static const std::vector<std::string> extern_funcs_fp32_unary = {
"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", "cos",
"cosh", "tan", "tanh", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", "fabs"};
static const std::vector<std::string> extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"};
static const std::vector<std::string> extern_funcs_int_binary = {
"left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"};
for (int i = 0; i < extern_funcs_fp32.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_fp32[i], {Float(32)}, Float(32));
static const std::vector<std::string> extern_funcs_int_int_unary = {"bitwise_not"};
for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32));
Register(proto->name, proto);
}
for (int i = 0; i < extern_funcs_float_bool_unary.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_float_bool_unary[i], {Float(32)}, Bool());
Register(proto->name, proto);
}
for (int i = 0; i < extern_funcs_int64.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_int64[i], {Int(64)}, Int(64));
for (int i = 0; i < extern_funcs_int_binary.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_int_binary[i], {Int(32), Int(32)}, Int(32));
Register(proto->name, proto);
}
for (int i = 0; i < extern_funcs_int_int_unary.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32));
Register(proto->name, proto);
}

auto* n = detail::CreateTanhVProto();
Register(n->name, n);
}
Expand Down
116 changes: 112 additions & 4 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

#include <glog/logging.h>
#include <glog/stl_logging.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/IR/Instruction.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Metadata.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>

#include <algorithm>
Expand Down Expand Up @@ -78,8 +81,11 @@ int NextPowerOfTwo(int x) {

} // namespace

CodeGenLLVM::CodeGenLLVM(llvm::Module *m, llvm::IRBuilder<> *b, const std::shared_ptr<SymbolTable> &symbol_table)
: m_(m), b_(b), symbol_table_(symbol_table) {
CodeGenLLVM::CodeGenLLVM(llvm::Module *m,
llvm::IRBuilder<> *b,
const std::shared_ptr<SymbolTable> &symbol_table,
const Target &target)
: m_(m), b_(b), symbol_table_(symbol_table), target_(target) {
if (!symbol_table.get()) {
symbol_table_ = std::make_shared<SymbolTable>();
}
Expand All @@ -88,8 +94,7 @@ CodeGenLLVM::CodeGenLLVM(llvm::Module *m, llvm::IRBuilder<> *b, const std::share
md_builder_ = std::make_unique<llvm::MDBuilder>(b_->getContext());
md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa");
md_tbaa_alias_set_ = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_);

InitTarget(common::DefaultHostTarget());
InitTarget(target_);
}

CodeGenLLVM::~CodeGenLLVM() {}
Expand Down Expand Up @@ -1141,6 +1146,11 @@ llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, int begin, int lanes)
}

void CodeGenLLVM::InitTarget(const Target &target) {
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
switch (target.arch) {
case Target::Arch::X86:
if (target.bits == Target::Bit::k32) {
Expand Down Expand Up @@ -1283,6 +1293,104 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::ArgsConstruct *op) {
return Call(callee, std::move(args));
}

llvm::Function *CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id,
llvm::Type *ret_type,
llvm::ArrayRef<llvm::Type *> arg_types) {
llvm::Module *module = m_;

if (!llvm::Intrinsic::isOverloaded(id)) {
return llvm::Intrinsic::getDeclaration(module, id, {});
}

llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
llvm::SmallVector<llvm::Type *, 4> overload_types;

auto try_match = [&](llvm::FunctionType *f_ty, bool var_arg) {
overload_types.clear();
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
if (llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref)) {
return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
}
}
return match;
};

auto *fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
switch (try_match(fn_ty, false)) {
case llvm::Intrinsic::MatchIntrinsicTypes_Match:
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
return nullptr;
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
break;
}

// try matching the var arg signature.
llvm::SmallVector<llvm::Type *, 4> var_types;
for (int i = 0; i <= arg_types.size(); ++i) {
if (i > 0) {
var_types.push_back(arg_types[i - 1]);
}
auto *ft = llvm::FunctionType::get(ret_type, var_types, true);
if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
}
}
return nullptr;
}

llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::UnaryIntrin *op) {
std::string func_name = op->name;
if (op->id == -1) {
if (func_name == "bitwise_and") {
CHECK_GE(op->args.size(), 2U);
return b_->CreateAnd(Visit(&op->args[0]), Visit(&op->args[1]));
} else if (func_name == "bitwise_or") {
CHECK_GE(op->args.size(), 2U);
return b_->CreateOr(Visit(&op->args[0]), Visit(&op->args[1]));
} else if (func_name == "bitwise_xor") {
CHECK_GE(op->args.size(), 2U);
return b_->CreateXor(Visit(&op->args[0]), Visit(&op->args[1]));
} else if (func_name == "bitwise_not") {
CHECK_GE(op->args.size(), 1U);
return b_->CreateNot(Visit(&op->args[0]));
} else if (func_name == "left_shift") {
CHECK_GE(op->args.size(), 2U);
return b_->CreateShl(Visit(&op->args[0]), Visit(&op->args[1]));
} else if (func_name == "right_shift") {
CHECK_GE(op->args.size(), 2U);
if (op->args[0]->type().is_int()) {
return b_->CreateAShr(Visit(&op->args[0]), Visit(&op->args[1]));
} else {
return b_->CreateLShr(Visit(&op->args[0]), Visit(&op->args[1]));
}
} else if (func_name == "isnan") {
CHECK_GE(op->args.size(), 1U);
llvm::Value *v = Visit(&op->args[0]);
return b_->CreateFCmpUNO(v, v);
}
}

llvm::Intrinsic::ID id = op->id;
int64_t num_signature = op->arg_nums;
std::vector<llvm::Value *> arg_value;
std::vector<llvm::Type *> arg_type;
for (size_t i = 0; i < op->args.size(); ++i) {
arg_value.push_back(Visit(&op->args[i]));
if (i < static_cast<size_t>(num_signature)) {
arg_type.push_back(arg_value.back()->getType());
}
}
CHECK(!op->args.empty());
llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_);
llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type);
CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {});
return b_->CreateCall(fn, arg_value);
}

llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) {
auto to_type = op->GetOutputType(0);
llvm::Function *callee{};
Expand Down
10 changes: 8 additions & 2 deletions cinn/backends/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class SymbolTable {
};

struct SymbolTableGuard {
SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); }
explicit SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); }

~SymbolTableGuard() { symbol_table_.PopScope(); }

Expand All @@ -97,7 +97,8 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin<CodeGenLLVM> {
public:
explicit CodeGenLLVM(llvm::Module *m,
llvm::IRBuilder<> *b,
const std::shared_ptr<SymbolTable> &symbol_table = nullptr);
const std::shared_ptr<SymbolTable> &symbol_table = nullptr,
const Target &target = common::DefaultHostTarget());

// Common llvm types
// @{
Expand Down Expand Up @@ -146,6 +147,10 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin<CodeGenLLVM> {

virtual llvm::Value *GetVar(const std::string &name, bool lazy = true);

llvm::Function *GetIntrinsicDecl(llvm::Intrinsic::ID id,
llvm::Type *ret_type,
llvm::ArrayRef<llvm::Type *> arg_types);

// Constants
// @{
inline llvm::Value *llvm_int32_constant(int v) { return llvm::ConstantInt::get(ll_int32_ty(), v); }
Expand Down Expand Up @@ -200,6 +205,7 @@ class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin<CodeGenLLVM> {
llvm::MDNode *md_tbaa_alias_set_{nullptr};

int naive_vec_alignment_{0};
Target target_;
};
namespace detail {
Expr StridedRampBase(Expr e, int stride);
Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/llvm/execution_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ TEST(ExecutionEngine, call_extern) {
{M, N}, [=](Var i, Var j) { return x(i, j) + y(i, j); }, "add_out");

ir::Tensor res = Compute(
{M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("cinn_cpu_tanh_fp32", {add_out(i, j)}); }, "res");
{M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("tanh", {add_out(i, j)}); }, "res");

auto stages = CreateStages({add_out, res});

Expand Down Expand Up @@ -297,7 +297,7 @@ TEST(ExecutionEngine, call_extern) {
auto *cd = reinterpret_cast<float *>(cb->memory);
for (int m = 0; m < kM; m++) {
for (int n = 0; n < kN; n++) {
ASSERT_NEAR(cd[m * kN + n], cinn_cpu_tanh_fp32(ad[m * kN + n] + bd[m * kN + n]), 1e-5);
ASSERT_NEAR(cd[m * kN + n], tanh(ad[m * kN + n] + bd[m * kN + n]), 1e-5);
}
}
}
Expand Down
Loading

0 comments on commit 92f72ad

Please sign in to comment.