Skip to content

Commit

Permalink
Update pass for extern_call_process; remove pass for TupleGet
Browse files Browse the repository at this point in the history
  • Loading branch information
Albresky committed Dec 24, 2024
1 parent 570026d commit 6e526f5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 43 deletions.
41 changes: 9 additions & 32 deletions paddle/cinn/optim/extern_call_process_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
// limitations under the License.

#include "paddle/cinn/optim/extern_call_process_pass.h"
#include "paddle/cinn/ir/utils/ir_compare.h"

namespace cinn {
namespace optim {
using ir::stmt::BlockRef;
using ir::stmt::Evaluate;
using ir::stmt::StmtRef;
using ir::stmt::Store;

namespace {

Expand All @@ -25,11 +28,11 @@ void ProcessMultiOutputStore(BlockRef block) {
std::vector<StmtRef> new_stmts;

for (const auto& stmt : stmts) {
if (stmt.isa<ir::Store>()) {
auto* store_op = stmt.as<ir::Store>();
auto* call = store_op->value.As<ir::Call>();
if (stmt.isa<Store>()) {
const auto& store_op = stmt.as<Store>()->value();
const auto& call = store_op.As<ir::Call>();
if (call && call->is_extern_call() && !call->write_args.empty()) {
new_stmts.emplace_back(store_op->value);
new_stmts.emplace_back(Evaluate(store_op));
} else {
new_stmts.emplace_back(stmt);
}
Expand All @@ -41,42 +44,16 @@ void ProcessMultiOutputStore(BlockRef block) {
block->set_stmts(new_stmts);
}

void RemoveTupleGetStatements(BlockRef block) {
const auto& stmts = block->stmts();
std::vector<StmtRef> new_stmts;

for (const auto& stmt : stmts) {
if (stmt.isa<ir::Call>()) {
auto* call = stmt.as<ir::Call>();
if (call && call->is_extern_call() && call->is_tuple_get()) {
continue;
}
}
new_stmts.emplace_back(stmt);
}

block->set_stmts(new_stmts);
}

} // namespace

LogicalResult ExternCallMultiOutputShallowStorePass::Run(ir::stmt::BlockRef block) {
LogicalResult ExternCallMultiOutputShallowStorePass::Run(BlockRef block) {
ProcessMultiOutputStore(block);
return LogicalResult::success();
}

LogicalResult ExternCallRemoveTupleGetStatementsPass::Run(ir::stmt::BlockRef block) {
RemoveTupleGetStatements(block);
return LogicalResult::success();
}

std::unique_ptr<BlockPass> CreateExternCallMultiOutputShallowStorePass() {
return std::make_unique<ExternCallMultiOutputShallowStorePass>();
}

std::unique_ptr<BlockPass> CreateExternCallRemoveTupleGetStatementsPass() {
return std::make_unique<ExternCallRemoveTupleGetStatementsPass>();
}

} // namespace optim
} // namespace cinn
8 changes: 3 additions & 5 deletions paddle/cinn/optim/extern_call_process_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,21 @@

#pragma once

#include "paddle/cinn/ir/stmt.h"
#include "paddle/cinn/pass/pass.h"

namespace cinn {
namespace optim {

class ExternCallMultiOutputShallowStorePass : public BlockPass {
public:
LogicalResult Run(ir::stmt::BlockRef block) override;
};
ExternCallMultiOutputShallowStorePass()
: BlockPass("extern_call_multi_output_shallow_store") {}

class ExternCallRemoveTupleGetStatementsPass : public BlockPass {
public:
LogicalResult Run(ir::stmt::BlockRef block) override;
};

std::unique_ptr<BlockPass> CreateExternCallMultiOutputShallowStorePass();
std::unique_ptr<BlockPass> CreateExternCallRemoveTupleGetStatementsPass();

} // namespace optim
} // namespace cinn
13 changes: 7 additions & 6 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
VLOG(10) << "After Optimize MapExternCall:" << copied;

// ExternCallMultiOutputShallowStore(&copied->body);
BlockPassManager pass_manager;
pass_manager.AddPass(CreateExternCallMultiOutputShallowStorePass());
pass_manager.AddPass(CreateExternCallRemoveTupleGetStatementsPass());
pass_manager.Run(copied);
VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore and ExternCallRemoveTupleGetStatements:" << copied;

BlockPassManager pass_manager0;
pass_manager0.AddPass(CreateExternCallMultiOutputShallowStorePass());
pass_manager0.Run(copied);
VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore and "
"ExternCallRemoveTupleGetStatements:"
<< copied;

// Simplify already contains CastSimplify
Simplify(&copied->body);
VLOG(10) << "After Optimize Simplify:" << copied;
Expand Down

0 comments on commit 6e526f5

Please sign in to comment.