diff --git a/paddle/cinn/hlir/framework/pir/trivial_op.cc b/paddle/cinn/hlir/framework/pir/trivial_op.cc index 9bcf35a9987fd..c27a588ed1fe7 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op.cc @@ -224,12 +224,19 @@ std::set GetStoreFromBody(const ir::Expr& body) { return store_tensor_exprs; } -bool CheckIterEq(std::vector up_iter, - std::vector down_iter){TODO} +std::vector GetOutputIters(const std::vector& indices) { + std::vector vars; + std::transform(indices.begin(), + indices.end(), + std::back_inserter(vars), + [](const ir::Expr& expr) { return expr.as_var_ref(); }); + return vars; +} + -ir::Expr TransformComputeExpr(ir::Expr up_compute_expr, ir::Expr downstream) { +bool CheckIterEq(std::vector up_iter, std::vector down_iter){ TODO -} +} } // namespace ComposeUtils @@ -243,23 +250,16 @@ struct TrivialOp { return GetSingleStoreExpr(func_body).As()->value; } + std::vector GetOutputIters() const { + return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As()->indices); + } + std::vector GetAllIterVar() const { return GetOutputIters(); } ir::Expr* GetStoreValuePointer() const { return &GetSingleStoreExpr(func_body).As()->value; } - std::vector GetOutputIters() const { - std::vector vars; - const auto& indices = - GetSingleStoreExpr(func_body).As()->indices; - std::transform(indices.begin(), - indices.end(), - std::back_inserter(vars), - [](const ir::Expr& expr) { return expr.as_var_ref(); }); - return vars; - } - ir::Expr GetFuncBody() const { return func_body; } ir::Tensor GetOutputTensor() const { @@ -313,14 +313,7 @@ struct ReduceOp { } std::vector GetOutputIters() const { - std::vector vars; - const auto& indices = - GetSingleStoreExpr(func_body).As()->indices; - std::transform(indices.begin(), - indices.end(), - std::back_inserter(vars), - [](const ir::Expr& expr) { return expr.as_var_ref(); }); - return vars; + return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As()->indices); } ir::Expr GetFuncBody() const { return func_body; } @@ -520,31 +513,24 @@ TrivialOp TransformT2R(ReduceOp reduce_upper, TrivialOp trivial_down) {} bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {} -ir::Expr ReplaceReduceComputeBody(const ir::Expr& body, - const ir::Expr& new_body) { - TODO; -} -ReduceOp TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) { +std::vector TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) { VLOG(4) << "RRTransform begin"; - const auto& down_out_iter = downstream.GetOutputIters(); - const auto& up_reduce_iter = upstream.GetReduceIters(); - const auto& down_reduce_iter = downstream.GetReduceIters(); - - // we just support fuse reduce when reduce iter eq - CHECK(ComposeUtils::CheckIterEq(up_reduce_iter, down_reduce_iter)); - - // TODO modify up_expr, replace out iter of up_expr i => f(i) - ir::Expr new_reduce_body = ir::ir_utils::IRCopy(downstream.GetFuncBody()); - ir::Expr reduce_op_expr = ComposeUtils::TransformComputeExpr( - new_reduce_body.GetComputeExpr(), down); - ir::Expr const auto& replaced_tensor = upstream.GetOutputTensor(); - - ir::Expr result = ComposeUtils::CreateReduceExpr(downstream, reduce_op_expr); + CHECK(ComposeUtils::CheckIterEq(upstream.GetReduceIters(), downstream.GetReduceIters())); + const auto& load_upstream_expr = downstream.GetEachTensorLoadExpr(upstream.GetOutputTensor()); + std::vector results; + for (const auto& load_tensor : load_upstream_expr){ + ir::Expr new_reduce = CreateReduceExpr( + downstream, + ComposeUtils::CopyedReplaceExpr(upstream.GetFuncBody(), upstream.GetOutputIters(), load_tensor.As()->indices), + upstream.GetInitExpr(), + new_tensor); + ComposeUtils::MappingTargetExprToDestExprMutator(load_tensor.As()->tensor, new_tensor)(downstream.GetFuncBody()); + results.emplace_back(new_reduce); + } - VLOG(4) << "RRTransform end" << result; - return ReduceOp(result); + return results; } FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) { @@ -558,14 +544,14 @@ FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) { } } -FusibleOp ReduceTransform(FusionNode* upstream, FusionNode* downstream) { +std::vector ReduceTransform(FusionNode* upstream, FusionNode* downstream) { if (downstream->IsTrivial()) { CHECK(CheckAllLoopRangeEq(std::get(upstream->fusible_op), - std::get(upstream->fusible_op))); - return upstream->fusible_op; + std::get(downstream->fusible_op))); + return {upstream->fusible_op}; } else { return TransformReduceLoopRange(std::get(upstream->fusible_op), - std::get(upstream->fusible_op)); + std::get(downstream->fusible_op)); } } @@ -702,7 +688,10 @@ struct FusionGraph { bfs_candidate.pop(); for (const auto& pair_data : downstream->upstream) { FusionNode* upstream = pair_data.first; - upstream->fusible_op = ReduceTransform(upstream, downstream); + const auto& new_fusible_ops = ReduceTransform(upstream, downstream); + + {TODO: update topo structure with multi upstream nodes} + bfs_candidate.push(upstream); } }