Skip to content

Commit

Permalink
[PIR][DynamicShape] Add strategy for compatibility in select_input op (
Browse files Browse the repository at this point in the history
…#62381)

Add strategy for compatibility in select_input op
  • Loading branch information
lanxianghit authored Mar 5, 2024
1 parent 2ab2994 commit dfb0f89
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -999,19 +999,20 @@ bool SelectInputOp::InferSymbolicShape(
const auto &input1_dims = GetSymExprForValue(operand_source(0));
const auto &input2_dims = GetSymExprForValue(operand_source(1));

// for compatibility, we just return second_shape.
if (input1_dims.size() != input2_dims.size()) {
shape_analysis->SetShapeOrDataForValue(
result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(input2_dims)});
return true;
}

std::vector<symbol::DimExpr> out_dims = input1_dims;
// merge shape for input1 and input2, since we don't know which will be
// selected in compile time, the strategy is same with IfOp, see IfOp's
// comments for details and examples
if (input2_dims.size() != 0) {
// now only support input1 and input2 have same rank.
PADDLE_ENFORCE_EQ(input1_dims.size(),
input2_dims.size(),
phi::errors::PreconditionNotMet(
"The true and false block should have same rank, "
"but got true_rank(%d) and false_rank(%d)",
input1_dims.size(),
input2_dims.size()));
for (size_t i = 0; i < input1_dims.size(); i++) {
if (input1_dims[i] != input2_dims[i]) {
out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()};
Expand Down

0 comments on commit dfb0f89

Please sign in to comment.