From 0339776519b95ac5679bd889ffb9ac0179d0c341 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 21 Nov 2023 11:13:09 +0800 Subject: [PATCH] [CINN] Strong constraint branch adapt to pir (#58993) * Strong Constraint Branch * NoInlineTranslator (#84) * Adapt adt to pir * Move FLAGS_cinn_enable_map_expr_schedule location * Apply new group schedule * Remove useless log * Remove adt unittest * Solve merge conflicts * Fix typo * Fix merge conflicts * Add unit test * Fix cmake * Add test_cinn_sub_graph_map_expr * Save current workspace * Remove unittest to cinn directory * Refactor unittest cmake * Refactor unittest cmake * Add unit test at Cmake * Restore test_cinn_sub_graph.py * Fix unittest * Refine codes according to comment * Refine codes according to comment --- paddle/cinn/adt/CMakeLists.txt | 62 ++++----- paddle/cinn/adt/adapter_tensor.cc | 44 +++++++ paddle/cinn/adt/adapter_tensor.h | 45 ++----- paddle/cinn/adt/equation_value.h | 5 - paddle/cinn/adt/generate_map_expr.cc | 118 +++++++++--------- paddle/cinn/adt/generate_map_expr.h | 22 ++-- paddle/cinn/adt/inline_translator.h | 37 +----- paddle/cinn/adt/inline_translator_trait.h | 58 +++++++++ paddle/cinn/adt/kgroup.cc | 4 +- paddle/cinn/adt/kgroup.h | 12 +- paddle/cinn/adt/m_expr.h | 23 ++-- paddle/cinn/adt/map_expr_ctx.h | 6 +- paddle/cinn/adt/naive_op_equation_context.cc | 45 ++++--- paddle/cinn/adt/naive_op_equation_context.h | 8 +- paddle/cinn/adt/no_inline_translator.h | 83 ++++++++++++ paddle/cinn/adt/print_utils/CMakeLists.txt | 27 ++-- .../cinn/adt/print_utils/print_equations.cc | 14 ++- paddle/cinn/adt/print_utils/print_map_expr.cc | 18 +-- paddle/cinn/adt/schedule_descriptor.cc | 4 +- .../transforms/cinn_group_lowering_pass.cc | 7 ++ paddle/cinn/hlir/framework/graph.h | 16 --- .../cinn/hlir/framework/op_lowering_impl.cc | 94 -------------- paddle/cinn/hlir/framework/op_lowering_impl.h | 38 ------ paddle/cinn/hlir/framework/pir/group.h | 31 ++++- .../hlir/framework/pir/op_lowering_impl.cc | 105 +++++++++++++++- .../hlir/framework/pir/op_lowering_impl.h | 33 +++++ paddle/cinn/hlir/pe/CMakeLists.txt | 7 +- paddle/cinn/hlir/pe/map_expr_to_ir.cc | 108 +++++++++++----- paddle/cinn/hlir/pe/map_expr_to_ir.h | 1 + .../st_shape_group_scheduler.cc | 7 ++ .../group_schedule/st_shape_group_scheduler.h | 2 + paddle/cinn/pybind/frontend.cc | 2 - paddle/cinn/runtime/flags.cc | 11 +- .../framework/paddle2cinn/cinn_compiler.cc | 2 - test/CMakeLists.txt | 2 +- test/cinn/CMakeLists.txt | 25 +--- test/cinn/adt/CMakeLists.txt | 21 ++++ test/cinn/adt/test_add_inline.py | 60 --------- test/cinn/adt/test_broadcast_expr.py | 59 --------- test/cinn/adt/test_cinn_sub_graph_map_expr.py | 76 +++++++++++ test/cinn/adt/test_naive_add.py | 57 --------- test/cinn/adt/test_naive_reduce.py | 50 -------- test/cinn/adt/test_reduce_fusion.py | 62 --------- test/cinn/adt/test_reduce_schedule_mesh.py | 59 --------- 44 files changed, 752 insertions(+), 818 deletions(-) create mode 100644 paddle/cinn/adt/adapter_tensor.cc create mode 100644 paddle/cinn/adt/inline_translator_trait.h create mode 100644 paddle/cinn/adt/no_inline_translator.h create mode 100644 test/cinn/adt/CMakeLists.txt delete mode 100755 test/cinn/adt/test_add_inline.py delete mode 100755 test/cinn/adt/test_broadcast_expr.py create mode 100644 test/cinn/adt/test_cinn_sub_graph_map_expr.py delete mode 100755 test/cinn/adt/test_naive_add.py delete mode 100644 test/cinn/adt/test_naive_reduce.py delete mode 100644 test/cinn/adt/test_reduce_fusion.py delete mode 100644 test/cinn/adt/test_reduce_schedule_mesh.py diff --git a/paddle/cinn/adt/CMakeLists.txt b/paddle/cinn/adt/CMakeLists.txt index 8dbb3bc8769839..e74c21bb0e7949 100644 --- a/paddle/cinn/adt/CMakeLists.txt +++ b/paddle/cinn/adt/CMakeLists.txt @@ -1,35 +1,39 @@ -add_subdirectory(print_utils) +if(NOT CINN_ONLY) + add_subdirectory(print_utils) -core_gather_headers() + core_gather_headers() -gather_srcs( - cinnapi_src - SRCS - anchor_sd_equation_context.cc - equation_function.cc - equation_solver.cc - equation_value.cc - generate_map_expr.cc - get_sub_reshape_dim_ranges.cc - igroup.cc - index_expr_infer_context.cc - kgroup.cc - m_ir.cc - naive_bidirection_equation_generator.cc - naive_op_equation_context.cc - partition_op_stmts.cc - schedule_descriptor.cc - schedule_dim.cc - schedule_mesh.cc - simplify_value.cc - write_broadcast_disabled_bidirection_equation_generator.cc) + gather_srcs( + cinnapi_src + SRCS + adapter_tensor.cc + anchor_sd_equation_context.cc + equation_function.cc + equation_solver.cc + equation_value.cc + generate_map_expr.cc + get_sub_reshape_dim_ranges.cc + igroup.cc + index_expr_infer_context.cc + kgroup.cc + m_ir.cc + naive_bidirection_equation_generator.cc + naive_op_equation_context.cc + partition_op_stmts.cc + schedule_descriptor.cc + schedule_dim.cc + schedule_mesh.cc + simplify_value.cc + write_broadcast_disabled_bidirection_equation_generator.cc) -cinn_cc_test(equation_value_match_trait_test SRCS - equation_value_match_trait_test.cc DEPS gtest glog) + cinn_cc_test(equation_value_match_trait_test SRCS + equation_value_match_trait_test.cc DEPS gtest glog) -cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) + cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) -cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS - cinncore) + cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS + cinncore) -message(STATUS "ADT srcs: ${cinnapi_src}") + message(STATUS "ADT srcs: ${cinnapi_src}") + +endif() diff --git a/paddle/cinn/adt/adapter_tensor.cc b/paddle/cinn/adt/adapter_tensor.cc new file mode 100644 index 00000000000000..464c45780dbecd --- /dev/null +++ b/paddle/cinn/adt/adapter_tensor.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/adt/adapter_tensor.h" +#include "glog/logging.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" + +namespace cinn::adt::adapter { + +std::size_t Tensor::GetRank() const { + return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data) + .size(); +} + +std::vector Tensor::GetShape() const { + std::vector ret{}; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret.emplace_back(dim_size); + } + return ret; +} + +std::size_t Tensor::GetNumel() const { + std::size_t ret = 1; + for (int dim_size : + cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) { + ret = ret * dim_size; + } + return ret; +} + +} // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/adapter_tensor.h b/paddle/cinn/adt/adapter_tensor.h index 2a6cc941afb89e..dbd2c2dcecfdbb 100644 --- a/paddle/cinn/adt/adapter_tensor.h +++ b/paddle/cinn/adt/adapter_tensor.h @@ -13,59 +13,28 @@ // limitations under the License. #pragma once -#include "glog/logging.h" #include "paddle/cinn/adt/adt.h" -#include "paddle/cinn/hlir/framework/graph.h" -#include "paddle/cinn/hlir/framework/node.h" +#include "paddle/pir/core/value.h" namespace cinn::adt::adapter { struct Tensor final { - const hlir::framework::NodeData* node_data; - const hlir::framework::Graph* graph; + ::pir::Value node_data; bool operator==(const Tensor& other) const { - return this->node_data == other.node_data && this->graph == other.graph; + return this->node_data == other.node_data; } - std::size_t GetRank() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()).size(); - } + std::size_t GetRank() const; - const std::vector& GetShape() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - return shape_dict.at(node_data->id()); - } + std::vector GetShape() const; - std::size_t GetNumel() const { - const auto& shape_dict = - graph->GetAttrs>( - "infershape"); - CHECK(shape_dict.count(node_data->id())) - << "Can't find " << node_data->id() << " 's shape!"; - std::vector shape = shape_dict.at(node_data->id()); - std::size_t ret = 1; - for (int32_t dim_size : shape) { - ret = ret * dim_size; - } - return ret; - } + std::size_t GetNumel() const; }; inline std::size_t GetHashValueImpl(const Tensor& tensor) { - return hash_combine( - std::hash()(tensor.node_data), - std::hash()(tensor.graph)); + return std::hash<::pir::Value>()(tensor.node_data); } } // namespace cinn::adt::adapter diff --git a/paddle/cinn/adt/equation_value.h b/paddle/cinn/adt/equation_value.h index 7aa6c2b7c3155b..6c1fef21a93dde 100644 --- a/paddle/cinn/adt/equation_value.h +++ b/paddle/cinn/adt/equation_value.h @@ -20,11 +20,6 @@ #include "paddle/cinn/adt/equation.h" #include "paddle/cinn/adt/match.h" -namespace cinn::hlir::framework { -class Node; -class NodeData; -} // namespace cinn::hlir::framework - namespace cinn::adt { DEFINE_ADT_TAG(tPointer); diff --git a/paddle/cinn/adt/generate_map_expr.cc b/paddle/cinn/adt/generate_map_expr.cc index b435acbcbcfb95..4180c9174a45fb 100644 --- a/paddle/cinn/adt/generate_map_expr.cc +++ b/paddle/cinn/adt/generate_map_expr.cc @@ -26,7 +26,11 @@ #include "paddle/cinn/adt/print.h" #include "paddle/cinn/adt/schedule_descriptor.h" #include "paddle/cinn/adt/tree.h" +#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/runtime/flags.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" #include "glog/logging.h" @@ -84,79 +88,65 @@ using LoopDescriptor4IterVarT = std::function; using AnchorTensor = Variable; using FakeOpPlaceHolders = List; -Op MakeOp(const hlir::framework::Node* op) { return {op}; } +Op MakeOp(const ::pir::Operation* op) { return {op}; } template -void VisitEachInputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->inlinks_in_order()) { - DoEach(graph_edge->source()->safe_as()); +void VisitEachInputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_operands(); ++i) { + DoEach(op->operand_source(i)); } } -List MakeOpStmtInputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtInputList(const ::pir::Operation* op) { List ret{}; - VisitEachInputTensor(op, [&](const auto* tensor) { - ret->emplace_back(adapter::Tensor{tensor, graph}); + VisitEachInputTensor(op, [&](const ::pir::Value& tensor) { + ret->emplace_back(adapter::Tensor{tensor}); }); return ret; } template -void VisitEachOutputTensor(const hlir::framework::Node* op, - const DoEachT& DoEach) { - for (const auto& graph_edge : op->outlinks_in_order()) { - DoEach(graph_edge->sink()->safe_as()); +void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_results(); ++i) { + DoEach(const_cast<::pir::Operation*>(op)->result(i)); } } -List MakeOpStmtOutputList(const hlir::framework::Node* op, - const hlir::framework::Graph* graph) { +List MakeOpStmtOutputList(const ::pir::Operation* op) { List ret{}; - VisitEachOutputTensor(op, [&](const auto* tensor) { - ret->emplace_back(adapter::Tensor{tensor, graph}); + VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) { + ret->emplace_back(adapter::Tensor{tensor}); }); return ret; } template -void VisitEachOpStmt( - const std::shared_ptr& group, - const DoEachT& DoEach) { - // Note - for (const auto* op : group->CollectNodes()) { - DoEach(OpStmt{MakeOp(op), - MakeOpStmtInputList(op, group->graph_), - MakeOpStmtOutputList(op, group->graph_)}); +void VisitEachOpStmt(const std::shared_ptr& group, + const DoEachT& DoEach) { + for (const auto* op : group->CollectOps()) { + DoEach( + OpStmt{MakeOp(op), MakeOpStmtInputList(op), MakeOpStmtOutputList(op)}); } } -hlir::framework::OpPatternKind GetOpPatternKind( - const hlir::framework::Node* node) { - static const hlir::framework::OpValueType& - op_pattern_dict = - hlir::framework::Operator::GetAttrs( - "OpPattern"); - auto kind = op_pattern_dict[node->op()]; - return kind; +hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) { + return hlir::framework::pir::CompatibleInfo::OpKind(*node); } bool CollectRewritedReductionOpStmts(const OpStmt& op_stmt, List* ret) { const auto& [op, inputs, outputs] = op_stmt.tuple(); - CHECK(op.Has()); - if (GetOpPatternKind(op.Get()) == + CHECK(op.Has()); + if (GetOpPatternKind(op.Get()) == hlir::framework::OpPatternKind::kReduction) { - tReduceInit init_op{ - op.Get()}; + tReduceInit init_op{ + op.Get()}; (*ret)->emplace_back(OpStmt{init_op, List{}, outputs}); - tReduceAcc acc_op{ - op.Get()}; + tReduceAcc acc_op{op.Get()}; (*ret)->emplace_back(OpStmt{acc_op, inputs, outputs}); return true; } else { @@ -172,7 +162,7 @@ void CollectRewritedOpStmts(const OpStmt& op_stmt, List* ret) { } List MakeOpStmts( - const std::shared_ptr& group) { + const std::shared_ptr& group) { List ret{}; VisitEachOpStmt(group, [&](const auto& op_stmt) { @@ -213,7 +203,7 @@ std::shared_ptr MakeIGroup(const AnchorGroup& igroup_spec) { } std::vector> GenerateIGroups( - const std::shared_ptr& group) { + const std::shared_ptr& group) { std::vector> ret{}; List op_stmts = MakeOpStmts(group); @@ -227,7 +217,7 @@ std::vector> GenerateIGroups( } std::shared_ptr GenerateKGroups( - const std::shared_ptr& group, + const std::shared_ptr& group, const std::vector>& igroups) { CHECK_EQ(igroups.size(), 1); return std::make_shared(group, igroups); @@ -343,36 +333,34 @@ Tensor GetAnchorTensor(const std::shared_ptr& igroup) { } template -void VisitInputTensor(const hlir::framework::Graph::Group& group, +void VisitInputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto* node_data : group.GetInputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetInputOpValues()) { + DoEach(node_data); } } template -void VisitOutputTensor(const hlir::framework::Graph::Group& group, +void VisitOutputTensor(const hlir::framework::pir::Group& group, const DoEachT& DoEach) { - for (const auto& node_data : group.GetOutputNodeDatas()) { - DoEach(node_data, group.graph_); + for (const ::pir::Value& node_data : group.GetOutputOpValues()) { + DoEach(node_data); } } List MakeInputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitInputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitInputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } List MakeOutputTensors(const std::shared_ptr& kgroup) { List ret{}; - VisitOutputTensor(*kgroup->cinn_group(), - [&](const auto* node_data, const auto* graph) { - ret->emplace_back(adapter::Tensor{node_data, graph}); - }); + VisitOutputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) { + ret->emplace_back(adapter::Tensor{node_data}); + }); return ret; } @@ -437,7 +425,7 @@ MapExpr GenerateMapExpr(const std::shared_ptr& kgroup) { } // namespace MapExpr GenerateMapExpr( - const std::shared_ptr& group) { + const std::shared_ptr& group) { const auto& igroups = GenerateIGroups(group); const auto& kgroup = GenerateKGroups(group, igroups); @@ -445,18 +433,26 @@ MapExpr GenerateMapExpr( return GenerateMapExpr(kgroup); } -namespace {} // namespace - void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph) { + const hlir::framework::pir::GroupList& groups) { if (!FLAGS_cinn_enable_map_expr) { return; } - for (const auto& fusion_group : graph->fusion_groups) { + for (const auto& fusion_group : groups) { const auto& map_expr = GenerateMapExpr(fusion_group); VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); } } +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group) { + if (!FLAGS_cinn_enable_map_expr) { + return; + } + const auto& map_expr = GenerateMapExpr(fusion_group); + VLOG(1) << ToTxtString(map_expr, fusion_group->group_id); + fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); +} + } // namespace cinn::adt diff --git a/paddle/cinn/adt/generate_map_expr.h b/paddle/cinn/adt/generate_map_expr.h index c604dd9f070c06..61b5906c8138a3 100644 --- a/paddle/cinn/adt/generate_map_expr.h +++ b/paddle/cinn/adt/generate_map_expr.h @@ -14,19 +14,25 @@ #pragma once +#include + #include "paddle/cinn/adt/m_expr.h" -#include "paddle/cinn/adt/m_ir.h" -#include "paddle/cinn/hlir/framework/graph.h" -namespace cinn::adt { +namespace cinn::hlir::framework::pir { + +struct Group; +using GroupList = std::vector>; -class IGroup; -class KGroup; +} // namespace cinn::hlir::framework::pir + +namespace cinn::adt { MapExpr GenerateMapExpr( - const std::shared_ptr& group); + const std::shared_ptr& group); + +void TryGenerateMapExprFromGraph(const hlir::framework::pir::GroupList& groups); -void TryGenerateMapExprFromGraph( - const std::shared_ptr& graph); +void TryGenerateMapExprFromGroup( + const std::shared_ptr& fusion_group); } // namespace cinn::adt diff --git a/paddle/cinn/adt/inline_translator.h b/paddle/cinn/adt/inline_translator.h index 5298d17ffadcda..2cd3a44bc7dd05 100644 --- a/paddle/cinn/adt/inline_translator.h +++ b/paddle/cinn/adt/inline_translator.h @@ -15,47 +15,12 @@ #pragma once #include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/adt/inline_translator_trait.h" #include "paddle/cinn/adt/m_expr.h" #include "paddle/cinn/adt/tree.h" namespace cinn::adt { -template