Skip to content

Commit

Permalink
run test
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Sep 20, 2023
1 parent 96ce84d commit 8950eaf
Show file tree
Hide file tree
Showing 19 changed files with 321 additions and 174 deletions.
21 changes: 15 additions & 6 deletions paddle/cinn/pybind/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@ namespace cinn {
namespace pybind {
void TensorStore(Expr tensor, Expr value, const std::vector<Expr>& indices) {
// TODO(6clc): Check the compatibility of data types for tensor and value
IRContext find_sch_block = IRBuilder::CurrentIRBuilder().data_->FindContext<ScheduleBlockContextNode>();
if( !find_sch_block.data_.defined()){
IRContext sch_block(new ScheduleBlockContextNode());
sch_block.data_->EnterWithContext();
LinkToParentContext(ir::Store::Make(tensor, value, indices));
sch_block.data_->ExitWithContext();
return;
}
LinkToParentContext(ir::Store::Make(tensor, value, indices));
}
std::vector<Var> AxisMap(std::string kinds, std::vector<Expr> iter_expression) {
std::vector<Var> rets;
std::vector<Expr> AxisMap(const std::string& kinds, const std::vector<Expr>& iter_expression) {
std::vector<Expr> rets;
CHECK_EQ(kinds.size(), iter_expression.size());
int n = iter_expression.size();
rets.reserve(n);
Expand All @@ -16,7 +24,7 @@ std::vector<Var> AxisMap(std::string kinds, std::vector<Expr> iter_expression) {

// TODO(6clc): set bound of IterVar

Var iter_var = ir::_Var_::Make("", common::Int(32));
Var iter_var = ir::_Var_::Make("iter_tmp", common::Int(32));
if (c == 'S') {
iter_var->is_reduce_axis = false;
} else if (c == 'R') {
Expand All @@ -27,6 +35,7 @@ std::vector<Var> AxisMap(std::string kinds, std::vector<Expr> iter_expression) {
}
rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i]));
}
return rets;
}
Var SetScheduleBlockIterVar(Var iter_var, Expr expr) {
IRContext cur_context =
Expand All @@ -35,10 +44,10 @@ Var SetScheduleBlockIterVar(Var iter_var, Expr expr) {
cur_context.As<ScheduleBlockContextNode>();
cur_context_node->iter_vars.push_back(iter_var);
cur_context_node->iter_values.push_back(expr);
return iter_var;
return iter_var.operator Expr();
}

Expr Arg(std::string name, Var var) {
Expr Arg(const std::string &name, Var var) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
var->name = name;
Expand All @@ -47,7 +56,7 @@ Expr Arg(std::string name, Var var) {
return var.operator Expr();
}

Expr Arg(std::string name, ir::Buffer buffer) {
Expr Arg(const std::string &name, ir::Buffer buffer) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
buffer->name = "_" + name;
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/pybind/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ namespace pybind {

template IRContext IRBuilderNode::GetLastContext<ScheduleBlockContextNode>() const;
Var SetScheduleBlockIterVar(Var iter_var, Expr expr);
std::vector<Var> AxisMap(std::string kinds, std::vector<Expr> iter_expression);
std::vector<Expr> AxisMap(const std::string& kinds, const std::vector<Expr>& iter_expression);
void TensorStore(Expr tensor, Expr value, const std::vector<Expr> &indices);
Expr Arg(std::string name, Var var);
Expr Arg(std::string name, ir::Buffer buffer);
Expr Arg(const std::string &name, Var var);
Expr Arg(const std::string &name, ir::Buffer buffer);
IRContext Sequential(Expr min, Expr extent);
} // namespace pybind
} // namespace cinn
36 changes: 24 additions & 12 deletions paddle/cinn/pybind/ir/ir_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ void BindLoweredFunc(py::module *m) {
[](const ir::LoweredFunc &self) -> std::string {
return utils::GetStreamCnt(Expr(self));
})
.def("__repr__", [](const ir::LoweredFunc &self) -> std::string {
return llvm::formatv(
"<LoweredFunc {0}>", self.get(), self->name.c_str());
});
.def("__repr__",
[](const ir::LoweredFunc &self) -> std::string {
return llvm::formatv(
"<LoweredFunc {0}>", self.get(), self->name.c_str());
})
.def("body", [](const ir::LoweredFunc &self) { return self->body; });
}

void BindNode(py::module *m) {
Expand Down Expand Up @@ -642,6 +644,8 @@ void BindIrTensor(py::module *m) {
[](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) {
return self(a, b, c, d);
})
.def("__getitem__",
[](ir::Tensor &self, std::vector<Expr> idx) { return self(idx); })
.def("Expr", [](ir::Tensor &self) { return self.operator Expr(); });

DefineExprNode<ir::_Tensor_>(m, "_Tensor_");
Expand Down Expand Up @@ -825,9 +829,20 @@ void BindIrContext(py::module *m) {
[](IRContext &self) {
return self.data_->safe_as<ForContextNode>()->loop_var;
})
.def_static("MakeLowerFunctionContext", [](std::string &name) {
return IRContext(new LowerFuncContextNode(name));
});
.def_static("MakeLowerFunctionContext",
[](std::string &name) {
return IRContext(new LowerFuncContextNode(name));
})
.def_static("MakeScheduleBlockContext",
[](std::string &name) {
return IRContext(new ScheduleBlockContextNode(name));
})
.def_static("MakeIfContext",
[](Expr expr) { return IRContext(new IfContextNode(expr)); })
.def_static("MakeElseContext",
[]() { return IRContext(new ElseContextNode()); })
.def_static("MakeThenContext",
[]() { return IRContext(new ThenContextNode()); });

py::class_<IRBuilder> ir_builder(*m, "IRBuilder");
ir_builder.def(py::init<>())
Expand All @@ -837,13 +852,10 @@ void BindIrContext(py::module *m) {
return self.data_->GetResult().as_lowered_func_ref();
});

py::class_<ScheduleBlockContextNode> sch_block_ctx(*m,
"ScheduleBlockContext");

m->def("AxisMap", &AxisMap);
m->def("TensorStore", &TensorStore);
m->def("Arg", py::overload_cast<std::string, Var>(&Arg));
m->def("Arg", py::overload_cast<std::string, ir::Buffer>(&Arg));
m->def("Arg", py::overload_cast<const std::string &, Var>(&Arg));
m->def("Arg", py::overload_cast<const std::string &, ir::Buffer>(&Arg));
m->def("Sequential", py::overload_cast<Expr, Expr>(&Sequential));
}
} // namespace
Expand Down
46 changes: 39 additions & 7 deletions paddle/cinn/pybind/ir/ir_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,61 @@ void ScheduleBlockContextNode::ExitWithContext() {

void ForContextNode::ExitWithContext() {
IRContextNode::ExitWithContext();
LinkToParentContext(
ir::For::Make(loop_var, min, extent, ir::ForType::Serial, ir::DeviceAPI::UNK, ir::Block::Make(exprs)));
LinkToParentContext(ir::For::Make(loop_var,
min,
extent,
ir::ForType::Serial,
ir::DeviceAPI::UNK,
ir::Block::Make(exprs)));
}

void LowerFuncContextNode::ExitWithContext() {
IRContextNode::ExitWithContext();
// TODO(6clc): implement Private Fields for intrinstic function, like
// allreduce
Expr body = ir::ScheduleBlockRealize::Make(
{}, ir::ScheduleBlock::Make({},{}, {}, "root", ir::Block::Make(exprs))
);
ir::LoweredFunc lower_func =
ir::_LoweredFunc_::Make(name, args, ir::Block::Make(exprs));
ir::_LoweredFunc_::Make(name, args, ir::Block::Make({body}));
IRBuilder ir_builder = IRBuilder::CurrentIRBuilder();
ir_builder.data_->result = lower_func.operator Expr();
}

void IfContextNode::ExitWithContext() {
IRContextNode::ExitWithContext();
if (!exprs.empty()) {
LOG(FATAL) << "Expr not be either in ThenBlock or ElseBlock in if";
}
if (!true_case.defined()){
LOG(FATAL) << "Expr not be defined in ThenBlock";
}
LinkToParentContext(ir::IfThenElse::Make(condition, true_case, false_case));
}


void ThenContextNode::ExitWithContext(){
IRContextNode::ExitWithContext();
IRContext for_ctx = IRBuilder::CurrentIRBuilder().data_->GetLastContext<IfContextNode>();
for_ctx.data_->safe_as<IfContextNode>()->true_case = ir::Block::Make(exprs);
}

void ElseContextNode::ExitWithContext(){
IRContextNode::ExitWithContext();
IRContext for_ctx = IRBuilder::CurrentIRBuilder().data_->GetLastContext<IfContextNode>();
for_ctx.data_->safe_as<IfContextNode>()->false_case = ir::Block::Make(exprs);

}

Expr IRBuilderNode::GetResult() const {
CHECK(result.defined()) << "No result generated in IRBuilder";
return result;
}

void IRBuilderNode::Reset(){
contexts.clear();
result.Reset();
}
void IRBuilderNode::Reset() {
contexts.clear();
result.Reset();
}

IRBuilder::IRBuilder() {
common::Shared<IRBuilderNode> n(new IRBuilderNode());
Expand Down
52 changes: 49 additions & 3 deletions paddle/cinn/pybind/ir/ir_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/utils/error.h"

namespace cinn {
namespace pybind {
Expand Down Expand Up @@ -59,8 +60,15 @@ class IRContext {
CHECK(data_.get()) << "IrContext holds null";
auto* ctx_node = data_.get()->safe_as<TIRContextNode>();
if (!ctx_node) {
LOG(FATAL) << "TypeConvertError: convert " << data_.get()->type_info()
<< " to " << TIRContextNode::__type_info__;
// TODO(6clc):
std::stringstream err_msg;
err_msg << "TypeConvertError: convert " << data_.get()->type_info()
<< " to " << TIRContextNode::__type_info__;

CINN_THROW(err_msg.str());
// CINN_THROW(...) << "TypeConvertError: convert " <<
// data_.get()->type_info()
// << " to " << TIRContextNode::__type_info__;
}
return ctx_node;
}
Expand Down Expand Up @@ -124,7 +132,6 @@ class ForContextNode : public IRContextNode {
static constexpr const char* __type_info__ = "ForContextNode";
};


class LowerFuncContextNode : public IRContextNode {
public:
//! The name of this function.
Expand All @@ -142,6 +149,44 @@ class LowerFuncContextNode : public IRContextNode {
static constexpr const char* __type_info__ = "LowerFuncContextNode";
};

class IfContextNode : public IRContextNode {
public:
Expr condition;
Expr true_case;
Expr false_case;

public:
IfContextNode() = default;
IfContextNode(Expr condition)
: condition(condition), true_case(Expr()), false_case(Expr()) {}
const char* type_info() const override { return __type_info__; }

void ExitWithContext() final;

public:
static constexpr const char* __type_info__ = "IfContextNode";
};

class ThenContextNode : public IRContextNode {
public:
ThenContextNode() = default;
const char* type_info() const override { return __type_info__; }

void ExitWithContext() final;

public:
static constexpr const char* __type_info__ = "ThenContextNode";
};

class ElseContextNode : public IRContextNode {
public:
ElseContextNode() = default;
const char* type_info() const override { return __type_info__; }
void ExitWithContext() final;

public:
static constexpr const char* __type_info__ = "ElseContextNode";
};

class IRBuilderNode : public common::Object {
public:
Expand Down Expand Up @@ -190,6 +235,7 @@ IRContext IRBuilderNode::FindContext() const {
return *it;
}
}
return IRContext();
}

} // namespace pybind
Expand Down
3 changes: 2 additions & 1 deletion python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
if isinstance(fn, CinnLowerLevelIrJit):
llir_func = ast_to_llir(fn, jit_inputs_signature)
else:
raise Exception("Current Only support compile from CinnLowerLevelIrJit")
raise Exception(
"Current Only support compile from CinnLowerLevelIrJit")

if just_convert:
return llir_func
Expand Down
Loading

0 comments on commit 8950eaf

Please sign in to comment.