Skip to content

Commit

Permalink
Enable tensor reshape (PaddlePaddle#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Aug 31, 2020
1 parent c1147e5 commit 18be7b3
Show file tree
Hide file tree
Showing 15 changed files with 229 additions and 138 deletions.
24 changes: 12 additions & 12 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ TEST(CodeGenC, matmul) {
#include <cinn_runtime.h>
#include <stdio.h>
cinn_buffer_t* _C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 50 });
cinn_buffer_t* _C_init = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 50 });
void matmul(void* _args, int32_t num_args)
{
const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
cinn_buffer_t* _C_init = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C_init);
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* C_init = ((float*)(_C->memory));
float* C = ((float*)(_C_init->memory));
float* C_init = ((float*)(_C_init->memory));
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 50; j += 1) {
C_init[((50 * i) + j)] = 0;
Expand All @@ -272,7 +272,7 @@ void matmul(void* _args, int32_t num_args)
};
};
};
cinn_buffer_free((void*)(0), _C);
cinn_buffer_free((void*)(0), _C_init);
}
void main(void* _args, int32_t num_args)
Expand Down Expand Up @@ -354,17 +354,17 @@ TEST(CodeGenC, matmul_tile) {
#include <cinn_runtime.h>
#include <stdio.h>
cinn_buffer_t* _C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 500 }, 32/*align*/);
cinn_buffer_t* _C_init = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 500 }, 32/*align*/);
void matmul(void* _args, int32_t num_args)
{
const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C);
cinn_buffer_t* _C_init = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2]));
cinn_buffer_malloc((void*)(0), _C_init);
const float* A = ((const float*)(_A->memory));
const float* B = ((const float*)(_B->memory));
float* C = ((float*)(_C->memory));
float* C_init = ((float*)(_C->memory));
float* C = ((float*)(_C_init->memory));
float* C_init = ((float*)(_C_init->memory));
for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) {
Expand All @@ -379,7 +379,7 @@ void matmul(void* _args, int32_t num_args)
};
};
};
cinn_buffer_free((void*)(0), _C);
cinn_buffer_free((void*)(0), _C_init);
}
)ROC";

Expand Down
5 changes: 3 additions & 2 deletions cinn/common/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <set>
#include <stack>

#include "cinn/common/common.h"
#include "cinn/utils/dot_lang.h"

namespace cinn {
Expand Down Expand Up @@ -54,7 +55,7 @@ std::vector<GraphNode *> Graph::nodes() {
return res;
}

std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topological_order() {
std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topological_order() const {
std::vector<GraphNode *> node_order;
std::vector<GraphEdge *> edge_order;
std::deque<GraphNode *> queue;
Expand All @@ -67,7 +68,7 @@ std::tuple<std::vector<GraphNode *>, std::vector<GraphEdge *>> Graph::topologica

// insert start points first.
for (auto *n : start_points()) {
queue.push_back(n);
queue.push_back(&Reference(n));
}

// start to visit
Expand Down
3 changes: 1 addition & 2 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class GraphNode : public Object {
}

void UnLinkTo(GraphNode* other) {
LOG(INFO) << "Unlink " << this->id() << " to " << other->id();
if (other == this) return;
// remove outlink
{
Expand Down Expand Up @@ -168,7 +167,7 @@ class Graph {
std::vector<GraphNode*> start_points();

//! Return the graph's nodes and edges(visited) in topological order.
std::tuple<std::vector<GraphNode*>, std::vector<GraphEdge*>> topological_order();
std::tuple<std::vector<GraphNode*>, std::vector<GraphEdge*>> topological_order() const;

//! Return the graph's DFS order.
std::vector<GraphNode*> dfs_order();
Expand Down
1 change: 1 addition & 0 deletions cinn/common/union_find.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace common {

struct UnionFindNode : public Object {
UnionFindNode* parent{};
std::string cluster_info;

std::tuple<UnionFindNode*, int /*height*/> GetRoot() {
auto* p = this;
Expand Down
85 changes: 35 additions & 50 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstring>

#include "cinn/cinn.h"
#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/ir_util.h"
Expand Down Expand Up @@ -29,7 +30,6 @@ Tensor _Tensor_::Make(const std::string &name,
n->reduce_axis = reduce_axis;
n->set_type(dtype);
n->operation = fn;
n->InitStage();
n->InitAxis();

return Tensor(n);
Expand Down Expand Up @@ -115,47 +115,6 @@ PlaceholderOp *_Tensor_::get_placeholder_op() const {
return operation->as<PlaceholderOp>();
}

void _Tensor_::InitStage() {
// Avoid duplicate init.
if (stage_shared) {
auto &shared_stage = *static_cast<Shared<poly::Stage> *>(stage_shared);
for (auto &depend : buffer_depended_tensor_names()) {
shared_stage->add_extra_depend_stage(depend);
}
return;
}

stage_shared = new Shared<poly::Stage>;
auto &shared_stage = *static_cast<Shared<poly::Stage> *>(stage_shared);
auto *op = operation->as<_Operation_>();
if (is_compute_node()) {
auto &body = op->as<ComputeOp>()->body;
CHECK_EQ(body.size(), 1UL) << "only support functional programming";
shared_stage = poly::Stage::New(GenerateIslDomain(), body.front(), this);
} else if (is_call_node()) {
if (!is_extern_call_node()) {
shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this);
} else {
shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this);
}
} else {
shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this);
}

shared_stage->set_extra_depend_stages(buffer_depended_tensor_names_);
auto depend_tensor_names = DependingTensorNames();
for (auto &x : depend_tensor_names) shared_stage->add_extra_depend_stage(x);
}

void _Tensor_::DropStage() {
if (stage_shared) {
delete static_cast<Shared<poly::Stage> *>(stage_shared);
stage_shared = nullptr;
}
}

bool _Tensor_::is_faked() const { return false; }

void _Tensor_::InitAxis() const {
// CHECK(!domain_without_reduce_axis().empty());
axis_ = common::GenDefaultAxis(domain_without_reduce_axis().size());
Expand Down Expand Up @@ -244,11 +203,7 @@ std::vector<const Expr *> _Tensor_::expr_fields() const {
return res;
}

_Tensor_::~_Tensor_() {
if (stage_shared) {
delete static_cast<Shared<poly::Stage> *>(stage_shared);
}
}
_Tensor_::~_Tensor_() {}

Expr _Tensor_::body() const {
if (is_placeholder_node()) return Expr();
Expand Down Expand Up @@ -315,9 +270,6 @@ void _Tensor_::Bind(lang::Buffer &buffer) {
CHECK(!buffer->binded_tensor_names().empty());
this->buffer = buffer.buffer();
CHECK(this->buffer.defined());

// Reset stage to nullptr to tell others this tensor should be inlined.
InitStage();
}

void _Tensor_::Bind(const Buffer &buffer) {
Expand Down Expand Up @@ -452,6 +404,39 @@ bool _Tensor_::Uses(const Tensor &other) {
return !loads.empty();
}

ir::Tensor _Tensor_::Reshape(const std::vector<Expr> &shape, poly::StageMap stages) const {
CHECK(!stages[this]->inlined());
auto op = BufferShareOp::Make();
auto n = make_shared<_Tensor_>();
auto selft = Tensor(const_cast<ir::_Tensor_ *>(this));

n->name = Context::Global().NewName(name + "_reshape");
n->shape = shape;
n->domain = shape;
n->set_type(type());
n->operation = op;
n->InitAxis();

auto t = Tensor(n);
stages->InsertLazily(t);

stages[n]->ShareBufferWith(stages[this]);
stages[n]->CtrlDepend(selft);
return t;
}

ir::Tensor _Tensor_::ReshapeCopied(const std::vector<Expr> &shape, poly::StageMap stages) const {
auto t = ir::Tensor(const_cast<ir::_Tensor_ *>(this));
auto copied = Compute(
domain,
[=](const std::vector<Expr> &axis) { return t(axis); },
Context::Global().NewName(this->name + "_copied"));
stages->InsertLazily(copied);
auto res = copied->Reshape(shape, stages);
stages->InsertLazily(res);
return res;
}

Shared<poly::Stage> CreateStage(Tensor tensor) {
return poly::Stage::New(tensor->GenerateIslDomain(), tensor->body(), tensor.self());
}
Expand Down
25 changes: 12 additions & 13 deletions cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ struct WriteCacheRelation;
* 2. never try to change a tensor's name, that will cause chaos.
*/
class _Tensor_ : public ExprNode<_Tensor_> {
//! a pointer to Shared<Stage>, use void* to avoid cyclic definition dependency.
void* stage_shared{};

public:
//! Shape of this tensor(buffer).
std::vector<Expr> shape;
Expand Down Expand Up @@ -151,6 +148,18 @@ class _Tensor_ : public ExprNode<_Tensor_> {
*/
std::set<std::string> DependingTensorNames();

/**
* Get a new tensor with the \p shape, but the underlying buffer shared.
* NOTE the tensor to Reshape should not be an inlined computation.
*/
ir::Tensor Reshape(const std::vector<Expr>& shape, poly::StageMap stages) const;

/**
* Get a new tensor with the \p shape with a newly allocated buffer.
* NOTE the tensor to Reshape should not be an inlined computation.
*/
ir::Tensor ReshapeCopied(const std::vector<Expr>& shape, poly::StageMap stages) const;

/**
* Tell whether this tensor has same shape with \p other.
*/
Expand Down Expand Up @@ -221,16 +230,6 @@ class _Tensor_ : public ExprNode<_Tensor_> {
void WithBuffer(const std::string& memory_type, const Type& type = Void());

private:
//! Create the polyhedral element for analysis.
//! It is based on the shape.
void InitStage();

//! Free the memory for stage.
void DropStage();

void FakeStage();
bool is_faked() const;

//! Initialize the axis field after the shape field is assigned.
void InitAxis() const;

Expand Down
Loading

0 comments on commit 18be7b3

Please sign in to comment.