Skip to content

Commit

Permalink
Merge pull request #56 from feifei-111/cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
update op lower
  • Loading branch information
feifei-111 authored Mar 12, 2024
2 parents c0dd054 + e75a6bf commit e96c0fd
Showing 1 changed file with 45 additions and 9 deletions.
54 changes: 45 additions & 9 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ std::set<Expr> GetStoreFromBody(const ir::Expr& body) {
return store_tensor_exprs;
}

bool CheckReduceIterEq(std::vector<ir::Var> up_iter, std::vector<ir::Var> down_iter){}
ir::Expr TransformComputeExpr(ir::Expr up_compute_expr, ir::Expr downstream){}
ir::Expr CreateReduceExpr(std::vector<ir::Var> out_iter, std::vector<ir::Var> reduce_iter, ir::Expr comput_expr, ir::Tensor replaced_tensor){}
}

struct TrivialOp {
Expand Down Expand Up @@ -257,6 +260,8 @@ struct TrivialOp {
return std::vector(load_exprs.begin(), load_exprs.end());
}

ir::Expr GetComputeExpr() const {}

private:
ir::Expr func_body;

Expand Down Expand Up @@ -314,6 +319,10 @@ struct ReduceOp {
return std::vector(load_exprs.begin(), load_exprs.end());
}

std::vector<ir::Var> GetReduceIters() const {}
ir::Expr GetComputeExpr() const {}
ir::Expr GetInitExpr() const {}

private:
ir::Expr func_body;

Expand Down Expand Up @@ -380,11 +389,37 @@ ir::Expr TRFusion(ir::Expr upper, ir::Expr down) {
return fused.GetFuncBody();
}

ir::Expr TransformT2R(ir::Expr body){

ir::Expr TransformT2R(ir::Expr reduce_upper, ir::Expr trivial_down){
ReduceOp upstream(reduce_upper);
TrivialOp downstream(trivial_down);
const auto& replaced_tensor = upstream.GetOutputTensor();
ir::Expr result = ComposeUtils::CreateReduceExpr(
downstream.GetOutputIters(), upstream.GetReduceIters(), downstream.GetComputeExpr(), replaced_tensor);
VLOG(4) << "T2Rransform end" << result;
return result;
}

ir::Expr TransformReduceLoopRange(ir::Expr upper, ir::Expr down){}
ir::Expr TransformReduceLoopRange(ir::Expr upper, ir::Expr down){
VLOG(4) << "RRTransform begin";
ReduceOp upstream(upper);
ReduceOp downstream(down);

const auto& down_out_iter = downstream.GetOutputIters();
const auto& up_reduce_iter = upstream.GetReduceIters();
const auto& down_reduce_iter = downstream.GetReduceIters();

// we just support fuse reduce when reduce iter eq
CHECK(ComposeUtils::CheckReduceIterEq(up_reduce_iter, down_reduce_iter));

// TODO modify up_expr, replace out iter of up_expr i => f(i)
ir::Expr new_expr = ComposeUtils::TransformComputeExpr(upstream.GetComputeExpr(), down);

const auto& replaced_tensor = upstream.GetOutputTensor();

ir::Expr result = ComposeUtils::CreateReduceExpr(down_out_iter, up_reduce_iter, new_expr, replaced_tensor);
VLOG(4) << "RRTransform end" << result;
return result;
}

struct FusionNode {
// Function bodies losses the kind information which needed in trivialop
Expand Down Expand Up @@ -553,9 +588,10 @@ struct FusionGraph {
}

void TransformExitTrivialOpToReduce(){
FusionNode* upstream;
for (FusionNode* exit_node: exit_nodes_){
if (IsTrivialKind(exit_node->op_pattern) && HasReduceUpstream(exit_node)){
exit_node->op_compute_body = TransformT2R(exit_node->op_compute_body);
if (IsTrivialKind(exit_node->op_pattern) && (upstream = FindReduceUpstream(exit_node)) != nullptr){
exit_node->op_compute_body = TransformT2R(exit_node->op_compute_body, upstream->op_computer_body);
exit_node->op_pattern = OpPatternKind::kReduction;
}
}
Expand Down Expand Up @@ -609,14 +645,14 @@ struct FusionGraph {
}
}

bool HasReduceUpstream(FusionNode* node){
FusionNode* FindReduceUpstream(FusionNode* node){
for (const auto& pair_data : node->upstream){
FusionNode* upstream = pair_data.first;
if (IsTrivialKind(upstream->op_pattern)){
return true;
if (!IsTrivialKind(upstream->op_pattern)){
return upstream;
}
}
return false;
return nullptr;
}

private:
Expand Down

0 comments on commit e96c0fd

Please sign in to comment.