Skip to content

Commit

Permalink
support unary op vectorize (PaddlePaddle#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 authored Dec 31, 2020
1 parent bca7614 commit c9b5f37
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 24 deletions.
2 changes: 1 addition & 1 deletion cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::UnaryIntrin *op) {
}
}
CHECK(!op->args.empty());
llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_);
llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_, true);
llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type);
CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {});
return b_->CreateCall(fn, arg_value);
Expand Down
4 changes: 2 additions & 2 deletions cinn/backends/llvm/llvm_intrin_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) {
CHECK(node);
CHECK_GE(node->read_args.size(), arg_nums);
if (add_float_suffix) {
CHECK_EQ(node->type(), Float(32));
*rv = ir::intrinsics::UnaryIntrin::Make(node->name + "f", node->read_args, id, arg_nums, Float(32));
CHECK(node->type().is_float());
*rv = ir::intrinsics::UnaryIntrin::Make(node->name + "f", node->read_args, id, arg_nums, node->type());
} else {
*rv = ir::intrinsics::UnaryIntrin::Make(node->name, node->read_args, id, arg_nums, node->type());
}
Expand Down
10 changes: 7 additions & 3 deletions cinn/backends/llvm/llvm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace cinn {
namespace backends {

llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m) {
llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m, bool is_vec) {
llvm::Type *ir_type = nullptr;
if (type.is_cpp_const()) {
// TODO(fc500110) support it latter.
Expand Down Expand Up @@ -51,9 +51,13 @@ llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m) {
}
CHECK(ir_type) << "LLVM can't convert type: " << type;

// C array.
// C array / vector.
if (type.lanes() > 1) {
ir_type = llvm::ArrayType::get(ir_type, type.lanes());
if (is_vec) {
ir_type = llvm::FixedVectorType::get(ir_type, type.lanes());
} else {
ir_type = llvm::ArrayType::get(ir_type, type.lanes());
}
}

if (type.is_cpp_handle()) {
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/llvm/llvm_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::string DumpToString(const T &entity) {

inline llvm::StringRef AsStringRef(std::string_view str) { return llvm::StringRef(str.data(), str.size()); }

llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m);
llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m, bool is_vec = false);

template <typename T>
llvm::Type *llvm_type_of(llvm::Module *m);
Expand Down
5 changes: 5 additions & 0 deletions cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwise(const framework::NodeAttr &at
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
}
} else if (target.arch == Target::Arch::X86) {
Expr Out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(Out.as_tensor());
pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.back(), target);
}
*ret = arg_pack;
});
Expand Down
3 changes: 2 additions & 1 deletion cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ Expr Load::Make(Expr tensor, const std::vector<Expr> &indices) {
auto node = make_shared<Load>();
node->tensor = tensor;
node->indices = indices;
node->set_type(node->type());
return Expr(node);
}
Type Load::type() const {
Expand Down Expand Up @@ -662,7 +663,7 @@ void Select::Verify() const {
CHECK(condition.defined());
CHECK(true_value.defined());
CHECK(false_value.defined());
CHECK_EQ(condition.type(), type_of<bool>()) << "Select Node's condition should be a boolean";
CHECK(condition.type().is_bool()) << "Select Node's condition should be a boolean";
CHECK_EQ(true_value.type(), false_value.type())
<< "Select Node's true_value and false_value should have the same type";
}
Expand Down
29 changes: 17 additions & 12 deletions cinn/lang/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ Expr logic_or(const std::vector<Expr>& conds) {

//! extern call op
#define EXTERN_CALL_IMP(name__, target__) \
Expr name__(Expr e) { return CallExtern(#target__, {e}); }
Expr name__(Expr e) { return ir::Call::Make(e->type(), #target__, {e}, {}, ir::CallType::Extern); }

#define EXTERN_CALL_IMP_NO_VEC(name__, target__) \
Expr name__(Expr e) { \
return ir::Call::Make( \
e->type(), #target__, {e}, {}, ir::CallType::Extern, ir::FunctionRef(), 0, {{"vectorizable", false}}); \
}

EXTERN_CALL_IMP(Exp, exp);
EXTERN_CALL_IMP(Erf, erf);
EXTERN_CALL_IMP_NO_VEC(Erf, erf);
EXTERN_CALL_IMP(Sqrt, sqrt);
EXTERN_CALL_IMP(Log, log);
EXTERN_CALL_IMP(Log2, log2);
Expand All @@ -45,17 +51,17 @@ EXTERN_CALL_IMP(Ceil, ceil);
EXTERN_CALL_IMP(Round, round);
EXTERN_CALL_IMP(Trunc, trunc);
EXTERN_CALL_IMP(Cos, cos);
EXTERN_CALL_IMP(Sin, sin);
EXTERN_CALL_IMP(Cosh, cosh);
EXTERN_CALL_IMP(Tan, tan);
EXTERN_CALL_IMP(Sin, sin);
EXTERN_CALL_IMP(Sinh, sinh);
EXTERN_CALL_IMP(Acos, acos);
EXTERN_CALL_IMP(Acosh, acosh);
EXTERN_CALL_IMP(Asin, asin);
EXTERN_CALL_IMP(Asinh, asinh);
EXTERN_CALL_IMP(Atan, atan);
EXTERN_CALL_IMP(Atanh, atanh);
EXTERN_CALL_IMP(Tanh, tanh);
EXTERN_CALL_IMP(Sinh, sinh);
EXTERN_CALL_IMP_NO_VEC(Acos, acos);
EXTERN_CALL_IMP_NO_VEC(Acosh, acosh);
EXTERN_CALL_IMP_NO_VEC(Asin, asin);
EXTERN_CALL_IMP_NO_VEC(Asinh, asinh);
EXTERN_CALL_IMP_NO_VEC(Atan, atan);
EXTERN_CALL_IMP_NO_VEC(Atanh, atanh);

Expr min_value(const Type& type) {
CHECK_EQ(type.lanes(), 1);
Expand Down Expand Up @@ -114,7 +120,6 @@ Expr Abs(Expr e) {

Expr IsNan(Expr e) {
Type type = e->type();
// Type bool_type = Bool(type.lanes());
if (type.is_int() || type.is_uint()) {
return common::make_bool(false, type.lanes());
} else if (type.is_float()) {
Expand All @@ -126,7 +131,7 @@ Expr IsNan(Expr e) {
if (type.bits() == 16) {
arg = ir::Cast::Make(Float(32), std::move(e));
}
return CallExtern("isnan", {arg});
return CallExtern("isnan", {arg}, {{"vectorizable", false}});
} else {
LOG(FATAL) << type << "is not supported for isnan op.";
return e;
Expand Down
1 change: 0 additions & 1 deletion cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Expr Optimize(Expr e, Target target, bool runtime_debug_info) {
CastSimplify(&copied);
Simplify(&copied);
VectorizeLoops(&copied, Target());
EliminateBroadcastInForloop(&copied);
UnrollLoop(&copied);
#ifdef CINN_WITH_CUDA
RemoveGpuForloopsAxis(&copied);
Expand Down
51 changes: 48 additions & 3 deletions cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,15 @@ class Vectorizer : public IRMutator<Expr *> {
}
}
if (!need_visit) return;

*expr = Load::Make(node->tensor, node->indices);
int lanes = 0;
for (auto &idx : node->indices) {
lanes = std::max(idx.type().lanes(), lanes);
}
std::vector<Expr> new_indices;
for (auto &idx : node->indices) {
new_indices.push_back(Widen(idx, lanes));
}
*expr = Load::Make(node->tensor, new_indices);
}

void Visit(const Store *op, Expr *expr) override {
Expand Down Expand Up @@ -173,7 +180,45 @@ class Vectorizer : public IRMutator<Expr *> {
*expr = Store::Make(node->tensor, node->value, new_indices);
}

void Visit(const Call *op, Expr *expr) override { LOG(ERROR) << "Ignore widen Call node"; }
void Visit(const Call *op, Expr *expr) override {
std::vector<Expr> read_args = op->read_args;
std::vector<Expr> write_args = op->write_args;
auto *node = expr->As<Call>();
ir::IRMutator<>::Visit(op, expr);
bool is_changed = false;
int lanes = 0;
for (int i = 0; i < node->read_args.size(); i++) {
lanes = std::max(node->read_args[i].type().lanes(), lanes);
if (!node->read_args[i].same_as(read_args[i])) {
is_changed = true;
}
}
for (int i = 0; i < node->write_args.size(); i++) {
lanes = std::max(node->write_args[i].type().lanes(), lanes);
if (!node->write_args[i].same_as(write_args[i])) {
is_changed = true;
}
}
if (!is_changed) return;

for (int i = 0; i < read_args.size(); i++) {
node->read_args[i] = Widen(node->read_args[i], lanes);
}
for (int i = 0; i < write_args.size(); i++) {
node->write_args[i] = Widen(node->write_args[i], lanes);
}

CHECK(!read_args.empty());
Type type = op->type().with_lanes(lanes);
*expr = Call::Make(type,
node->name,
node->read_args,
node->write_args,
node->call_type,
node->func,
node->value_index,
node->attrs);
}

void Visit(const Let *op, Expr *expr) override {
auto *node = expr->As<Let>();
Expand Down

0 comments on commit c9b5f37

Please sign in to comment.