From 33f5cb36a1f12f7e241b02eb495642ede3e0b910 Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Sat, 2 Mar 2024 13:45:09 +0000 Subject: [PATCH] update --- .../hlir/framework/pir/op_lowering_impl.cc | 133 ++++++++++++++---- 1 file changed, 108 insertions(+), 25 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 90eeafa14ee2a6..4423b153a5ef65 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #include +#include #include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" @@ -737,6 +738,109 @@ ir::Expr TrivalFusion(ir::Expr upper, ir::Expr down) { return fused.GetFuncBody(); } +template +std::vector> ZipVectors(const std::vector& first_vector, + const std::vector& second_vector) { + PADDLE_ENFORCE_EQ(first_vector.size(), + second_vector.size(), + "can not zip vector with different length"); + + std::vector> output_vector; + for (int i = 0; i < first_vector.size(); i++) { + output_vector.emplace_back( + std::make_pair(first_vector[i], second_vector[i])); + } + return output_vector; +} + +template +Type RecursiveCallUnit(const Process& process, + const Condition& condition, + const Type& init_data) { + Type acc = init_data; + while (condition(acc)) { + acc = process(acc); + } + return acc; +} + +bool IsAdjecentInjectiveBetween( + const std::pair& upstream_func, + const std::pair& downstream_func) { + return upstream_func.second <= OpPatternKind::kInjective && + downstream_func.second <= OpPatternKind::kInjective && + IsAdjecent(upstream_func.first, downstream_func.first); +} + +std::optional> FindNonRoot( + const std::vector>& input) { + for (int i = 0; i < input.size(); i++) { + const auto& upstream_expr = input[i].first; + const auto& upstream_op_kind = input[i].second; + + if (upstream_op_kind <= OpPatternKind::kInjective) { + for (int j = i + 1; j < input.size(); j++) { + const auto& downstream_expr = input[j].first; + const auto& downstream_op_kind = input[j].second; + if (downstream_op_kind <= OpPatternKind::kInjective && + IsAdjecent(upstream_expr, downstream_expr)) { + return input[i]; + } + } + } + } + return {}; +} + +bool CanFindNonRoot( + const std::vector>& input) { + const auto& result = FindNonRoot(input); + return result.has_value(); +} + +std::vector> InlineAllUsedBody( + const std::vector>& origin_funcs, + const std::pair& upstream_func) { + std::vector> inlined_funcs; + std::transform( + origin_funcs.begin(), + origin_funcs.end(), + std::back_inserter(inlined_funcs), + [&](const std::pair downstream_func) { + if (upstream_func.first != downstream_func.first && + IsAdjecentInjectiveBetween(upstream_func, downstream_func)) { + return std::make_pair( + TrivalFusion(upstream_func.first, downstream_func.first), + OpPatternKind::kInjective); + } else { + return downstream_func; + } + }); + return inlined_funcs; +} + +std::vector> FusionStep( + const std::vector>& input) { + const auto& upstream_func = FindNonRoot(input).value(); + auto inlined_funcs = InlineAllUsedBody(input, upstream_func); + inlined_funcs.erase( + std::remove_if(inlined_funcs.begin(), + inlined_funcs.end(), + [&](const std::pair& func_pair) { + return func_pair.first == upstream_func.first; + })); + return inlined_funcs; +} + +std::vector FetchExprFromPairs( + const std::vector>& input) { + std::vector output_expr; + for (const auto& func_pair : input) { + output_expr.emplace_back(func_pair.first); + } + return output_expr; +} + std::vector OpInlineFusion(const GroupPtr& group, const std::vector<::pir::Operation*> ops, std::vector funcs) { @@ -745,37 +849,16 @@ std::vector OpInlineFusion(const GroupPtr& group, } auto op_patterns = GetOpPatternVector(ops); + auto pattern_func_pairs = ZipVectors(funcs, op_patterns); VLOG(4) << "op_patterns.size() = " << op_patterns.size(); VLOG(4) << " funcs.size() = " << funcs.size(); PADDLE_ENFORCE_EQ( op_patterns.size(), funcs.size(), "ops and funcs size not equal"); - while (true) { - VLOG(4) << "Start search for Injective + Injective"; - std::optional> opt_idx_pair = - SearchAdjacentInjectives(op_patterns, funcs); - if (!opt_idx_pair.has_value()) { - VLOG(4) << "Not found Injective + Injective, break."; - break; - } - int upper_stream = opt_idx_pair.value().first, - down_stream = opt_idx_pair.value().second; - VLOG(4) << "Find Injective + Injective" << upper_stream << " " - << down_stream; - ir::Expr func_body = TrivalFusion(funcs[upper_stream], funcs[down_stream]); - - // update - auto update_funcs_and_op_patterns = [&]() { - funcs[down_stream] = func_body; - op_patterns[down_stream] = OpPatternKind::kInjective; - - funcs.erase(funcs.begin() + upper_stream); - op_patterns.erase(op_patterns.begin() + upper_stream); - }; - update_funcs_and_op_patterns(); - } - return funcs; + const auto& inlined_func_pairs = + RecursiveCallUnit(FusionStep, CanFindNonRoot, pattern_func_pairs); + return FetchExprFromPairs(inlined_func_pairs); } std::vector OpLowererImpl::LowerGroup(