Skip to content

Commit

Permalink
Cinn trivalop fuse (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Mar 12, 2024
1 parent 8662273 commit 483edae
Showing 1 changed file with 38 additions and 49 deletions.
87 changes: 38 additions & 49 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,19 @@ std::set<Expr> GetStoreFromBody(const ir::Expr& body) {
return store_tensor_exprs;
}

bool CheckIterEq(std::vector<ir::Var> up_iter,
std::vector<ir::Var> down_iter){TODO}
std::vector<ir::Var> GetOutputIters(const std::vector<ir::Expr>& indices) {
std::vector<ir::Var> vars;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
}


ir::Expr TransformComputeExpr(ir::Expr up_compute_expr, ir::Expr downstream) {
bool CheckIterEq(std::vector<ir::Var> up_iter, std::vector<ir::Var> down_iter){
TODO
}
}

} // namespace ComposeUtils

Expand All @@ -243,23 +250,16 @@ struct TrivialOp {
return GetSingleStoreExpr(func_body).As<ir::Store>()->value;
}

std::vector<ir::Var> GetOutputIters() const {
return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As<ir::Store>()->indices);
}

std::vector<ir::Var> GetAllIterVar() const { return GetOutputIters(); }

ir::Expr* GetStoreValuePointer() const {
return &GetSingleStoreExpr(func_body).As<ir::Store>()->value;
}

std::vector<ir::Var> GetOutputIters() const {
std::vector<ir::Var> vars;
const auto& indices =
GetSingleStoreExpr(func_body).As<ir::Store>()->indices;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
}

ir::Expr GetFuncBody() const { return func_body; }

ir::Tensor GetOutputTensor() const {
Expand Down Expand Up @@ -313,14 +313,7 @@ struct ReduceOp {
}

std::vector<ir::Var> GetOutputIters() const {
std::vector<ir::Var> vars;
const auto& indices =
GetSingleStoreExpr(func_body).As<ir::Store>()->indices;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As<ir::Store>()->indices);
}

ir::Expr GetFuncBody() const { return func_body; }
Expand Down Expand Up @@ -520,31 +513,24 @@ TrivialOp TransformT2R(ReduceOp reduce_upper, TrivialOp trivial_down) {}

bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {}

ir::Expr ReplaceReduceComputeBody(const ir::Expr& body,
const ir::Expr& new_body) {
TODO;
}

ReduceOp TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) {
std::vector<ReduceOp> TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) {
VLOG(4) << "RRTransform begin";

const auto& down_out_iter = downstream.GetOutputIters();
const auto& up_reduce_iter = upstream.GetReduceIters();
const auto& down_reduce_iter = downstream.GetReduceIters();

// we just support fuse reduce when reduce iter eq
CHECK(ComposeUtils::CheckIterEq(up_reduce_iter, down_reduce_iter));

// TODO modify up_expr, replace out iter of up_expr i => f(i)
ir::Expr new_reduce_body = ir::ir_utils::IRCopy(downstream.GetFuncBody());
ir::Expr reduce_op_expr = ComposeUtils::TransformComputeExpr(
new_reduce_body.GetComputeExpr(), down);
ir::Expr const auto& replaced_tensor = upstream.GetOutputTensor();

ir::Expr result = ComposeUtils::CreateReduceExpr(downstream, reduce_op_expr);
CHECK(ComposeUtils::CheckIterEq(upstream.GetReduceIters(), downstream.GetReduceIters()));
const auto& load_upstream_expr = downstream.GetEachTensorLoadExpr(upstream.GetOutputTensor());
std::vector<ReduceOp> results;
for (const auto& load_tensor : load_upstream_expr){
ir::Expr new_reduce = CreateReduceExpr(
downstream,
ComposeUtils::CopyedReplaceExpr(upstream.GetFuncBody(), upstream.GetOutputIters(), load_tensor.As<ir::Load>()->indices),
upstream.GetInitExpr(),
new_tensor);
ComposeUtils::MappingTargetExprToDestExprMutator(load_tensor.As<ir::Load>()->tensor, new_tensor)(downstream.GetFuncBody());
results.emplace_back(new_reduce);
}

VLOG(4) << "RRTransform end" << result;
return ReduceOp(result);
return results;
}

FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) {
Expand All @@ -558,14 +544,14 @@ FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) {
}
}

FusibleOp ReduceTransform(FusionNode* upstream, FusionNode* downstream) {
std::vector<FusibleOp> ReduceTransform(FusionNode* upstream, FusionNode* downstream) {
if (downstream->IsTrivial()) {
CHECK(CheckAllLoopRangeEq(std::get<ReduceOp>(upstream->fusible_op),
std::get<TrivialOp>(upstream->fusible_op)));
return upstream->fusible_op;
std::get<TrivialOp>(downstream->fusible_op)));
return {upstream->fusible_op};
} else {
return TransformReduceLoopRange(std::get<ReduceOp>(upstream->fusible_op),
std::get<ReduceOp>(upstream->fusible_op));
std::get<ReduceOp>(downstream->fusible_op));
}
}

Expand Down Expand Up @@ -702,7 +688,10 @@ struct FusionGraph {
bfs_candidate.pop();
for (const auto& pair_data : downstream->upstream) {
FusionNode* upstream = pair_data.first;
upstream->fusible_op = ReduceTransform(upstream, downstream);
const auto& new_fusible_ops = ReduceTransform(upstream, downstream);

{TODO: update topo structure with multi upstream nodes}

bfs_candidate.push(upstream);
}
}
Expand Down

0 comments on commit 483edae

Please sign in to comment.