diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 11adafeb85be4..35af8a0246594 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -33,12 +33,14 @@ #include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/utils/stmt_converter.h" #include "paddle/cinn/lang/placeholder.h" #include "paddle/cinn/operator_fusion/fusion_interface.h" #include "paddle/cinn/optim/check_tensor_buffer_map.h" #include "paddle/cinn/optim/eliminate_common_global_memory_read.h" -#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/schedule_block_dce_pass.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/cinn/pass/pass_manager.h" #include "paddle/common/ddim.h" #include "paddle/common/enforce.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -368,7 +370,18 @@ std::vector OpLowererImpl::PostProcess( std::vector lowered_funcs; for (int i = 0; i < func_bodies.size(); ++i) { ir::Expr func_body = func_bodies[i]; - optim::EliminateDeadScheduleBlock(&(func_body), group->output_names()); + + if (func_body.As()) { + VLOG(6) << "Before CreateEliminateDeadScheduleBlockPass: \n" << func_body; + ir::stmt::BlockRef _block = ir::ConvertExprBlockToStmtBlock(func_body); + optim::BlockPassManager pass_manager; + pass_manager.AddPass( + optim::CreateEliminateDeadScheduleBlockPass(group->output_names())); + pass_manager.Run(_block); + func_body = ir::ConvertStmtBlockToExprBlock(_block); + VLOG(6) << "After CreateEliminateDeadScheduleBlockPass: \n" << func_body; + } + if (i != func_bodies.size() - 1) { cinn::common::DefaultDeviceTarget().arch.Match( [&](std::variant { - explicit ScheduleBlockDCE(const std::vector& output_names) - : output_names_(output_names.begin(), output_names.end()) {} - - void operator()(ir::Expr* expr) { - UpdateDeadScheduleBlocks(*expr); - while (!dead_schedule_block_names_.empty()) { - Visit(expr); - UpdateDeadScheduleBlocks(*expr); - } - } - - private: - void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Block* op, Expr* expr) override { - auto* node = expr->As(); - PADDLE_ENFORCE_NOT_NULL(node, - ::common::errors::InvalidArgument( - "Sorry, but expr->As node is nullptr")); - for (auto& stmt : node->stmts) { - IRMutator::Visit(&stmt, &stmt); - } - - std::unordered_set need_remove_ids; - for (int i = 0; i < node->stmts.size(); ++i) { - if (IsDeadScheduleBlock(node->stmts[i]) || IsEmptyBlock(node->stmts[i])) { - need_remove_ids.insert(i); - } - } - if (!need_remove_ids.empty()) { - node->stmts = [&] { - std::vector new_stmts; - for (int i = 0; i < node->stmts.size(); ++i) { - if (need_remove_ids.count(i) == 0) { - new_stmts.push_back(node->stmts[i]); - } - } - return new_stmts; - }(); - } - } - - void Visit(const ir::IfThenElse* op, Expr* expr) override { - auto* node = expr->As(); - PADDLE_ENFORCE_NOT_NULL(node, - ::common::errors::InvalidArgument( - "Sorry, but node expr->As is nullptr")); - IRMutator::Visit(&node->true_case, &node->true_case); - if (node->false_case.defined()) { - IRMutator::Visit(&node->false_case, &node->false_case); - } - if (IsEmptyIf(op)) { - *expr = ir::Block::Make({}); - } - } - - void Visit(const ir::For* op, Expr* expr) override { - auto* node = expr->As(); - PADDLE_ENFORCE_NOT_NULL(node, - ::common::errors::InvalidArgument( - "Sorry, but node expr->As is nullptr")); - IRMutator::Visit(&(node->body), &(node->body)); - if (IsEmptyBlock(op->body)) { - *expr = ir::Block::Make({}); - } - } - - bool IsEmptyBlock(const ir::Expr& expr) { - const auto* block_node = expr.As(); - if (block_node == nullptr) return false; - for (const auto& stmt : block_node->stmts) { - if (!IsEmptyBlock(stmt)) return false; - } - return true; - } - - bool IsEmptyIf(const ir::IfThenElse* node) { - if (node->false_case.defined()) { - return IsEmptyBlock(node->true_case) && IsEmptyBlock(node->false_case); - } - return IsEmptyBlock(node->true_case); - } - - bool IsDeadScheduleBlock(const ir::Expr& expr) { - const auto* sbr = expr.As(); - return sbr != nullptr && - dead_schedule_block_names_.count( - sbr->schedule_block.As()->name) > 0; - } - - void UpdateDeadScheduleBlocks(const ir::Expr& expr) { - dead_schedule_block_names_.clear(); - std::unordered_set load_buffer_names; - std::unordered_set load_tensor_names; - auto InsertLoadTensorAndBufferNames = [&](const ir::Expr* x) -> bool { - if (const ir::Load* load = x->As()) { - load_buffer_names.insert(load->tensor.as_tensor()->buffer->name); - load_tensor_names.insert(load->tensor.as_tensor()->name); - } - return false; - }; - ir::ir_utils::CollectIRNodes(expr, InsertLoadTensorAndBufferNames); - - auto IsShareBufferWithLoadedTensor = - [&](const ir::_Tensor_* tensor) -> bool { - return load_buffer_names.count(tensor->buffer->name) > 0; - }; - auto IsLoadedTensor = [&](const ir::_Tensor_* tensor) -> bool { - return load_tensor_names.count(tensor->name) > 0; - }; - auto IsOutputTensor = [&](const ir::_Tensor_* tensor) -> bool { - return output_names_.count(tensor->name) > 0; - }; - auto IsDeadStore = [&](const ir::Store* store) -> bool { - const ir::_Tensor_* tensor = store->tensor.as_tensor(); - return !IsOutputTensor(tensor) && !IsLoadedTensor(tensor) && - !IsShareBufferWithLoadedTensor(tensor); - }; - auto InsertDeadStoreName = [&](const ir::Expr* x) -> bool { - const ir::Store* store = x->As(); - if (store != nullptr && IsDeadStore(store)) { - VLOG(6) << "Find dead schedule block: " - << store->tensor.as_tensor()->name; - dead_schedule_block_names_.insert(store->tensor.as_tensor()->name); - } - return false; - }; - ir::ir_utils::CollectIRNodes(expr, InsertDeadStoreName); - } - - private: - std::unordered_set dead_schedule_block_names_; - std::unordered_set output_names_; -}; - -void EliminateDeadScheduleBlock(Expr* e, - const std::vector& output_names) { - VLOG(6) << "Start EliminateDeadScheduleBlock" << *e; - ScheduleBlockDCE eliminator(output_names); - eliminator(e); - VLOG(6) << "End EliminateDeadScheduleBlock: " << *e; -} - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/schedule_block_dce_pass.cc b/paddle/cinn/optim/schedule_block_dce_pass.cc new file mode 100644 index 0000000000000..353ccfee23a31 --- /dev/null +++ b/paddle/cinn/optim/schedule_block_dce_pass.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/schedule_block_dce_pass.h" + +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/stmt_visitors.h" + +namespace cinn { +namespace optim { +using ir::stmt::_Block_; +using ir::stmt::Alloc; +using ir::stmt::BlockRef; +using ir::stmt::Evaluate; +using ir::stmt::For; +using ir::stmt::Free; +using ir::stmt::IfThenElse; +using ir::stmt::Let; +using ir::stmt::Schedule; +using ir::stmt::StmtRef; +using ir::stmt::Store; + +class DSBNamesCollectorInStmt : public ir::stmt::StmtVisitor<> { + public: + explicit DSBNamesCollectorInStmt( + std::unordered_set* dead_schedule_block_names, + std::unordered_set* output_names) + : dead_schedule_block_names_(dead_schedule_block_names), + output_names_(output_names) {} + + void operator()(const BlockRef& block) { + dead_schedule_block_names_->clear(); + ir::stmt::StmtVisitor<>::VisitBlock(block); + } + + private: + void VisitStmt(const IfThenElse& stmt) override { + VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + } + } + + void VisitStmt(const For& stmt) override { VisitBlock(stmt->body()); } + + void VisitStmt(const Schedule& stmt) override { VisitBlock(stmt->body()); } + + void VisitStmt(const Let& stmt) override { + UpdateDeadScheduleBlocks(stmt->body()); + } + + void VisitStmt(const Store& stmt) override { + UpdateDeadScheduleBlocks(stmt->value()); + } + + void VisitStmt(const Evaluate& stmt) override {} + + void VisitStmt(const Alloc& stmt) override {} + + void VisitStmt(const Free& stmt) override {} + + void UpdateDeadScheduleBlocks(const ir::Expr& expr) { + std::unordered_set load_buffer_names; + std::unordered_set load_tensor_names; + auto InsertLoadTensorAndBufferNames = [&](const ir::Expr* x) -> void { + if (const ir::Load* load = x->As()) { + load_buffer_names.insert(load->tensor.as_tensor()->buffer->name); + load_tensor_names.insert(load->tensor.as_tensor()->name); + } + }; + InsertLoadTensorAndBufferNames(&expr); + + auto IsShareBufferWithLoadedTensor = + [&](const ir::_Tensor_* tensor) -> bool { + return load_buffer_names.count(tensor->buffer->name) > 0; + }; + auto IsLoadedTensor = [&](const ir::_Tensor_* tensor) -> bool { + return load_tensor_names.count(tensor->name) > 0; + }; + auto IsOutputTensor = [&](const ir::_Tensor_* tensor) -> bool { + return output_names_->count(tensor->name) > 0; + }; + auto IsDeadStore = [&](const ir::Store* store) -> bool { + const ir::_Tensor_* tensor = store->tensor.as_tensor(); + return !IsOutputTensor(tensor) && !IsLoadedTensor(tensor) && + !IsShareBufferWithLoadedTensor(tensor); + }; + auto InsertDeadStoreName = [&](const ir::Expr* x) -> void { + const ir::Store* store = x->As(); + if (store != nullptr && IsDeadStore(store)) { + VLOG(6) << "Find dead schedule block names: \n" + << store->tensor.as_tensor()->name; + dead_schedule_block_names_->insert(store->tensor.as_tensor()->name); + } + }; + InsertDeadStoreName(&expr); + } + + std::unordered_set* dead_schedule_block_names_; + std::unordered_set* output_names_; +}; + +class ScheduleBlockDCE { + public: + explicit ScheduleBlockDCE(const std::vector& output_names) + : output_names_(output_names.begin(), output_names.end()) {} + + void operator()(BlockRef block) { + DSBNamesCollectorInStmt collector(&dead_schedule_block_names_, + &output_names_); + collector(block); + while (!dead_schedule_block_names_.empty()) { + VisitBlock(block); + DSBNamesCollectorInStmt collector(&dead_schedule_block_names_, + &output_names_); + } + } + + private: + void VisitBlock(BlockRef block) { + const auto& stmts = block->stmts(); + std::unordered_set need_remove_ids; + for (int i = 0; i < block->stmts().size(); ++i) { + if ((stmts[i].isa() && VisitStmt(stmts[i].as())) || + (stmts[i].isa() && + VisitStmt(stmts[i].as())) || + (stmts[i].isa() && VisitStmt(stmts[i].as()))) { + need_remove_ids.insert(i); + } + } + + if (!need_remove_ids.empty()) { + std::vector new_stmts; + for (int i = 0; i < block->stmts().size(); ++i) { + VLOG(6) << "[TEST] Remove dead schedule block: \n" << i << "\n"; + if (need_remove_ids.count(i) == 0) { + new_stmts.push_back(block->stmts()[i]); + } + } + block->set_stmts(new_stmts); + } + } + + bool VisitStmt(IfThenElse stmt) { + VisitBlock(stmt->true_case()); + if (stmt->false_case().defined()) { + VisitBlock(stmt->false_case()); + } + return (IsEmptyIf(stmt)); + } + + bool VisitStmt(For stmt) { + VisitBlock(stmt->body()); + return (IsEmptyBlock(stmt->body())); + } + + bool VisitStmt(const Schedule& stmt) { + return !stmt->block_fields().empty() && + dead_schedule_block_names_.count(stmt->name()) > 0; + } + + bool IsEmptyStmt(const StmtRef& stmt) { + if (stmt->block_fields().empty()) return false; + for (const BlockRef& block : stmt->block_fields()) { + if (!IsEmptyBlock(block)) return false; + } + return true; + } + + bool IsEmptyBlock(const BlockRef& block) { + if (block->stmts().empty()) return false; + for (const StmtRef& stmt : block->stmts()) { + if (!IsEmptyStmt(stmt)) return false; + } + return true; + } + + bool IsEmptyIf(const IfThenElse& stmt) { + if (stmt->false_case().defined()) { + return IsEmptyBlock(stmt->true_case()) && + IsEmptyBlock(stmt->false_case()); + } + return IsEmptyBlock(stmt->true_case()); + } + + private: + std::unordered_set dead_schedule_block_names_; + std::unordered_set output_names_; +}; + +LogicalResult EliminateDeadScheduleBlockPass::Run(BlockRef stmt) { + EliminateDeadScheduleBlock(stmt); + return LogicalResult::success(); +} + +std::unique_ptr CreateEliminateDeadScheduleBlockPass( + const std::vector& output_names) { + return std::make_unique(output_names); +} + +void EliminateDeadScheduleBlockPass::EliminateDeadScheduleBlock( + BlockRef block) { + ScheduleBlockDCE eliminator(this->output_names); + eliminator(block); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/schedule_block_dce.h b/paddle/cinn/optim/schedule_block_dce_pass.h similarity index 81% rename from paddle/cinn/optim/schedule_block_dce.h rename to paddle/cinn/optim/schedule_block_dce_pass.h index e1341fe97e31c..f907c6a76889b 100644 --- a/paddle/cinn/optim/schedule_block_dce.h +++ b/paddle/cinn/optim/schedule_block_dce_pass.h @@ -16,10 +16,9 @@ * This file implements the strategy to remove the unnecessary schedule_block. */ #pragma once -#include - #include "paddle/cinn/common/common.h" #include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/pass/pass.h" namespace cinn { namespace optim { @@ -56,8 +55,22 @@ namespace optim { * be reduced. * */ -void EliminateDeadScheduleBlock(Expr* e, - const std::vector& output_names); +class EliminateDeadScheduleBlockPass : public BlockPass { + public: + explicit EliminateDeadScheduleBlockPass( + const ::std::vector& output_names) + : BlockPass("eliminate_dead_schedule_block"), + output_names(output_names) {} + LogicalResult Run(ir::stmt::BlockRef block) override; + + private: + void EliminateDeadScheduleBlock(ir::stmt::BlockRef block); + + std::vector output_names; +}; + +std::unique_ptr CreateEliminateDeadScheduleBlockPass( + const std::vector& output_names); } // namespace optim } // namespace cinn