Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR][DynamicShape] Add strategy for compatibility in select_input op #62381

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -999,19 +999,20 @@ bool SelectInputOp::InferSymbolicShape(
const auto &input1_dims = GetSymExprForValue(operand_source(0));
const auto &input2_dims = GetSymExprForValue(operand_source(1));

// for compatibility, we just return second_shape.
if (input1_dims.size() != input2_dims.size()) {
shape_analysis->SetShapeOrDataForValue(
result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(input2_dims)});
return true;
}

std::vector<symbol::DimExpr> out_dims = input1_dims;
// merge shape for input1 and input2, since we don't know which will be
// selected in compile time, the strategy is same with IfOp, see IfOp's
// comments for details and examples
if (input2_dims.size() != 0) {
// now only support input1 and input2 have same rank.
PADDLE_ENFORCE_EQ(input1_dims.size(),
input2_dims.size(),
phi::errors::PreconditionNotMet(
"The true and false block should have same rank, "
"but got true_rank(%d) and false_rank(%d)",
input1_dims.size(),
input2_dims.size()));
for (size_t i = 0; i < input1_dims.size(); i++) {
if (input1_dims[i] != input2_dims[i]) {
out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()};
Expand Down