From 5fb15fbdfd22fdf0bf1e2ee00bc5ee7a6863c873 Mon Sep 17 00:00:00 2001 From: lanxianghit Date: Mon, 4 Mar 2024 12:05:11 +0000 Subject: [PATCH] Add strategy for compatibility in select_input op --- .../pir/dialect/operator/ir/control_flow_op.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 7f490cdd24f8a7..60d589773d5bb2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -999,19 +999,20 @@ bool SelectInputOp::InferSymbolicShape( const auto &input1_dims = GetSymExprForValue(operand_source(0)); const auto &input2_dims = GetSymExprForValue(operand_source(1)); + // for compatibility, we just return second_shape. + if (input1_dims.size() != input2_dims.size()) { + shape_analysis->SetShapeOrDataForValue( + result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(input2_dims)}); + return true; + } + std::vector out_dims = input1_dims; // merge shape for input1 and input2, since we don't know which will be // selected in compile time, the strategy is same with IfOp, see IfOp's // comments for details and examples if (input2_dims.size() != 0) { - // now only support input1 and input2 have same rank. - PADDLE_ENFORCE_EQ(input1_dims.size(), - input2_dims.size(), - phi::errors::PreconditionNotMet( - "The true and false block should have same rank, " - "but got true_rank(%d) and false_rank(%d)", - input1_dims.size(), - input2_dims.size())); for (size_t i = 0; i < input1_dims.size(); i++) { if (input1_dims[i] != input2_dims[i]) { out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()};