diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 6d2ae9b159df89..a8dd2ec197027f 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -16,7 +16,7 @@ gather_srcs( call_arg_list_to_pod_value.cc insert_debug_log_callee.cc lower_function_call_bind_vars.cc - extern_call_process.cc + extern_call_process_pass.cc map_extern_call.cc compute_inline_expand.cc buffer_assign.cc diff --git a/paddle/cinn/optim/extern_call_process.cc b/paddle/cinn/optim/extern_call_process.cc deleted file mode 100644 index be3636c81982e3..00000000000000 --- a/paddle/cinn/optim/extern_call_process.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/extern_call_process.h" - -#include "paddle/cinn/ir/ir_mutator.h" - -namespace cinn { -namespace optim { - -namespace { - -struct ExternCallMultiOutputShallowStoreMutator : public ir::IRMutator<> { - void operator()(Expr* e) { ir::IRMutator<>::Visit(e, e); } - - private: - void Visit(const ir::Store* op, Expr* expr) override { - auto* call = op->value.As(); - if (call && call->is_extern_call() && !call->write_args.empty()) { - *expr = op->value; - } - } -}; - -} // namespace - -void ExternCallMultiOutputShallowStore(Expr* e) { - ExternCallMultiOutputShallowStoreMutator()(e); -} - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/extern_call_process_pass.cc b/paddle/cinn/optim/extern_call_process_pass.cc new file mode 100644 index 00000000000000..047668afc6b23d --- /dev/null +++ b/paddle/cinn/optim/extern_call_process_pass.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/extern_call_process_pass.h" + +namespace cinn { +namespace optim { +using ir::stmt::BlockRef; +using ir::stmt::Evaluate; +using ir::stmt::StmtRef; +using ir::stmt::Store; + +namespace { + +void ProcessMultiOutputStore(BlockRef block) { + const auto& stmts = block->stmts(); + std::vector new_stmts; + + for (const auto& stmt : stmts) { + if (stmt.isa()) { + const auto& store_op = stmt.as()->value(); + const auto& call = store_op.As(); + if (call && call->is_extern_call() && !call->write_args.empty()) { + new_stmts.emplace_back(Evaluate(store_op)); + } else { + new_stmts.emplace_back(stmt); + } + } else { + new_stmts.emplace_back(stmt); + } + } + + block->set_stmts(new_stmts); +} + +} // namespace + +LogicalResult ExternCallMultiOutputShallowStorePass::Run(BlockRef block) { + ProcessMultiOutputStore(block); + return LogicalResult::success(); +} + +std::unique_ptr CreateExternCallMultiOutputShallowStorePass() { + return std::make_unique(); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/extern_call_process.h b/paddle/cinn/optim/extern_call_process_pass.h similarity index 77% rename from paddle/cinn/optim/extern_call_process.h rename to paddle/cinn/optim/extern_call_process_pass.h index d526db7da78d55..6ed1d52afb046e 100644 --- a/paddle/cinn/optim/extern_call_process.h +++ b/paddle/cinn/optim/extern_call_process_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ #pragma once -#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/stmt.h" +#include "paddle/cinn/pass/pass.h" namespace cinn { namespace optim { @@ -47,18 +48,15 @@ namespace optim { * Output IR: * Store(target, Call(extern_func, args, {})) */ -void ExternCallMultiOutputShallowStore(Expr* e); +std::unique_ptr CreateExternCallMultiOutputShallowStorePass(); -/* - * Remove external call statements that are TupleGet. - * - * This pass is applicable in scenarios where the external call statements are - * TupleGet. - * - * When applied, this pass will traverse the external call statements in the - * block and remove the statements that are TupleGet. - */ -void ExternCallRemoveTupleGetStatements(Expr* e); +class ExternCallMultiOutputShallowStorePass : public BlockPass { + public: + ExternCallMultiOutputShallowStorePass() + : BlockPass("extern_call_multi_output_shallow_store") {} + + LogicalResult Run(ir::stmt::BlockRef block) override; +}; } // namespace optim } // namespace cinn diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index cb552d1f95c44e..561b11a5a815af 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -22,7 +22,7 @@ #include "paddle/cinn/optim/cast_bool_to_int8.h" #include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h" #include "paddle/cinn/optim/eliminate_invariant_loop.h" -#include "paddle/cinn/optim/extern_call_process.h" +#include "paddle/cinn/optim/extern_call_process_pass.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h" #include "paddle/cinn/optim/if_fusion_pass.h" #include "paddle/cinn/optim/insert_debug_log_callee.h" @@ -99,8 +99,14 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, MapExternCall(&copied->body, target); VLOG(10) << "After Optimize MapExternCall:" << copied; - ExternCallMultiOutputShallowStore(&copied->body); - VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore:" << copied; + // ExternCallMultiOutputShallowStore(&copied->body); + BlockPassManager pass_manager0; + pass_manager0.AddPass(CreateExternCallMultiOutputShallowStorePass()); + pass_manager0.Run(copied); + VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore and " + "ExternCallRemoveTupleGetStatements:" + << copied; + // Simplify already contains CastSimplify Simplify(&copied->body); VLOG(10) << "After Optimize Simplify:" << copied;