Skip to content

Commit

Permalink
make simple lower test case works
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Feb 26, 2020
1 parent 99da6f8 commit a8167a1
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 27 deletions.
5 changes: 5 additions & 0 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ std::vector<const Expr *> PolyFor::expr_fields() const { return {&init, &conditi
bool Var::operator==(const Var &o) const { return o->name == operator->()->name; }
bool Var::operator!=(const Var &o) const { return !(*this == o); }

Var &Var::operator=(_Var_ *x) {
*this = Var(x);
return *this;
}

} // namespace ir

namespace common {
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ struct Var : public IrNodeRef {
bool operator==(const Var& o) const;
bool operator!=(const Var& o) const;

Var& operator=(_Var_* x);

const _Var_* operator->() const { return get(); }
_Var_* operator->() { return get(); }
const _Var_* get() const { return static_cast<const _Var_*>(ptr()); }
Expand Down
13 changes: 7 additions & 6 deletions cinn/ir/ir_operators.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once
#include "cinn/ir/ir.h"

namespace cinn {
Expand Down Expand Up @@ -80,12 +81,12 @@ Expr operator>=(POD a, Expr b) {
}

//--
Expr operator+(Expr a, Expr b) { return Add::Make(a, b); }
Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); }
Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); }
Expr operator/(Expr a, Expr b) { return Div::Make(a, b); }
Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); }
Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); }
inline Expr operator+(Expr a, Expr b) { return Add::Make(a, b); }
inline Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); }
inline Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); }
inline Expr operator/(Expr a, Expr b) { return Div::Make(a, b); }
inline Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); }
inline Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); }

} // namespace ir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/lang/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ endforeach()
cc_test(test_compute SRCS compute_test.cc DEPS core)
cc_test(test_placeholder SRCS placeholder_test.cc DEPS core)
cc_test(test_tensor SRCS tensor_test.cc DEPS core)
cc_test(test_lower SRCS lower_test.cc DEPS core)
1 change: 1 addition & 0 deletions cinn/lang/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/schedule.h"

Expand Down
21 changes: 17 additions & 4 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "cinn/lang/lower.h"

#include <map>
#include <set>

#include "cinn/ir/ir_printer.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/optim/replace_call_with_expr.h"
#include "cinn/poly/ast_gen.h"

Expand Down Expand Up @@ -78,19 +82,28 @@ std::vector<LoweredFunc> Lower(const std::string& name, const std::vector<Tensor
auto schedule = poly::CreateSchedule(stages);

// generate the expressions for each group.
std::vector<Expr> block;
std::vector<Expr> exprs;
CHECK_GT(schedule->gened_groups().size(), 0) << "no group is generated";
for (auto& group : schedule->gened_groups()) {
CHECK_GT(group.nodes.size(), 0) << "group is empty";
std::map<std::string, Expr> tuple_to_expr;
for (auto& node : group.nodes) {
auto& tensor = tensor_dic.at(node->id());
tuple_to_expr[tensor->name] = tensor->body();
auto& tensor = tensor_dic.at(node->id());
// NOTE here just schedule the compute node.
if (!tensor->is_compute_node()) continue;

tuple_to_expr[tensor->name] = tensor->tensor_store_expanded_body();
}

Expr group_expr = LowerGroup(group, tuple_to_expr);
VLOG(3) << "group expr: " << group_expr;
block.push_back(group_expr);
exprs.push_back(group_expr);
}

Expr block = ir::Block::Make(exprs);
// call passes
optim::RemoveNestedBlock(&block);

// prepare arguments
std::vector<Argument> arguments;
for (auto& arg : args) {
Expand Down
3 changes: 3 additions & 0 deletions cinn/lang/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
*/

#pragma once
#include <string>
#include <vector>

#include "cinn/ir/function.h"
#include "cinn/ir/ir.h"
#include "cinn/lang/module.h"
Expand Down
44 changes: 44 additions & 0 deletions cinn/lang/lower_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "cinn/lang/lower.h"

#include <gtest/gtest.h>

#include "cinn/lang/compute.h"
#include "cinn/lang/placeholder.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace lang {

TEST(lower, basic) {
const int M = 100;
const int N = 15;

Placeholder<float> A("A", {Expr(M), Expr(N)});

auto B = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B");

auto lower_funcs = Lower("cal_B", {A, B});

LOG(INFO) << "lower_size " << lower_funcs.size();

#define TEST_SOUTPUT(x, out) LOG(INFO) << "\n" << x; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));

auto out = R"ROC(
{
poly_for (0, (c1 <= 99), 1)
{
poly_for (0, (c3 <= 14), 1)
{
A(c1, c3)
B[((c1 * 15) + c3)] = (A(c1, c3) + 1)
}
}
}
)ROC";
TEST_SOUTPUT(lower_funcs.front().body, out);
}

} // namespace lang
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/lang/placeholder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Placeholder {
Expr operator()(const std::vector<Expr> &indices) const;
// @}

operator ir::Tensor() { return tensor_; }

private:
ir::Tensor tensor_;
};
Expand Down
44 changes: 44 additions & 0 deletions cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,44 @@
#include <cstring>

#include "cinn/common/common.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/ir/operation.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace ir {

namespace detail {

Expr ExpandTo1DIndice(const std::vector<Expr> &shape, const std::vector<Expr> &indices) {
CHECK_EQ(shape.size(), indices.size());
Expr res = indices.front() * shape[1];
for (int i = 1; i < shape.size() - 1; i++) {
res = res + indices[i] * shape[i + 1];
}
if (shape.size() > 1) res = res + indices.back();
return res;
}

Expr ExpandTo1DIndice(const std::vector<int> &shape, const std::vector<Expr> &indices) {
std::vector<Expr> shape_;
for (int v : shape) shape_.push_back(Expr(v));
return ExpandTo1DIndice(shape, indices);
}

} // namespace detail

Tensor _Tensor_::Make(const std::string &name, const std::vector<Expr> &shape, FunctionRef fn) {
CHECK(!shape.empty()) << "Tensor shape is set empty";
CHECK(!name.empty()) << "Tensor name is set empty";
auto n = make_shared<_Tensor_>();
n->name = name;
n->shape = shape;
n->operaion = fn;
n->InitStage();
n->InitAxis();
n->SetDefaultBindedBuffer();
return Tensor(n);
}

Expand All @@ -26,8 +51,13 @@ Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::map<std::string, IrNodeRef> &attrs,
const std::vector<Expr> &body) {
CHECK(!shape.empty()) << "Tensor shape is set empty";
CHECK(!name.empty()) << "Tensor name is set empty";

auto op = ComputeOp::Make(name, tag, attrs, axis, body, shape);
auto *compute_op = const_cast<ComputeOp *>(op->As<ComputeOp>());

CHECK_EQ(axis.size(), shape.size()) << "axis not match the dimension in shape";
compute_op->axis = axis;

auto n = make_shared<_Tensor_>();
Expand All @@ -36,6 +66,7 @@ Tensor _Tensor_::Make(const std::string &name,
n->shape = shape;
n->set_type(dtype);
n->InitStage();
n->SetDefaultBindedBuffer();
return Tensor(n);
}

Expand Down Expand Up @@ -76,6 +107,12 @@ void _Tensor_::InitStage() {
}
}

void _Tensor_::InitAxis() {
CHECK(!shape.empty());
CHECK(axis.empty()) << "duplicate init axis";
axis = common::GenDefaultAxis(shape.size());
}

isl::set _Tensor_::GenerateIslDomain() {
CHECK(!shape.empty()) << "shape should be set";
std::vector<poly::Dim> dims;
Expand Down Expand Up @@ -138,5 +175,12 @@ Expr _Tensor_::body() const {
NOT_IMPLEMENTED;
}

Expr _Tensor_::tensor_store_expanded_body() const {
CHECK(!is_placeholder_node()) << "placeholder should not expand store";
std::vector<Expr> axis_;
for (auto &a : axis) axis_.push_back(Expr(a));
return ir::Store::Make(buffer_var, body(), detail::ExpandTo1DIndice(shape, axis_));
}

} // namespace ir
} // namespace cinn
31 changes: 26 additions & 5 deletions cinn/lang/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace ir {
namespace detail {
constexpr bool LE(int a, int b) { return a <= b; }
constexpr bool GE(int a, int b) { return a >= b; }

//! Expand milti-dim indices to 1-dim index.
Expr ExpandTo1DIndice(const std::vector<int>& shape, const std::vector<Expr>& indices);
Expr ExpandTo1DIndice(const std::vector<Expr>& shape, const std::vector<Expr>& indices);

} // namespace detail

class _Tensor_;
Expand Down Expand Up @@ -54,11 +59,6 @@ class Tensor : public ir::IrNodeRef {
//! Get number of dimensions.
inline size_t ndims() const;

inline const _Tensor_* operator->() const { return As<_Tensor_>(); }
inline _Tensor_* operator->() { return As<_Tensor_>(); }

inline operator Expr() const { return Expr(get()); }

/**
* Take elements from the tensor.
* This take one or multiple expressions as indices.
Expand All @@ -82,6 +82,11 @@ class Tensor : public ir::IrNodeRef {
* @return The result expression representing a tensor read.
*/
Expr operator()(const std::vector<Expr>& indices) const;

inline const _Tensor_* operator->() const { return As<_Tensor_>(); }
inline _Tensor_* operator->() { return As<_Tensor_>(); }

inline operator Expr() const { return Expr(get()); }
};

/**
Expand All @@ -99,6 +104,8 @@ class _Tensor_ : public ExprNode<_Tensor_> {
std::string name;
//! Polyhedral element for analysis and schedule.
poly::Stage* stage{};
//! The binded buffer, for each tensor if it is not inline.
Var buffer_var;

//! Generate a tensor from a computation.
static Tensor Make(const std::string& name,
Expand All @@ -112,12 +119,19 @@ class _Tensor_ : public ExprNode<_Tensor_> {
//! Generate a tensor from a function.
static Tensor Make(const std::string& name, const std::vector<Expr>& shape, FunctionRef fn);

//! Tell the operation type.
// @{
bool is_compute_node() const;
bool is_placeholder_node() const;
const char* operation_type() const;
// @}

//! The expression generate this tensor, will be empty if it is a PlaceHolder.
Expr body() const;
//! Get the expression with `store(tensor)` inserted into the body.
Expr tensor_store_expanded_body() const;

Expr inline_expanded(const std::vector<Expr>& indices);

std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;
Expand All @@ -133,6 +147,13 @@ class _Tensor_ : public ExprNode<_Tensor_> {
//! It is based on the shape.
void InitStage();

//! Initialize the axis field after the shape field is assigned.
void InitAxis();

//! Bind the tensor to a buffer by default.
//! NOTE it should called by all the Make.
void SetDefaultBindedBuffer() { buffer_var = ir::_Var_::Make(name, type()).As<_Var_>(); }

isl::set GenerateIslDomain();
};

Expand Down
25 changes: 17 additions & 8 deletions cinn/poly/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ std::vector<Group> PartitionGraphByIterationDomain(common::Graph* graph) {
std::set<DataFlowGraphNode*> groups_gathered;
std::vector<DataFlowGraphNode*> groups_in_topo_order;

std::vector<common::GraphNode*> nodes;
std::vector<common::GraphEdge*> edges;
std::vector<common::GraphNode*> nodes_in_order;
std::vector<common::GraphEdge*> edges_in_order;
std::map<DataFlowGraphNode*, std::vector<DataFlowGraphNode*>> node_groups;
std::tie(nodes, edges) = graph->topological_order();
for (auto* n : nodes) {
std::tie(nodes_in_order, edges_in_order) = graph->topological_order();
for (auto* n : nodes_in_order) {
auto* node = n->As<DataFlowGraphNode>();
auto* ancestor = node->group_ancestor();
if (!groups_gathered.count(ancestor)) {
Expand All @@ -122,14 +122,22 @@ std::vector<Group> PartitionGraphByIterationDomain(common::Graph* graph) {

std::vector<Group> groups;
// preparing result
for (auto* n : groups_in_topo_order) {
for (auto* ancestor : groups_in_topo_order) {
Group group;
for (auto* c : node_groups[n]) {
for (auto* c : node_groups[ancestor]) {
group.nodes.push_back(c);
}
groups.emplace_back(group);
}

// NOTE DEBUG
// check there are same count of nodes both in the orginal graph and the groups.
// @{
int num_node_in_groups = 0;
for (auto& group : groups) num_node_in_groups += group.nodes.size();
CHECK_EQ(num_node_in_groups, graph->num_nodes());
// @}

return groups;
}

Expand All @@ -148,9 +156,10 @@ std::unique_ptr<common::Graph> CreateGraph(const std::vector<Stage*>& stages) {
}
}

auto* graph = new common::Graph;
std::unique_ptr<common::Graph> graph(new common::Graph);
for (auto& item : id2stage) graph->RegisterNode(item.first, item.second.get());
return std::unique_ptr<common::Graph>(graph);
VLOG(3) << "created graph:\n" << graph->Visualize();
return graph;
}

} // namespace poly
Expand Down
Loading

0 comments on commit a8167a1

Please sign in to comment.