Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN][Backend Pass Update No.3] Update extern_call_process pass #70191

Merged
2 changes: 1 addition & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ gather_srcs(
call_arg_list_to_pod_value.cc
insert_debug_log_callee.cc
lower_function_call_bind_vars.cc
extern_call_process.cc
extern_call_process_pass.cc
map_extern_call.cc
compute_inline_expand.cc
buffer_assign.cc
Expand Down
82 changes: 82 additions & 0 deletions paddle/cinn/optim/extern_call_process_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/extern_call_process_pass.h"
#include "paddle/cinn/ir/utils/ir_compare.h"

namespace cinn {
namespace optim {

namespace {

void ProcessMultiOutputStore(BlockRef block) {
const auto& stmts = block->stmts();
std::vector<StmtRef> new_stmts;

for (const auto& stmt : stmts) {
if (stmt.isa<ir::Store>()) {
auto* store_op = stmt.as<ir::Store>();
auto* call = store_op->value.As<ir::Call>();
if (call && call->is_extern_call() && !call->write_args.empty()) {
new_stmts.emplace_back(store_op->value);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call现在是个Expr,如果要它脱离store成为一条stmt需要使用Evaluate stmt包装一下

new_stmts.emplace_back(stmt);
}
} else {
new_stmts.emplace_back(stmt);
}
}

block->set_stmts(new_stmts);
}

void RemoveTupleGetStatements(BlockRef block) {
const auto& stmts = block->stmts();
std::vector<StmtRef> new_stmts;

for (const auto& stmt : stmts) {
if (stmt.isa<ir::Call>()) {
auto* call = stmt.as<ir::Call>();
if (call && call->is_extern_call() && call->is_tuple_get()) {
continue;
}
}
new_stmts.emplace_back(stmt);
}

block->set_stmts(new_stmts);
}

} // namespace

LogicalResult ExternCallMultiOutputShallowStorePass::Run(ir::stmt::BlockRef block) {
ProcessMultiOutputStore(block);
return LogicalResult::success();
}

LogicalResult ExternCallRemoveTupleGetStatementsPass::Run(ir::stmt::BlockRef block) {
RemoveTupleGetStatements(block);
return LogicalResult::success();
}

std::unique_ptr<BlockPass> CreateExternCallMultiOutputShallowStorePass() {
return std::make_unique<ExternCallMultiOutputShallowStorePass>();
}

std::unique_ptr<BlockPass> CreateExternCallRemoveTupleGetStatementsPass() {
return std::make_unique<ExternCallRemoveTupleGetStatementsPass>();
}

} // namespace optim
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/optim/extern_call_process_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/pass/pass.h"

namespace cinn {
namespace optim {

class ExternCallMultiOutputShallowStorePass : public BlockPass {
public:
LogicalResult Run(ir::stmt::BlockRef block) override;
};

class ExternCallRemoveTupleGetStatementsPass : public BlockPass {
public:
LogicalResult Run(ir::stmt::BlockRef block) override;
};

std::unique_ptr<BlockPass> CreateExternCallMultiOutputShallowStorePass();
std::unique_ptr<BlockPass> CreateExternCallRemoveTupleGetStatementsPass();

} // namespace optim
} // namespace cinn
11 changes: 8 additions & 3 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/eliminate_invariant_loop.h"
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/extern_call_process_pass.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/if_fusion_pass.h"
#include "paddle/cinn/optim/insert_debug_log_callee.h"
Expand Down Expand Up @@ -99,8 +99,13 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
MapExternCall(&copied->body, target);
VLOG(10) << "After Optimize MapExternCall:" << copied;

ExternCallMultiOutputShallowStore(&copied->body);
VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore:" << copied;
// ExternCallMultiOutputShallowStore(&copied->body);
BlockPassManager pass_manager;
pass_manager.AddPass(CreateExternCallMultiOutputShallowStorePass());
pass_manager.AddPass(CreateExternCallRemoveTupleGetStatementsPass());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExternCallRemoveTupleGetStatements没有使用,且之前没有实现,可以不实现并删除

pass_manager.Run(copied);
VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore and ExternCallRemoveTupleGetStatements:" << copied;

// Simplify already contains CastSimplify
Simplify(&copied->body);
VLOG(10) << "After Optimize Simplify:" << copied;
Expand Down