Skip to content

Commit

Permalink
Add constrains between the shapes of x and mask (PaddlePaddle#65927)
Browse files Browse the repository at this point in the history
  • Loading branch information
gongshaotian authored Jul 11, 2024
1 parent 9aa806b commit 92989e9
Showing 1 changed file with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,25 @@ bool MaskedSelectOpInferSymbolicShape(
out_dims.push_back(out_shape);
return out_dims;
}();
// TODO(fty1777): Add constrains between the shapes of x and mask
// Add constrains between the shapes of x and mask
const std::vector<symbol::DimExpr> &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
const std::vector<symbol::DimExpr> &mask_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();
size_t ndims_x = x_shape.size();
size_t ndims_mask = mask_shape.size();
if (ndims_x >= ndims_mask) {
size_t diff = ndims_x - ndims_mask;
for (size_t i = 0; i < ndims_mask; i++) {
infer_context->AddBroadcastableCstr(x_shape[i + diff], mask_shape[i]);
}
} else {
size_t diff = ndims_mask - ndims_x;
for (size_t i = 0; i < ndims_x; i++) {
infer_context->AddBroadcastableCstr(x_shape[i], mask_shape[i + diff]);
}
}

infer_context->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs{out_dims});
return true;
Expand Down

0 comments on commit 92989e9

Please sign in to comment.