Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#20 from feifei-111/trival_fuse
Browse files Browse the repository at this point in the history
refine codes of OpInlineFusion
  • Loading branch information
2742195759 authored Mar 4, 2024
2 parents 0a1d2bf + 264145d commit 42f33b7
Showing 1 changed file with 108 additions and 23 deletions.
131 changes: 108 additions & 23 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h"

#include <string>
#include <utility>

#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
Expand Down Expand Up @@ -737,6 +738,109 @@ ir::Expr TrivalFusion(ir::Expr upper, ir::Expr down) {
return fused.GetFuncBody();
}

template <typename A, typename B>
std::vector<std::pair<A, B>> ZipVectors(const std::vector<A>& first_vector,
const std::vector<B>& second_vector) {
PADDLE_ENFORCE_EQ(first_vector.size(),
second_vector.size(),
"can not zip vector with different length");

std::vector<std::pair<A, B>> output_vector;
for (int i = 0; i < first_vector.size(); i++) {
output_vector.emplace_back(
std::make_pair(first_vector[i], second_vector[i]));
}
return output_vector;
}

template <typename Type, typename Process, typename Condition>
Type RecursiveCallUnit(const Process& process,
const Condition& condition,
const Type& init_data) {
Type acc = init_data;
while (condition(acc)) {
acc = process(acc);
}
return acc;
}

bool IsAdjecentInjectiveBetween(
const std::pair<ir::Expr, OpPatternKind>& upstream_func,
const std::pair<ir::Expr, OpPatternKind>& downstream_func) {
return upstream_func.second <= OpPatternKind::kInjective &&
downstream_func.second <= OpPatternKind::kInjective &&
IsAdjecent(upstream_func.first, downstream_func.first);
}

std::optional<std::pair<ir::Expr, OpPatternKind>> FindNonRoot(
const std::vector<std::pair<ir::Expr, OpPatternKind>>& input) {
for (int i = 0; i < input.size(); i++) {
const auto& upstream_expr = input[i].first;
const auto& upstream_op_kind = input[i].second;

if (upstream_op_kind <= OpPatternKind::kInjective) {
for (int j = i + 1; j < input.size(); j++) {
const auto& downstream_expr = input[j].first;
const auto& downstream_op_kind = input[j].second;
if (downstream_op_kind <= OpPatternKind::kInjective &&
IsAdjecent(upstream_expr, downstream_expr)) {
return input[i];
}
}
}
}
return {};
}

bool CanFindNonRoot(
const std::vector<std::pair<ir::Expr, OpPatternKind>>& input) {
const auto& result = FindNonRoot(input);
return result.has_value();
}

std::vector<std::pair<ir::Expr, OpPatternKind>> InlineAllUsedBody(
const std::vector<std::pair<ir::Expr, OpPatternKind>>& origin_funcs,
const std::pair<ir::Expr, OpPatternKind>& upstream_func) {
std::vector<std::pair<ir::Expr, OpPatternKind>> inlined_funcs;
std::transform(
origin_funcs.begin(),
origin_funcs.end(),
std::back_inserter(inlined_funcs),
[&](const std::pair<ir::Expr, OpPatternKind> downstream_func) {
if (upstream_func.first != downstream_func.first &&
IsAdjecentInjectiveBetween(upstream_func, downstream_func)) {
return std::make_pair(
TrivalFusion(upstream_func.first, downstream_func.first),
OpPatternKind::kInjective);
} else {
return downstream_func;
}
});
return inlined_funcs;
}

std::vector<std::pair<ir::Expr, OpPatternKind>> FusionStep(
const std::vector<std::pair<ir::Expr, OpPatternKind>>& input) {
const auto& upstream_func = FindNonRoot(input).value();
auto inlined_funcs = InlineAllUsedBody(input, upstream_func);
inlined_funcs.erase(
std::remove_if(inlined_funcs.begin(),
inlined_funcs.end(),
[&](const std::pair<ir::Expr, OpPatternKind>& func_pair) {
return func_pair.first == upstream_func.first;
}));
return inlined_funcs;
}

std::vector<ir::Expr> FetchExprFromPairs(
const std::vector<std::pair<ir::Expr, OpPatternKind>>& input) {
std::vector<ir::Expr> output_expr;
for (const auto& func_pair : input) {
output_expr.emplace_back(func_pair.first);
}
return output_expr;
}

std::vector<ir::Expr> OpInlineFusion(const GroupPtr& group,
const std::vector<::pir::Operation*> ops,
std::vector<ir::Expr> funcs) {
Expand All @@ -745,35 +849,16 @@ std::vector<ir::Expr> OpInlineFusion(const GroupPtr& group,
}

auto op_patterns = GetOpPatternVector(ops);
auto pattern_func_pairs = ZipVectors(funcs, op_patterns);

VLOG(4) << "op_patterns.size() = " << op_patterns.size();
VLOG(4) << " funcs.size() = " << funcs.size();
PADDLE_ENFORCE_EQ(
op_patterns.size(), funcs.size(), "ops and funcs size not equal");

while (true) {
VLOG(4) << "Start search for Injective + Injective";
std::optional<std::pair<int, int>> opt_idx_pair =
SearchAdjacentInjectives(op_patterns, funcs);
if (!opt_idx_pair.has_value()) {
VLOG(4) << "Not found Injective + Injective, break.";
break;
}
int upper_stream = opt_idx_pair.value().first,
down_stream = opt_idx_pair.value().second;
VLOG(4) << "Find Injective + Injective" << upper_stream << " "
<< down_stream;
ir::Expr func_body = TrivalFusion(funcs[upper_stream], funcs[down_stream]);

// update
auto update_funcs_and_op_patterns = [&]() {
funcs[down_stream] = func_body;
op_patterns[down_stream] = OpPatternKind::kInjective;
};
update_funcs_and_op_patterns();
RemoveUseless(upper_stream, &op_patterns, &funcs);
}
return funcs;
const auto& inlined_func_pairs =
RecursiveCallUnit(FusionStep, CanFindNonRoot, pattern_func_pairs);
return FetchExprFromPairs(inlined_func_pairs);
}

std::vector<ir::LoweredFunc> OpLowererImpl::LowerGroup(
Expand Down

0 comments on commit 42f33b7

Please sign in to comment.