From aacdc4d9ac1573fd65e6170afd0335d8e666fba2 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 4 Jan 2024 11:16:57 +0800 Subject: [PATCH] [Dynamic Shape] Erase expand (#60525) * EraseExpandOp * minor fix * minor fix * Code format --- .../group_merge/cinn_group_lowering_pass.cc | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc index f4aa34bbc7263..db2dd030ba702 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc @@ -34,10 +34,85 @@ #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + PD_DECLARE_bool(cinn_enable_map_expr); namespace { +using ShapeOrDataDimExprs4ValueT = + std::function; + +pir::Block::ConstIterator FindFirstExpandOp(pir::Block* block) { + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (iter->isa()) { + return iter; + } + } +} + +bool SameInputOutputShape( + paddle::dialect::ExpandOp expand_op, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + const auto& x = ShapeOrDataDimExprs4Value(expand_op.x()); + const auto& shape = ShapeOrDataDimExprs4Value(expand_op.shape()); + const auto& out = ShapeOrDataDimExprs4Value(expand_op.out()); + if (x.data().has_value()) return false; + if (!shape.data().has_value()) return false; + if (out.data().has_value()) return false; + CHECK(shape.data().value() == out.shape()); + return x.shape() == out.shape(); +} + +void ReplaceAllUsesWithInput(paddle::dialect::ExpandOp expand) { + pir::Value x = expand.x(); + expand.out().ReplaceAllUsesWith(x); +} + +void EraseExpandOp(pir::Block* block, pir::Block::ConstIterator expand_it) { + block->erase(expand_it); +} + +void EraseUpstreamGenerateShapeOp( + pir::Block* block, cinn::dialect::GenerateShapeOp generate_shape_op) { + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (iter->isa()) { + if (iter->dyn_cast() == + generate_shape_op) { + block->erase(iter); + } + } + } +} + +// Returns true if success +bool EraseOneExpand( + pir::Block* block, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + for (auto expand_it = block->begin(); expand_it != block->end(); + ++expand_it) { + if (!expand_it->isa()) continue; + auto expand = expand_it->dyn_cast(); + if (!SameInputOutputShape(expand, ShapeOrDataDimExprs4Value)) continue; + auto generate_shape_op = + expand.shape().defining_op(); + CHECK_NOTNULL(generate_shape_op); + ReplaceAllUsesWithInput(expand); + EraseExpandOp(block, expand_it); + EraseUpstreamGenerateShapeOp(block, generate_shape_op); + return true; + } + return false; +} + +void EraseExpands(pir::Block* block, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + while (EraseOneExpand(block, ShapeOrDataDimExprs4Value)) { + // Do nothing. + } +} + std::vector GetBlockOutsideInput( const std::vector op_list) { std::vector vec_res;