Skip to content

Commit

Permalink
IR add PrimitiveNode (PaddlePaddle#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Jul 27, 2020
1 parent f335e5a commit 7111279
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ void CodeGenC::PrintStackVecType(Type type, int lanes) {
os() << "StackedVec<" << GetTypeRepr(type) << "," << lanes << ">";
}

void CodeGenC::Visit(const ir::PrimitiveNode *op){NOT_IMPLEMENTED}

std::string ReadWholeFile(const std::string &path) {
CHECK(!path.empty());
std::ifstream file(path);
Expand Down
2 changes: 2 additions & 0 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Block *op) {
return ret;
}

llvm::Value *CodeGenLLVM::Visit(const ir::PrimitiveNode *) { NOT_IMPLEMENTED return nullptr; }

llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) {
if (op->name == runtime::intrisic::buffer_create) {
} else if (op->name == runtime::intrisic::get_address_repr) {
Expand Down
7 changes: 7 additions & 0 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -552,5 +552,12 @@ lang::Module _Module_::Make(const std::string &name, Target target) {
return lang::Module(n);
}

Expr PrimitiveNode::Make(const std::string &name, const std::map<std::string, attr_t> &attrs) {
auto *n = make_shared<PrimitiveNode>();
n->name = name;
n->attrs = attrs;
return Expr(n);
}

} // namespace ir
} // namespace cinn
22 changes: 22 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <string>
#include <vector>

#include <variant>
#include "cinn/common/shared.h"
#include "cinn/common/type.h"
#include "cinn/ir/function_base.h"
Expand Down Expand Up @@ -826,6 +827,27 @@ struct _Module_ : public ExprNode<_Module_> {
static const IrNodeTy _node_type_ = IrNodeTy::_Module_;
};

/**
* \brief PrimitiveNode holds the contept of Primitive in CINN.
* A Primitive is a basic Call to some Expr function, it is introduced to create several level of coarsed-grained IR
* nodes for better IR optimization and hardware adaption.
*/
struct PrimitiveNode : public ExprNode<PrimitiveNode> {
// NOTE attr_t only support POD, can not contain Expr or other IR nodes, or the IRVisitor or IRCopy on PrimitiveNode
// will result in undefined behavior.
using attr_t = std::variant<int, float, bool, std::string>;

std::string name;
//! the inputs of the PrimitiveNode, the vector<vector<Expr>> can hold variadic arguments.
std::vector<std::vector<Expr>> arguments;
//! the attribute of this PrimitiveNode.
std::map<std::string, attr_t> attrs;

static Expr Make(const std::string& name, const std::map<std::string, attr_t>& attrs);

static const IrNodeTy _node_type_ = IrNodeTy::PrimitiveNode;
};

class _Range_;
class Range : public IrNodeRef {
public:
Expand Down
9 changes: 9 additions & 0 deletions cinn/ir/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ void IRMutator<T>::Visit(const Sum *expr, T op) {
IRVisitorBase<void, T>::Visit(&x, &x);
}
}
template <typename T>
void IRMutator<T>::Visit(const PrimitiveNode *expr, T op) {
auto *node = op->template As<PrimitiveNode>();
for (auto &args : node->arguments) {
for (auto &arg : args) {
IRVisitorBase<void, T>::Visit(&arg, &arg);
}
}
}

} // namespace ir
} // namespace cinn
15 changes: 15 additions & 0 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,21 @@ void IrPrinter::Visit(const Sum *x) {
os() << ")";
}

void IrPrinter::Visit(const PrimitiveNode *x) {
os() << x->name << "(";
std::vector<std::string> args_repr;
for (auto &args : x->arguments) {
std::vector<std::string> arg_repr;
for (auto &arg : args) {
arg_repr.push_back(utils::GetStreamCnt(arg));
}
args_repr.push_back(utils::Join(arg_repr, ","));
}

os() << utils::Join(args_repr, ",");
os() << ")";
}

std::ostream &operator<<(std::ostream &os, Expr a) {
std::stringstream ss;
IrPrinter printer(ss);
Expand Down
1 change: 1 addition & 0 deletions cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Var;
macro__(Product) \
macro__(Sum) \
macro__(Activate) \
macro__(PrimitiveNode) \

#define NODETY_FORALL(__m) \
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
Expand Down
13 changes: 13 additions & 0 deletions cinn/optim/ir_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,19 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
return n;
}

Expr Visit(const ir::PrimitiveNode* op) override {
std::vector<std::vector<Expr>> arguments;
for (auto& args : op->arguments) {
arguments.push_back(Visit(args));
}

auto n = common::make_shared<ir::PrimitiveNode>();
n->name = op->name;
n->attrs = op->attrs; // attrs are PODs
n->arguments = arguments;
return Expr(n);
}

#define OP_BINARY_HANDLE(op__) \
Expr Visit(const ir::op__* op) override { \
auto a = IRVisitorBase::Visit(&op->a()); \
Expand Down

0 comments on commit 7111279

Please sign in to comment.