Skip to content

Commit

Permalink
[Dynamic Shape] Erase expand (#60525)
Browse files Browse the repository at this point in the history
* EraseExpandOp

* minor fix

* minor fix

* Code format
  • Loading branch information
jiahy0825 authored Jan 4, 2024
1 parent 1b26966 commit aacdc4d
Showing 1 changed file with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,85 @@
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

PD_DECLARE_bool(cinn_enable_map_expr);

namespace {

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;

pir::Block::ConstIterator FindFirstExpandOp(pir::Block* block) {
for (auto iter = block->begin(); iter != block->end(); ++iter) {
if (iter->isa<paddle::dialect::ExpandOp>()) {
return iter;
}
}
}

bool SameInputOutputShape(
paddle::dialect::ExpandOp expand_op,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
const auto& x = ShapeOrDataDimExprs4Value(expand_op.x());
const auto& shape = ShapeOrDataDimExprs4Value(expand_op.shape());
const auto& out = ShapeOrDataDimExprs4Value(expand_op.out());
if (x.data().has_value()) return false;
if (!shape.data().has_value()) return false;
if (out.data().has_value()) return false;
CHECK(shape.data().value() == out.shape());
return x.shape() == out.shape();
}

void ReplaceAllUsesWithInput(paddle::dialect::ExpandOp expand) {
pir::Value x = expand.x();
expand.out().ReplaceAllUsesWith(x);
}

void EraseExpandOp(pir::Block* block, pir::Block::ConstIterator expand_it) {
block->erase(expand_it);
}

void EraseUpstreamGenerateShapeOp(
pir::Block* block, cinn::dialect::GenerateShapeOp generate_shape_op) {
for (auto iter = block->begin(); iter != block->end(); ++iter) {
if (iter->isa<cinn::dialect::GenerateShapeOp>()) {
if (iter->dyn_cast<cinn::dialect::GenerateShapeOp>() ==
generate_shape_op) {
block->erase(iter);
}
}
}
}

// Returns true if success
bool EraseOneExpand(
pir::Block* block,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
for (auto expand_it = block->begin(); expand_it != block->end();
++expand_it) {
if (!expand_it->isa<paddle::dialect::ExpandOp>()) continue;
auto expand = expand_it->dyn_cast<paddle::dialect::ExpandOp>();
if (!SameInputOutputShape(expand, ShapeOrDataDimExprs4Value)) continue;
auto generate_shape_op =
expand.shape().defining_op<cinn::dialect::GenerateShapeOp>();
CHECK_NOTNULL(generate_shape_op);
ReplaceAllUsesWithInput(expand);
EraseExpandOp(block, expand_it);
EraseUpstreamGenerateShapeOp(block, generate_shape_op);
return true;
}
return false;
}

void EraseExpands(pir::Block* block,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
while (EraseOneExpand(block, ShapeOrDataDimExprs4Value)) {
// Do nothing.
}
}

std::vector<pir::Value> GetBlockOutsideInput(
const std::vector<pir::Operation*> op_list) {
std::vector<pir::Value> vec_res;
Expand Down

0 comments on commit aacdc4d

Please sign in to comment.