Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#31 from Superjomn/fea/link-isl-ast-to-ir
Browse files Browse the repository at this point in the history
fea/link isl ast to ir
  • Loading branch information
Superjomn authored Feb 18, 2020
2 parents c47783d + 1acdad8 commit 7cf7c03
Show file tree
Hide file tree
Showing 23 changed files with 534 additions and 135 deletions.
39 changes: 22 additions & 17 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Expr Cast::Make(Type t, Expr v) {
return Expr(node);
}

void Cast::Accept(IRVisitor *v) const { v->IRVisitorBase::Visit(&this->v); }

Expr Add::Make(Expr a, Expr b) {
auto node = make_shared<Add>(a, b);
return Expr(node);
Expand Down Expand Up @@ -118,7 +120,7 @@ Expr _Var_::Make(const std::string &name, const Type &type) {
return Expr(node);
}

For::For(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body) {
For::For(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Expr body) : ExprNode(Type()) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(body.defined());
Expand All @@ -130,38 +132,38 @@ For::For(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt bod
this->body = std::move(body);
}

Stmt For::Make(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body) {
Expr For::Make(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Expr body) {
auto node = make_shared<For>(min, extent, for_type, device_api, body);
return Stmt(node);
return Expr(node);
}

Stmt Block::Make(const std::vector<Stmt> &stmts) {
Expr Block::Make(const std::vector<Expr> &stmts) {
auto node = make_shared<Block>();
node->stmts = stmts;
return Stmt(node);
return Expr(node);
}

Stmt IfThenElse::Make(Expr condition, Stmt true_case, Stmt false_case) {
Expr IfThenElse::Make(Expr condition, Expr true_case, Expr false_case) {
auto node = make_shared<IfThenElse>(condition, true_case, false_case);
return Stmt(node);
return Expr(node);
}

Stmt Store::Make(Var buffer_var, Expr value, Expr index) {
Expr Store::Make(Var buffer_var, Expr value, Expr index) {
auto node = make_shared<Store>();
node->buffer_var = buffer_var;
node->value = value;
node->index = index;
return Stmt(node);
return Expr(node);
}

Stmt Alloc::Make(Var buffer_var, Type type, const std::vector<Expr> &extents, Expr condition, Stmt body) {
Expr Alloc::Make(Var buffer_var, Type type, const std::vector<Expr> &extents, Expr condition, Expr body) {
auto node = make_shared<Alloc>();
node->buffer_var = buffer_var;
node->type = type;
node->extents = extents;
node->condition = condition;
node->body = body;
return Stmt(node);
node->set_type(type);
return Expr(node);
}

int32_t Alloc::ConstantAllocationSize() const {
Expand All @@ -180,10 +182,10 @@ int32_t Alloc::ConstantAllocationSize(const std::string &name, const std::vector
return res;
}

Stmt Free::Make(Var var) {
Expr Free::Make(Var var) {
auto node = make_shared<Free>();
node->var = var;
return Stmt(node);
return Expr(node);
}

void _Range_::Accept(IRVisitor *v) const { v->Visit(this); }
Expand Down Expand Up @@ -226,8 +228,8 @@ Expr Call::Make(Type type,
return Expr(node);
}

Stmt PolyFor::Make(
Var iterator, Expr init_val, Expr condition, Expr inc, ForType for_type, DeviceAPI device_api, Stmt body) {
Expr PolyFor::Make(
Var iterator, Expr init_val, Expr condition, Expr inc, ForType for_type, DeviceAPI device_api, Expr body) {
auto n = make_shared<PolyFor>();
n->iterator = iterator;
n->init = init_val;
Expand All @@ -236,9 +238,12 @@ Stmt PolyFor::Make(
n->for_type = for_type;
n->device_api = device_api;
n->body = body;
return Stmt(n);
return Expr(n);
}

bool Var::operator==(const Var &o) const { return o->name == operator->()->name; }
bool Var::operator!=(const Var &o) const { return !(*this == o); }

} // namespace ir

namespace common {
Expand Down
69 changes: 38 additions & 31 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct Cast : public UnaryOpNode<Cast> {

static Expr Make(Type t, Expr v);

void Accept(IRVisitor* v) const override;

static const IrNodeTy _node_type_ = IrNodeTy::Cast;
};

Expand Down Expand Up @@ -266,6 +268,9 @@ struct Var : public IrNodeRef {

operator Expr() { return Expr(get()); }

bool operator==(const Var& o) const;
bool operator!=(const Var& o) const;

const _Var_* operator->() const { return get(); }
_Var_* operator->() { return get(); }
const _Var_* get() const { return static_cast<const _Var_*>(ptr()); }
Expand Down Expand Up @@ -313,11 +318,13 @@ struct Load : public ExprNode<Load> {
/**
* Store a `value` to the buffer at a given `index`.
*/
struct Store : public StmtNode<Store> {
struct Store : public ExprNode<Store> {
Var buffer_var;
Expr value, index;

static Stmt Make(Var buffer_var, Expr value, Expr index);
Store() : ExprNode(Type()) {}

static Expr Make(Var buffer_var, Expr value, Expr index);

static const IrNodeTy _node_type_ = IrNodeTy::Store;
};
Expand All @@ -326,15 +333,16 @@ struct Store : public StmtNode<Store> {
* Allocate a buffer with the given type and size. The buffer lives for at most the duration of the body statement,
* within which it is freed.
*/
struct Alloc : public StmtNode<Alloc> {
struct Alloc : public ExprNode<Alloc> {
Var buffer_var;
Type type;
//! Dimensions of this buffer (as a multi-dimensional array).
std::vector<Expr> extents;
Expr condition;
Stmt body;
Expr body;

Alloc() : ExprNode(Type()) {}

static Stmt Make(Var buffer_var, Type type, const std::vector<Expr>& extents, Expr condition, Stmt body);
static Expr Make(Var buffer_var, Type type, const std::vector<Expr>& extents, Expr condition, Expr body);

int32_t ConstantAllocationSize() const;
static int32_t ConstantAllocationSize(const std::string& name, const std::vector<Expr>& extents);
Expand All @@ -345,26 +353,28 @@ struct Alloc : public StmtNode<Alloc> {
/**
* Free the resources associated with the given buffer.
*/
struct Free : public StmtNode<Free> {
struct Free : public ExprNode<Free> {
Var var;

static Stmt Make(Var var);
Free() : ExprNode(Type()) {}

static Expr Make(Var var);

static const IrNodeTy _node_type_ = IrNodeTy::Free;
};

struct IfThenElse : public StmtNode<IfThenElse> {
struct IfThenElse : public ExprNode<IfThenElse> {
Expr condition;
Stmt true_case;
Stmt false_case;
Expr true_case;
Expr false_case;

IfThenElse(Expr condition, Stmt true_case, Stmt false_case)
: condition(condition), true_case(true_case), false_case(false_case) {
IfThenElse(Expr condition, Expr true_case, Expr false_case)
: ExprNode(Type()), condition(condition), true_case(true_case), false_case(false_case) {
CHECK(condition.defined());
CHECK(true_case.defined());
}

static Stmt Make(Expr condition, Stmt true_case, Stmt false_case);
static Expr Make(Expr condition, Expr true_case, Expr false_case);

static const IrNodeTy _node_type_ = IrNodeTy::IfThenElse;
};
Expand All @@ -380,7 +390,7 @@ enum class ForType : int {
Unrolled = 3,
};

struct For : public StmtNode<For> {
struct For : public ExprNode<For> {
//! The loop variable.
Expr loop_var;
//! The minimum value of the iteration.
Expand All @@ -390,19 +400,19 @@ struct For : public StmtNode<For> {
//! The type of the for loop.
ForType for_type;

Stmt body;
Expr body;

DeviceAPI device_api;

For(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
For(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Expr body);

static Stmt Make(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
static Expr Make(Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Expr body);

static const IrNodeTy _node_type_ = IrNodeTy::For;
};

//! Polyhedral forloop, which condition is more complex than the normal `For`.
struct PolyFor : public StmtNode<PolyFor> {
struct PolyFor : public ExprNode<PolyFor> {
//! The iterator variable.
Var iterator;
// Initial value of the iterator.
Expand All @@ -412,13 +422,15 @@ struct PolyFor : public StmtNode<PolyFor> {
//! Increase the iterator.
Expr inc;
//! The forloop body.
Stmt body;
Expr body;

ForType for_type;
DeviceAPI device_api;

static Stmt Make(
Var iterator, Expr init_val, Expr condition, Expr inc, ForType for_type, DeviceAPI device_api, Stmt body);
PolyFor() : ExprNode(Type()) {}

static Expr Make(
Var iterator, Expr init_val, Expr condition, Expr inc, ForType for_type, DeviceAPI device_api, Expr body);

static const IrNodeTy _node_type_ = IrNodeTy::PolyFor;
};
Expand All @@ -429,12 +441,12 @@ struct Module : public ExprNode<Module> {
static const IrNodeTy _node_type_ = IrNodeTy::Module;
};

struct Block : public StmtNode<Block> {
std::vector<Stmt> stmts;
struct Block : public ExprNode<Block> {
std::vector<Expr> stmts;

Block() = default;
Block() : ExprNode(Type()) {}

static Stmt Make(const std::vector<Stmt>& stmts);
static Expr Make(const std::vector<Expr>& stmts);

static const IrNodeTy _node_type_ = IrNodeTy::Block;
};
Expand Down Expand Up @@ -563,11 +575,6 @@ struct Builder {
Expr MakeExpr(Args... args) {
return IRType::Make(args...);
}

template <typename IRType, typename... Args>
Stmt MakeStmt(Args... args) {
return IRType::Make(args...);
}
};

} // namespace ir
Expand Down
20 changes: 12 additions & 8 deletions cinn/ir/ir_mutator.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "cinn/ir/ir_mutator.h"

#include "cinn/ir/ir_printer.h"

namespace cinn {
namespace ir {

void IRMutator::Visit(const Expr *expr, Expr *op) { IRVisitorBase::Visit(expr, op); }

#define UNARY_OP_IMPL(op__) \
void IRMutator::Visit(const op__ *expr, Expr *op) { \
auto *node = op->As<op__>(); \
Expand All @@ -25,17 +29,19 @@ NODETY_BINARY_OP_FOR_EACH(BINARY_OP_IMPL)
void IRMutator::Visit(const IntImm *expr, Expr *op) {}
void IRMutator::Visit(const UIntImm *expr, Expr *op) {}
void IRMutator::Visit(const FloatImm *expr, Expr *op) {}
void IRMutator::Visit(const Cast *expr, Expr *op) {
auto *node = op->As<Cast>();
Visit(&node->v, &node->v);
}
void IRMutator::Visit(const For *expr, Expr *op) {
auto *node = op->As<For>();
IRVisitorBase::Visit(&node->min, &node->min);
IRVisitorBase::Visit(&node->extent, &node->extent);
Expr tmp(node->body);
IRVisitorBase::Visit(&node->body, &tmp);
IRVisitorBase::Visit(&node->body, &node->body);
}
void IRMutator::Visit(const PolyFor *expr, Expr *op) {
auto *node = op->As<PolyFor>();
Expr tmp(node->body);
IRVisitorBase::Visit(&node->body, &tmp);
IRVisitorBase::Visit(&node->body, &node->body);
IRVisitorBase::Visit(&node->condition, &node->condition);
IRVisitorBase::Visit(&node->inc, &node->inc);
}
Expand All @@ -56,15 +62,13 @@ void IRMutator::Visit(const IfThenElse *expr, Expr *op) {
void IRMutator::Visit(const Block *expr, Expr *op) {
auto *node = op->As<Block>();
for (auto &expr : node->stmts) {
Expr tmp(expr);
IRVisitorBase::Visit(&expr, &tmp);
IRVisitorBase::Visit(&expr, &expr);
}
}
void IRMutator::Visit(const Call *expr, Expr *op) {
auto *node = op->As<Call>();
for (auto &expr : node->args) {
Expr tmp(expr);
IRVisitorBase::Visit(&expr, &tmp);
IRVisitorBase::Visit(&expr, &expr);
}
}
void IRMutator::Visit(const Module *expr, Expr *op) {}
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace ir {

class IRMutator : public IRVisitorBase<void, Expr*> {
public:
void Visit(const Expr* expr, Expr* op) override;

#define __(op__) void Visit(const op__* expr, Expr* op) override;
NODETY_FORALL(__)
#undef __
Expand Down
4 changes: 4 additions & 0 deletions cinn/ir/ir_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ Expr operator>=(POD a, Expr b) {
}

//--
Expr operator+(Expr a, Expr b) { return Add::Make(a, b); }
Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); }
Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); }
Expr operator/(Expr a, Expr b) { return Div::Make(a, b); }
Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); }
Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); }

Expand Down
1 change: 0 additions & 1 deletion cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace cinn {
namespace ir {

void IrPrinter::Print(Expr e) { e.Accept(reinterpret_cast<IRVisitor *>(this)); }
void IrPrinter::Print(Stmt s) { s.Accept(reinterpret_cast<IRVisitor *>(this)); }
void IrPrinter::Print(const std::vector<Expr> &exprs, const std::string &splitter) {
for (int i = 0; i < exprs.size() - 1; i++) {
Print(exprs[i]);
Expand Down
2 changes: 0 additions & 2 deletions cinn/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ struct IrPrinter : public IRVisitor {

//! Emit an expression on the output stream.
void Print(Expr e);
//! Emit a statement on the output stream.
void Print(Stmt s);
//! Emit a expression list with , splitted.
void Print(const std::vector<Expr> &exprs, const std::string &splitter = ", ");
//! Emit a binary operator
Expand Down
6 changes: 1 addition & 5 deletions cinn/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,10 @@ struct IRVisitorBase {
LOG(FATAL) << "not supported NodeTy";
#undef __
}
return RetTy();
}
// @}

virtual RetTy Visit(const ir::Stmt* expr, Args... args) {
Expr tmp(*expr);
return Visit(&tmp, args...);
}

protected:
#define __(op__) virtual RetTy Visit(const ir::op__* op, Args... args) = 0;
NODETY_FORALL(__)
Expand Down
Loading

0 comments on commit 7cf7c03

Please sign in to comment.