Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#39 from Superjomn/fea/init-lower
Browse files Browse the repository at this point in the history
init lower
  • Loading branch information
Superjomn authored Feb 25, 2020
2 parents 484edd2 + 2115816 commit 99da6f8
Show file tree
Hide file tree
Showing 21 changed files with 317 additions and 93 deletions.
1 change: 0 additions & 1 deletion cinn/arithmetic/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

3 changes: 2 additions & 1 deletion cinn/common/context.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include <isl/cpp.h>

#include <set>
#include <string>
#include <vector>

Expand Down Expand Up @@ -35,7 +36,7 @@ class Context {
/**
* The global isl ctx.
*/
isl::ctx& isl_ctx() { return ctx_; }
isl::ctx isl_ctx() { return ctx_; }

private:
Context() : ctx_(isl_ctx_alloc()) {}
Expand Down
25 changes: 24 additions & 1 deletion cinn/common/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,33 @@ std::tuple<Graph::node_order_t, Graph::edge_order_t> TopologicalSort(const std::

void DFSSortUtil(const GraphNode *node, std::vector<GraphNode *> *order) {}

std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {}
std::vector<GraphNode *> DFSSort(const std::vector<GraphNode *> &nodes) {
LOG(FATAL) << "not implemented";
return {};
}

} // namespace

std::set<GraphNode *> Graph::dependencies(const std::vector<GraphNode *> &targets) {
// A naive implementation.
std::set<GraphNode *> _targets(targets.begin(), targets.end());
std::set<GraphNode *> res;
int targets_count = 0;
while (targets_count != _targets.size()) {
targets_count = _targets.size();
for (auto *node : nodes()) {
if (_targets.count(node)) continue;
for (auto &edge : node->outlinks()) {
if (_targets.count(edge->sink())) {
res.insert(edge->sink());
_targets.insert(edge->sink());
}
}
}
}
return res;
}

std::vector<const GraphNode *> Graph::nodes() const {
std::vector<const GraphNode *> res;
for (auto &s : nodes_) res.push_back(s.get());
Expand Down
14 changes: 14 additions & 0 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

#include <glog/logging.h>

#include <functional>
#include <list>
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/common/object.h"
Expand Down Expand Up @@ -119,6 +122,9 @@ class Graph {
//! Return the graph's DFS order.
std::vector<GraphNode*> dfs_order();

//! Return the dependency nodes of a set of nodes.
std::set<GraphNode*> dependencies(const std::vector<GraphNode*>& nodes);

std::vector<const GraphNode*> nodes() const;
std::vector<GraphNode*> nodes();

Expand All @@ -136,3 +142,11 @@ class Graph {

} // namespace common
} // namespace cinn

namespace std {
template <>
struct hash<cinn::common::GraphNode> {
size_t operator()(const cinn::common::GraphNode& x) { return reinterpret_cast<size_t>(hash<std::string>()(x.id())); }
};

} // namespace std
2 changes: 2 additions & 0 deletions cinn/common/shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct Shared {
inline T* get() const { return p_; }
inline T& operator*() const { return *p_; }
inline T* operator->() const { return p_; }
inline T* self() { return p_; }
inline const T* self() const { return p_; }
// @}

inline bool defined() const { return p_; }
Expand Down
2 changes: 1 addition & 1 deletion cinn/lang/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(srcs buffer.cc compute.cc placeholder.cc tensor.cc module.cc)
set(srcs buffer.cc compute.cc placeholder.cc tensor.cc module.cc lower.cc)

foreach(cpp ${srcs})
set(core_src
Expand Down
104 changes: 104 additions & 0 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include "cinn/lang/lower.h"

#include "cinn/ir/ir_printer.h"
#include "cinn/optim/replace_call_with_expr.h"
#include "cinn/poly/ast_gen.h"

namespace cinn {
namespace lang {

using ir::Tensor;
using poly::Stage;

Expr LowerGroup(const poly::detail::Group& group, const std::map<std::string, Expr>& tuple_to_expr) {
std::vector<poly::Stage*> stages;
for (auto& node : group.nodes) {
stages.push_back(node->stage.get());
}

poly::PolyScheduler scheduler(stages);
// TODO Schedule it.
scheduler.BuildSchedule();

isl::set context(Context::Global().isl_ctx(), "{:}");
poly::AstGen gen(context, stages, scheduler);
isl::ast_node ast = gen.Build();

ir::Expr e;
poly::IslAstNodeToCinnExpr(ast, &e);

for (auto& statement : tuple_to_expr) {
auto axis_ast_map = gen.axis2ast(statement.first);
Expr statement_candi_expr = tuple_to_expr.at(statement.first);

std::map<std::string, Expr> axis;
for (auto& item : axis_ast_map) {
poly::IslAstExprToCinnExpr(item.second, &axis[item.first]);
}
optim::ReplaceCallWithExpr(&e, statement.first, statement_candi_expr, axis);
}

return e;
}

std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args) {
// make sure the graph's start-points in the args.

auto stages = poly::GatherStagesInTensors(args);
auto graph = poly::CreateGraph(stages);

// Create a dic for stages and tensors.
std::map<std::string, Stage*> stage_dic;
std::map<std::string, Tensor> tensor_dic;
for (auto& tensor : args) tensor_dic.emplace(tensor->name, tensor);
for (auto& stage : stages) stage_dic.emplace(stage->id(), stage);
CHECK_EQ(tensor_dic.size(), stage_dic.size());
CHECK_EQ(args.size(), stage_dic.size()) << "tensor should duplicate name";

std::set<std::string> args_names;
for (auto& arg : args) {
args_names.insert(arg->name);
}
CHECK_EQ(args.size(), args_names.size()) << "Tensor should have unique name";

// collect the graph nodes of `args`
std::vector<common::GraphNode*> input_graph_nodes;
for (auto& node : graph->nodes()) {
if (args_names.count(node->id())) {
input_graph_nodes.push_back(node);
}
}

auto depend_node_set = graph->dependencies(input_graph_nodes);
// collect start points in the depend_node_set
for (auto& node : depend_node_set) {
CHECK(args_names.count(node->id())) << "The dependency tensor [" << node->id() << "] not in the inputs";
}

auto schedule = poly::CreateSchedule(stages);

// generate the expressions for each group.
std::vector<Expr> block;
for (auto& group : schedule->gened_groups()) {
std::map<std::string, Expr> tuple_to_expr;
for (auto& node : group.nodes) {
auto& tensor = tensor_dic.at(node->id());
tuple_to_expr[tensor->name] = tensor->body();
}

Expr group_expr = LowerGroup(group, tuple_to_expr);
VLOG(3) << "group expr: " << group_expr;
block.push_back(group_expr);
}

// prepare arguments
std::vector<Argument> arguments;
for (auto& arg : args) {
arguments.emplace_back(arg->name, Argument::Kind::kBuffer, arg->type(), arg->shape.size());
}

return {LoweredFunc(name, arguments, block)};
}

} // namespace lang
} // namespace cinn
19 changes: 19 additions & 0 deletions cinn/lang/lower.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/**
* Lower lowerise the statements to LoweredFuncs.
*/

#pragma once
#include "cinn/ir/function.h"
#include "cinn/ir/ir.h"
#include "cinn/lang/module.h"
#include "cinn/lang/tensor.h"
#include "cinn/poly/schedule.h"

namespace cinn {
namespace lang {
using ir::Tensor;

std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args);

} // namespace lang
} // namespace cinn
6 changes: 6 additions & 0 deletions cinn/lang/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,11 @@ void Module::Append(const Module &module) { self()->submodules.push_back(module)

void Module::Compile(const backends::Outputs &outputs) const {}

LoweredFunc::LoweredFunc(const std::string &name, const std::vector<Argument> &args, const std::vector<Expr> &body) {
this->name = name;
this->args = args;
this->body = ir::Block::Make(body);
}

} // namespace lang
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/lang/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ struct LoweredFunc {

LoweredFunc(const std::string& name, const std::vector<Argument>& args, const Expr& body)
: name(name), args(args), body(body) {}
LoweredFunc(const std::string& name, const std::vector<Argument>& args, const std::vector<Expr>& body);
};

} // namespace lang
Expand Down
8 changes: 8 additions & 0 deletions cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ std::vector<Expr *> _Tensor_::expr_fields() {
}
return res;
}

std::vector<const Expr *> _Tensor_::expr_fields() const {
std::vector<const Expr *> res;
const char *func_type = operaion->As<ir::_Operation_>()->func_type();
Expand All @@ -125,10 +126,17 @@ std::vector<const Expr *> _Tensor_::expr_fields() const {
_Tensor_::~_Tensor_() {
if (stage) {
delete stage;
stage = nullptr;
}
}

const _Operation_ *Operation::operator->() const { return static_cast<_Operation_ *>(get()); }

Expr _Tensor_::body() const {
if (is_placeholder_node()) return Expr();
if (is_compute_node()) return operaion->As<ir::ComputeOp>()->body.front();
NOT_IMPLEMENTED;
}

} // namespace ir
} // namespace cinn
16 changes: 16 additions & 0 deletions cinn/lang/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "cinn/common/graph_utils.h"
Expand Down Expand Up @@ -115,6 +116,9 @@ class _Tensor_ : public ExprNode<_Tensor_> {
bool is_placeholder_node() const;
const char* operation_type() const;

//! The expression generate this tensor, will be empty if it is a PlaceHolder.
Expr body() const;

std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;

Expand Down Expand Up @@ -163,3 +167,15 @@ class _Operation_ : public ir::FunctionBase {

} // namespace ir
} // namespace cinn

namespace std {

template <>
struct hash<cinn::ir::Tensor> {
inline size_t operator()(const cinn::ir::Tensor& x) {
// We treat the tensor's name as the unique identifier.
return std::hash<std::string>()(x->name);
}
};

} // namespace std
8 changes: 4 additions & 4 deletions cinn/optim/replace_call_with_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ TEST(ReplaceCallWithExpr, basic) {
tuple_to_expr["A"] = ir::Store::Make(A_buf, A_value, Expr(i) * 100 * 100 + Expr(j) * 100 + Expr(k));
tuple_to_expr["B"] = ir::Store::Make(A_buf, B_value, Expr(i) * 100 * 100 + Expr(j) * 100 + Expr(k));

isl::ctx ctx(isl_ctx_alloc());
auto *A = make_shared<Stage>(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
auto *B = make_shared<Stage>(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));
isl::ctx ctx = Context::Global().isl_ctx();
auto *A = make_shared<Stage>(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
auto *B = make_shared<Stage>(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));

Iterator A_i0, A_i1;
Iterator B_i0, B_i1;

std::tie(A_i0, A_i1) = A->Split(Iterator("i"), 4);
std::tie(B_i0, B_i1) = B->Split(Iterator("i"), 4);

Scheduler scheduler;
PolyScheduler scheduler;
scheduler.AddStage(*A);
scheduler.AddStage(*B);
scheduler.After(*A, *B, 3);
Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ void AstGen::InitIslAstConfig() {
isl_options_set_ast_build_allow_else(ctx().get(), 1);
}

AstGen::AstGen(const isl::set& context, const std::vector<Stage*>& stages, const Scheduler& scheduler)
AstGen::AstGen(const isl::set& context, const std::vector<Stage*>& stages, const PolyScheduler& scheduler)
: context_(context), scheduler_(scheduler) {
for (auto* x : stages) stages_.emplace_back(x);
InitIslAstConfig();
Expand Down
4 changes: 2 additions & 2 deletions cinn/poly/ast_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace poly {
*/
class AstGen {
public:
AstGen(const isl::set& context, const std::vector<Stage*>& stages, const Scheduler& scheduler);
AstGen(const isl::set& context, const std::vector<Stage*>& stages, const PolyScheduler& scheduler);

/**
* Set forloop iterator names.
Expand Down Expand Up @@ -61,7 +61,7 @@ class AstGen {
private:
isl::set context_;
std::vector<Shared<Stage>> stages_;
const Scheduler& scheduler_;
const PolyScheduler& scheduler_;
std::vector<std::string> iterator_names_;
//! tuple name -> { axis -> isl_ast }
std::map<std::string, std::map<std::string, isl::ast_expr>> transformed_indice_map_;
Expand Down
8 changes: 4 additions & 4 deletions cinn/poly/ast_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ namespace cinn {
namespace poly {

TEST(ast_gen, basic) {
isl::ctx ctx(isl_ctx_alloc());
auto* A = make_shared<Stage>(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
auto* B = make_shared<Stage>(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));
isl::ctx ctx = Context::Global().isl_ctx();
auto* A = make_shared<Stage>(isl::set(ctx, "{ A[i,j,k]: 0<i,j,k<100 }"));
auto* B = make_shared<Stage>(isl::set(ctx, "{ B[i,j,k]: 0<i,j,k<100 }"));

Iterator A_i0, A_i1;
Iterator B_i0, B_i1;

std::tie(A_i0, A_i1) = A->Split(Iterator("i"), 4);
std::tie(B_i0, B_i1) = B->Split(Iterator("i"), 4);

Scheduler scheduler;
PolyScheduler scheduler;
scheduler.AddStage(*A);
scheduler.AddStage(*B);
scheduler.After(*A, *B, 3);
Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct Domain {
//! The ISL context.
isl::ctx ctx;

Domain(isl::ctx, std::string id, std::vector<Dim> dims) : ctx(ctx), id(std::move(id)), dims(std::move(dims)) {}
Domain(isl::ctx ctx, std::string id, std::vector<Dim> dims) : ctx(ctx), id(std::move(id)), dims(std::move(dims)) {}

//! The ISL format representation, such as '{ S[i]: 0<=i<=20 }'.
std::string __str__() const;
Expand Down
Loading

0 comments on commit 99da6f8

Please sign in to comment.