diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 904a39e6a332f..888b1c4fe40ac 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -253,17 +253,17 @@ TEST(CodeGenC, matmul) { #include #include -cinn_buffer_t* _C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 50 }); +cinn_buffer_t* _C_init = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 50 }); void matmul(void* _args, int32_t num_args) { const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); - cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); - cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_t* _C_init = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C_init); const float* A = ((const float*)(_A->memory)); const float* B = ((const float*)(_B->memory)); - float* C = ((float*)(_C->memory)); - float* C_init = ((float*)(_C->memory)); + float* C = ((float*)(_C_init->memory)); + float* C_init = ((float*)(_C_init->memory)); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j = 0; j < 50; j += 1) { C_init[((50 * i) + j)] = 0; @@ -272,7 +272,7 @@ void matmul(void* _args, int32_t num_args) }; }; }; - cinn_buffer_free((void*)(0), _C); + cinn_buffer_free((void*)(0), _C_init); } void main(void* _args, int32_t num_args) @@ -354,17 +354,17 @@ TEST(CodeGenC, matmul_tile) { #include #include -cinn_buffer_t* _C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 500 }, 32/*align*/); +cinn_buffer_t* _C_init = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 100, 500 }, 32/*align*/); void matmul(void* _args, int32_t num_args) { const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); - cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); - cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_t* _C_init = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C_init); const float* A = ((const float*)(_A->memory)); const float* B = ((const float*)(_B->memory)); - float* C = ((float*)(_C->memory)); - float* C_init = ((float*)(_C->memory)); + float* C = ((float*)(_C_init->memory)); + float* C_init = ((float*)(_C_init->memory)); for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) { @@ -379,7 +379,7 @@ void matmul(void* _args, int32_t num_args) }; }; }; - cinn_buffer_free((void*)(0), _C); + cinn_buffer_free((void*)(0), _C_init); } )ROC"; diff --git a/cinn/common/graph_utils.cc b/cinn/common/graph_utils.cc index c055e74ab1681..e7765d923cb04 100644 --- a/cinn/common/graph_utils.cc +++ b/cinn/common/graph_utils.cc @@ -7,6 +7,7 @@ #include #include +#include "cinn/common/common.h" #include "cinn/utils/dot_lang.h" namespace cinn { @@ -54,7 +55,7 @@ std::vector Graph::nodes() { return res; } -std::tuple, std::vector> Graph::topological_order() { +std::tuple, std::vector> Graph::topological_order() const { std::vector node_order; std::vector edge_order; std::deque queue; @@ -67,7 +68,7 @@ std::tuple, std::vector> Graph::topologica // insert start points first. for (auto *n : start_points()) { - queue.push_back(n); + queue.push_back(&Reference(n)); } // start to visit diff --git a/cinn/common/graph_utils.h b/cinn/common/graph_utils.h index 1f3b6cb7e979f..4f72df2f3ff58 100644 --- a/cinn/common/graph_utils.h +++ b/cinn/common/graph_utils.h @@ -86,7 +86,6 @@ class GraphNode : public Object { } void UnLinkTo(GraphNode* other) { - LOG(INFO) << "Unlink " << this->id() << " to " << other->id(); if (other == this) return; // remove outlink { @@ -168,7 +167,7 @@ class Graph { std::vector start_points(); //! Return the graph's nodes and edges(visited) in topological order. - std::tuple, std::vector> topological_order(); + std::tuple, std::vector> topological_order() const; //! Return the graph's DFS order. std::vector dfs_order(); diff --git a/cinn/common/union_find.h b/cinn/common/union_find.h index bf90d008b8e64..a55513dbb63c2 100644 --- a/cinn/common/union_find.h +++ b/cinn/common/union_find.h @@ -16,6 +16,7 @@ namespace common { struct UnionFindNode : public Object { UnionFindNode* parent{}; + std::string cluster_info; std::tuple GetRoot() { auto* p = this; diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc index d36ced8745aed..b6440774ceda9 100644 --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -2,6 +2,7 @@ #include +#include "cinn/cinn.h" #include "cinn/common/cas.h" #include "cinn/common/common.h" #include "cinn/common/ir_util.h" @@ -29,7 +30,6 @@ Tensor _Tensor_::Make(const std::string &name, n->reduce_axis = reduce_axis; n->set_type(dtype); n->operation = fn; - n->InitStage(); n->InitAxis(); return Tensor(n); @@ -115,47 +115,6 @@ PlaceholderOp *_Tensor_::get_placeholder_op() const { return operation->as(); } -void _Tensor_::InitStage() { - // Avoid duplicate init. - if (stage_shared) { - auto &shared_stage = *static_cast *>(stage_shared); - for (auto &depend : buffer_depended_tensor_names()) { - shared_stage->add_extra_depend_stage(depend); - } - return; - } - - stage_shared = new Shared; - auto &shared_stage = *static_cast *>(stage_shared); - auto *op = operation->as<_Operation_>(); - if (is_compute_node()) { - auto &body = op->as()->body; - CHECK_EQ(body.size(), 1UL) << "only support functional programming"; - shared_stage = poly::Stage::New(GenerateIslDomain(), body.front(), this); - } else if (is_call_node()) { - if (!is_extern_call_node()) { - shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this); - } else { - shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this); - } - } else { - shared_stage = poly::Stage::New(GenerateIslDomain(), body(), this); - } - - shared_stage->set_extra_depend_stages(buffer_depended_tensor_names_); - auto depend_tensor_names = DependingTensorNames(); - for (auto &x : depend_tensor_names) shared_stage->add_extra_depend_stage(x); -} - -void _Tensor_::DropStage() { - if (stage_shared) { - delete static_cast *>(stage_shared); - stage_shared = nullptr; - } -} - -bool _Tensor_::is_faked() const { return false; } - void _Tensor_::InitAxis() const { // CHECK(!domain_without_reduce_axis().empty()); axis_ = common::GenDefaultAxis(domain_without_reduce_axis().size()); @@ -244,11 +203,7 @@ std::vector _Tensor_::expr_fields() const { return res; } -_Tensor_::~_Tensor_() { - if (stage_shared) { - delete static_cast *>(stage_shared); - } -} +_Tensor_::~_Tensor_() {} Expr _Tensor_::body() const { if (is_placeholder_node()) return Expr(); @@ -315,9 +270,6 @@ void _Tensor_::Bind(lang::Buffer &buffer) { CHECK(!buffer->binded_tensor_names().empty()); this->buffer = buffer.buffer(); CHECK(this->buffer.defined()); - - // Reset stage to nullptr to tell others this tensor should be inlined. - InitStage(); } void _Tensor_::Bind(const Buffer &buffer) { @@ -452,6 +404,39 @@ bool _Tensor_::Uses(const Tensor &other) { return !loads.empty(); } +ir::Tensor _Tensor_::Reshape(const std::vector &shape, poly::StageMap stages) const { + CHECK(!stages[this]->inlined()); + auto op = BufferShareOp::Make(); + auto n = make_shared<_Tensor_>(); + auto selft = Tensor(const_cast(this)); + + n->name = Context::Global().NewName(name + "_reshape"); + n->shape = shape; + n->domain = shape; + n->set_type(type()); + n->operation = op; + n->InitAxis(); + + auto t = Tensor(n); + stages->InsertLazily(t); + + stages[n]->ShareBufferWith(stages[this]); + stages[n]->CtrlDepend(selft); + return t; +} + +ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, poly::StageMap stages) const { + auto t = ir::Tensor(const_cast(this)); + auto copied = Compute( + domain, + [=](const std::vector &axis) { return t(axis); }, + Context::Global().NewName(this->name + "_copied")); + stages->InsertLazily(copied); + auto res = copied->Reshape(shape, stages); + stages->InsertLazily(res); + return res; +} + Shared CreateStage(Tensor tensor) { return poly::Stage::New(tensor->GenerateIslDomain(), tensor->body(), tensor.self()); } diff --git a/cinn/ir/tensor.h b/cinn/ir/tensor.h index 9b08a9398f44c..7a071355e2cdc 100644 --- a/cinn/ir/tensor.h +++ b/cinn/ir/tensor.h @@ -98,9 +98,6 @@ struct WriteCacheRelation; * 2. never try to change a tensor's name, that will cause chaos. */ class _Tensor_ : public ExprNode<_Tensor_> { - //! a pointer to Shared, use void* to avoid cyclic definition dependency. - void* stage_shared{}; - public: //! Shape of this tensor(buffer). std::vector shape; @@ -151,6 +148,18 @@ class _Tensor_ : public ExprNode<_Tensor_> { */ std::set DependingTensorNames(); + /** + * Get a new tensor with the \p shape, but the underlying buffer shared. + * NOTE the tensor to Reshape should not be an inlined computation. + */ + ir::Tensor Reshape(const std::vector& shape, poly::StageMap stages) const; + + /** + * Get a new tensor with the \p shape with a newly allocated buffer. + * NOTE the tensor to Reshape should not be an inlined computation. + */ + ir::Tensor ReshapeCopied(const std::vector& shape, poly::StageMap stages) const; + /** * Tell whether this tensor has same shape with \p other. */ @@ -221,16 +230,6 @@ class _Tensor_ : public ExprNode<_Tensor_> { void WithBuffer(const std::string& memory_type, const Type& type = Void()); private: - //! Create the polyhedral element for analysis. - //! It is based on the shape. - void InitStage(); - - //! Free the memory for stage. - void DropStage(); - - void FakeStage(); - bool is_faked() const; - //! Initialize the axis field after the shape field is assigned. void InitAxis() const; diff --git a/cinn/ir/tensor_test.cc b/cinn/ir/tensor_test.cc index 7cb0689d54830..0997167ed7ff1 100644 --- a/cinn/ir/tensor_test.cc +++ b/cinn/ir/tensor_test.cc @@ -34,9 +34,9 @@ TEST(Tensor, inlined) { auto stages = CreateStages({D}); stages[C]->ComputeInline(); - auto funcs = lang::Lower("func_C", stages, {A, B, D}); - std::cout << "output: \n" << funcs << std::endl; - auto out = GetStreamCnt(funcs); + auto func = lang::Lower("func_C", stages, {A, B, D}); + std::cout << "output: \n" << func << std::endl; + auto out = GetStreamCnt(func); EXPECT_EQ(Trim(out), Trim(R"ROC( function func_C (_A, _B, _D) { @@ -61,5 +61,105 @@ TEST(Tensor, IsDependOnStatement) { ASSERT_FALSE(t->IsDependOnStatement("XXX")); } +TEST(Tensor, Reshape) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(100); + Placeholder A("A", {M, N}); + + auto stages = CreateStages({A}); + + auto A1 = A->Reshape({Expr(10), Expr(10), Expr(100)}, stages); + auto B = Compute(A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }); + + stages[A1]->ShareBufferWith(stages[A]); + stages->InsertLazily(B); + + auto func = lang::Lower("fn", stages, {A, B}); + + lang::Module::Builder builder("some_modue", common::DefaultHostTarget()); + builder.AddFunction(func); + + backends::CodeGenC codegenc(common::DefaultHostTarget()); + codegenc.SetInlineBuiltinCodes(false); + auto source = codegenc.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + LOG(INFO) << "source:\n" << source; + + auto target_source = R"ROC( +#include +#include + +void fn(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _tensor_3 = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _tensor_3); + cinn_buffer_malloc((void*)(0), _A); + const float* A_reshape_2 = ((const float*)(_A->memory)); + float* tensor_3 = ((float*)(_tensor_3->memory)); + for (int32_t i = 0; i < 10; i += 1) { + for (int32_t j = 0; j < 10; j += 1) { + for (int32_t k = 0; k < 100; k += 1) { + tensor_3[((1000 * i) + ((100 * j) + k))] = (2 * A_reshape_2[((1000 * i) + ((100 * j) + k))]); + }; + }; + }; + cinn_buffer_free((void*)(0), _tensor_3); +} +)ROC"; + + ASSERT_EQ(Trim(target_source), Trim(source)); +} + +TEST(Tensor, ReshapeCopied) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(100); + Placeholder A("A", {M, N}); + + auto stages = CreateStages({A}); + + auto A1 = A->ReshapeCopied({Expr(10), Expr(10), Expr(100)}, stages); + auto B = Compute(A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }); + + stages[A1]->ShareBufferWith(stages[A]); + stages->InsertLazily(B); + + auto func = lang::Lower("fn", stages, {A, B}); + + lang::Module::Builder builder("some_modue", common::DefaultHostTarget()); + builder.AddFunction(func); + + backends::CodeGenC codegenc(common::DefaultHostTarget()); + codegenc.SetInlineBuiltinCodes(false); + auto source = codegenc.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + LOG(INFO) << "source:\n" << source; + + auto target_source = R"ROC( +#include +#include + +void fn(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _tensor_4 = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _tensor_4); + cinn_buffer_malloc((void*)(0), _A); + const float* A_copied_2_reshape_3 = ((const float*)(_A->memory)); + float* tensor_4 = ((float*)(_tensor_4->memory)); + for (int32_t i = 0; i < 10; i += 1) { + for (int32_t j = 0; j < 10; j += 1) { + for (int32_t k = 0; k < 100; k += 1) { + tensor_4[((1000 * i) + ((100 * j) + k))] = (2 * A_copied_2_reshape_3[((1000 * i) + ((100 * j) + k))]); + }; + }; + }; + cinn_buffer_free((void*)(0), _tensor_4); +} +)ROC"; + + ASSERT_EQ(Trim(target_source), Trim(source)); +} + } // namespace ir } // namespace cinn diff --git a/cinn/lang/lower.cc b/cinn/lang/lower.cc index 61e741bafa653..2fcb65152c3c6 100644 --- a/cinn/lang/lower.cc +++ b/cinn/lang/lower.cc @@ -31,7 +31,8 @@ std::vector GetTempBuffers(const std::vector& tensor_args, std::unordered_set temp_buffer_names; // used to avoid duplication. std::vector temp_buffers; auto all_tensors = ir::CollectIRNodes(body, [&](const Expr* x) { - return x->as_tensor() && !stage_map[x->as_tensor()]->inlined() && !tensor_arg_names.count(x->as_tensor()->name); + return x->as_tensor() && x->as_tensor()->buffer.defined() && !stage_map[x->as_tensor()]->inlined() && + !tensor_arg_names.count(x->as_tensor()->name); }); for (auto& e : all_tensors) { if (!temp_buffer_names.count(e.as_tensor()->buffer->name)) { diff --git a/cinn/lang/lower_impl.cc b/cinn/lang/lower_impl.cc index 2f9c8cdd17df1..aae543287e742 100644 --- a/cinn/lang/lower_impl.cc +++ b/cinn/lang/lower_impl.cc @@ -313,7 +313,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); bool is_output = teller.IsWrite(tensor->name); - VLOG(5) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; // avoid duplicate if (!tensor_node->buffer.defined()) continue; @@ -378,7 +378,9 @@ std::unordered_map LowerImpl::GenAllTensorMap() { ir::LoweredFunc LowerImpl::operator()() { std::vector stages; + std::map all_tensor_map; for (auto& t : CollectAllTensors()) { + all_tensor_map[t->name] = t; if (!stages_[t]->inlined()) stages.push_back(stages_[t]); } @@ -388,7 +390,7 @@ ir::LoweredFunc LowerImpl::operator()() { auto func_body = GenerateFunctionBody(schedule.get()); - auto tensor_map = optim::InitialAssignBuffer(&func_body, stages_); + auto tensor_map = optim::InitialAssignBuffer(&func_body, stages_, all_tensor_map, comp_graph()); // copy the tensor(with buffer assigned) back to func's args. { for (auto& arg : tensor_args_) { diff --git a/cinn/optim/buffer_assign.cc b/cinn/optim/buffer_assign.cc index fc45a645a499d..37225951e1283 100644 --- a/cinn/optim/buffer_assign.cc +++ b/cinn/optim/buffer_assign.cc @@ -3,6 +3,7 @@ #include "cinn/common/union_find.h" #include "cinn/ir/ir_mutator.h" #include "cinn/ir/ir_printer.h" +#include "cinn/lang/lower_impl.h" #include "cinn/optim/ir_replace.h" namespace cinn { @@ -24,17 +25,12 @@ const char* BufferUFNode::__type_info__ = "BufferUFNode"; struct IRReplaceTensorMutator : ir::IRMutator<> { const std::map& tensor_map; IRReplaceTensorMutator(const std::map& tensor_map) : tensor_map(tensor_map) {} - void operator()(Expr* expr) { - LOG(INFO) << "original expr: " << *expr; - ir::IRMutator<>::Visit(expr, expr); - } + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(const ir::_Tensor_* op, Expr* expr) override { auto it = tensor_map.find(op->name); if (it != tensor_map.end()) { - LOG(INFO) << "unify tensor " << *expr; *expr = Expr(it->second); - LOG(INFO) << "unified to " << expr->as_tensor(); } } }; @@ -43,73 +39,79 @@ struct IRReplaceTensorMutator : ir::IRMutator<> { std::map InitialAssignBuffer(Expr* expr, poly::StageMap stages, - const std::vector>& buffer_shared) { - std::map tensor_map; - - auto tensor_exprs = ir::CollectIRNodes(*expr, [&](const Expr* x) { - auto* t = x->as_tensor(); - return t && (!stages[t]->meta.compute_inline) && !t->buffer.defined(); - }); + const std::map& all_tensor_map, + const common::Graph* comp_graph) { + // The tensor map helps to reserve only one tensor instance for a tensor(called the same name). + std::map buffer_updated_tensor; - if (tensor_exprs.empty()) return tensor_map; + for (auto& item : all_tensor_map) { + if (stages[item.second]->inlined()) continue; + buffer_updated_tensor[item.second->name] = item.second; + } // union-find to cluster the tensors with the same buffer. common::UnionFind union_find; - // The tensor map helps to reserve only one tensor instance for a tensor(called the same name). - for (auto& e : tensor_exprs) { - tensor_map[e.as_tensor()->name] = e.as_tensor_ref(); - } // unify all the tensor occurance with a global one, e.g. there are multiple tensor B exists in the expression, // replace them with a shared one. ir::CollectIRNodes(*expr, [&](const Expr* x) -> bool { auto* t = x->as_tensor(); - if (t && tensor_map.count(t->name)) { - Reference(x) = Expr(tensor_map.at(t->name)); + if (t && !stages[t]->inlined()) { + Reference(x) = Expr(all_tensor_map.at(t->name)); } return false; }); - auto existing_tensors = ir::CollectIRNodes(*expr, [&](const Expr* x) { - auto* t = x->as_tensor(); - return t && !stages[t]->meta.compute_inline && !t->buffer.defined(); - }); - CHECK_EQ(existing_tensors.size(), tensor_map.size()) - << "some of the tensors named same are not unified to one object"; - std::map uf_map; - for (auto& item : tensor_map) { + for (auto& item : all_tensor_map) { auto* n = union_find.AddNode(new BufferUFNode(item.second->name)); uf_map[item.second->name] = n->safe_as(); } - for (auto& item : tensor_map) { + for (auto& item : buffer_updated_tensor) { auto* cur_n = uf_map[item.first]; - if (!stages[item.second]->meta.tensors_to_share_buffer_with.empty()) { - for (auto& other : stages[item.second]->meta.tensors_to_share_buffer_with) { - // we might intialize the buffer in args. - auto* other_n = uf_map[other]; - if (!other_n) continue; - - VLOG(3) << "share buffer between " << item.first << " " << other_n->tensor_name; - cur_n->Union(other_n); + for (auto& other : stages[item.second]->meta.tensors_to_share_buffer_with) { + // we might intialize the buffer in args. + auto* other_n = uf_map[other]; + if (!other_n) continue; + + VLOG(3) << "share buffer between " << item.first << " " << other_n->tensor_name; + cur_n->Union(other_n); + } + } + + // determine which tensor to have the initial buffer, and will share accross the cluser, we take a topological order + // of the computational graph, and find out which tensor comes first in a cluster. + + auto [topo_order, topo_edges] = comp_graph->topological_order(); + for (common::GraphNode* n : topo_order) { + auto nn = n->safe_as(); + CHECK(nn); + { + auto it = uf_map.find(nn->tensor->name); + CHECK(it != uf_map.end()); + auto& cluster_info = std::get<0>(it->second->GetRoot())->cluster_info; + if (cluster_info.empty()) { // buffer owner(a tensor) of this cluster not set yet. + cluster_info = nn->tensor->name; } } } for (auto& cluster : union_find.GetClusters()) { - VLOG(5) << "get cluster size " << cluster.size(); - auto& first_tensor = tensor_map.at(cluster[0]->safe_as()->tensor_name); - first_tensor->WithBuffer(); - VLOG(3) << "first_tensor: " << first_tensor->name << " buffer " << first_tensor->buffer; - for (int i = 1; i < cluster.size(); i++) { - auto& tensor = tensor_map.at(cluster[i]->safe_as()->tensor_name); - tensor->Bind(first_tensor->buffer); - VLOG(3) << "tensor [" << tensor->name << "] bind buffer [" << first_tensor->buffer << "]"; + auto* cluster_root = std::get<0>(cluster[0]->GetRoot()); + auto root_tensor = all_tensor_map.at(cluster_root->cluster_info); + if (!root_tensor->buffer.defined() && !root_tensor->type().is_void()) root_tensor->WithBuffer(); + + for (auto* n : cluster) { + auto& tensor = all_tensor_map.at(n->safe_as()->tensor_name); + if (tensor != root_tensor) { + Reference(&tensor)->Bind(root_tensor->buffer); + VLOG(3) << "tensor " << tensor->name << " bind buffer [" << tensor->buffer->name << "]"; + } } } - return tensor_map; + return buffer_updated_tensor; } } // namespace optim diff --git a/cinn/optim/buffer_assign.h b/cinn/optim/buffer_assign.h index 25fddac6d9a73..75e5aee73be80 100644 --- a/cinn/optim/buffer_assign.h +++ b/cinn/optim/buffer_assign.h @@ -9,11 +9,12 @@ namespace optim { /** * Assign buffer for tensors those are not marked as compute_inline. * @param expr - * @param buffer_shared the clusters that each cluster share the same buffer. + * @param stages The stage map. */ std::map InitialAssignBuffer(Expr* expr, poly::StageMap stages, - const std::vector>& buffer_shared = {}); + const std::map& all_tensor_map, + const common::Graph* comp_graph); } // namespace optim } // namespace cinn diff --git a/cinn/optim/compute_inline_expand.cc b/cinn/optim/compute_inline_expand.cc index 00dac1fade692..f4c6e934f12a7 100644 --- a/cinn/optim/compute_inline_expand.cc +++ b/cinn/optim/compute_inline_expand.cc @@ -91,9 +91,6 @@ void ComputeInlineExpand(Expr *expr, poly::StageMap stages) { inline_tensors = ir::CollectLoadTensors( *expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); }); - - LOG(INFO) << "inline tensor size: " << inline_tensors.size(); - LOG(INFO) << "expr: " << *expr; } } diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc index fde664e22066f..e4b0b50042ae8 100644 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -546,6 +546,7 @@ void Stage::ShareBufferWith(Stage *other) { CHECK(tensor_); CHECK(!other->meta.compute_inline); CHECK(!meta.compute_inline); + meta.tensors_to_share_buffer_with.insert(other->id()); other->meta.tensors_to_share_buffer_with.insert(tensor_->name); } @@ -629,7 +630,7 @@ const Stage *_StageMap_::operator[](const ir::Tensor &tensor) const { return data_.at(tensor->name).get(); } Stage *_StageMap_::operator[](const ir::_Tensor_ *tensor) { - CHECK(data_.count(tensor->name)); + CHECK(data_.count(tensor->name)) << "StageMap has no stage for tensor [" << tensor->name << "]"; return data_[tensor->name].get(); } const Stage *_StageMap_::operator[](const ir::_Tensor_ *tensor) const { diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 0429503fe48ac..011cb129c1bb8 100644 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -97,7 +97,7 @@ struct TensorScheduleMeta { bool compute_inline{false}; - //! Name of the tensors thouse share buffer with `this` tensor. + //! Name of the tensors those share buffer with `this` tensor. std::set tensors_to_share_buffer_with; }; diff --git a/cinn/pybind/ir.cc b/cinn/pybind/ir.cc index 2aa4c4e12d7ee..173c3b8eb37df 100644 --- a/cinn/pybind/ir.cc +++ b/cinn/pybind/ir.cc @@ -518,6 +518,8 @@ void BindIrTensor(py::module *m) { .def("buffer_depended_tensor_names", &ir::_Tensor_::buffer_depended_tensor_names) .def(py::init<>()) .def("has_expression", &ir::_Tensor_::has_expression) + .def("reshape", &ir::_Tensor_::Reshape) + .def("reshape_copied", &ir::_Tensor_::ReshapeCopied) .def("with_buffer", py::overload_cast(&ir::_Tensor_::WithBuffer), py::arg("type") = Type::type_t::Void)