diff --git a/paddle/cinn/hlir/framework/pir/trivial_op.cc b/paddle/cinn/hlir/framework/pir/trivial_op.cc index 9df157c7164e7..10e64f7de805e 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op.cc @@ -464,8 +464,24 @@ struct TrivialOp { } std::vector GetOutputIters() const { - return ComposeUtils::GetOutputIters( - GetSingleStoreExpr(func_body).As()->indices); + const ir::Expr& output_schedule_block_realize = + (SearchUtils::ChildScheduleBlocks * + SearchUtils::FindFather(GetFuncBody()) * + SearchUtils::FilterMaker([](const ir::Expr& e) -> bool { + return e.As(); + })).GetSingle(GetFuncBody()); + + const std::vector& output_iter_expr = + output_schedule_block_realize.As() + ->iter_values; + std::vector output_iter_vars; + + std::transform(output_iter_expr.begin(), + output_iter_expr.end(), + output_iter_vars.begin(), + [](const Expr& expr) { return expr.as_var_ref(); }); + + return output_iter_vars; } ir::Expr* GetStoreValuePointer() const { @@ -569,7 +585,7 @@ struct ReduceOp { } std::vector GetAllIterVars() const { - ir::Expr compute_schedule_block_realize = + const ir::Expr& compute_schedule_block_realize = (SearchUtils::ChildScheduleBlocks * SearchUtils::ScheduleBlockIsNotInit * SearchUtils::FindFather(GetFuncBody()) * @@ -591,7 +607,7 @@ struct ReduceOp { } std::vector GetOuterIterVars() const { - ir::Expr init_schedule_block_realize = + const ir::Expr& init_schedule_block_realize = (SearchUtils::ChildScheduleBlocks * SearchUtils::ScheduleBlockIsInit * SearchUtils::FindFather(GetFuncBody()) * SearchUtils::FilterMaker([](const ir::Expr& e) -> bool {