diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 2c9c08e789b23..5982c0a723ce6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -651,12 +651,43 @@ bool BilinearOpInferSymbolicShape( // return true; // } -// bool BroadcastTensorsOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool BroadcastTensorsOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape_or_data_list = + infer_context->GetShapeOrDataForValue(op->operand_source(0)) + .dyn_cast(); + + // 1. Find Output rank = max(Inputs rank) + size_t target_rank = 0; + for (const auto &input_shape_or_data : input_shape_or_data_list) { + size_t tmp_rank = input_shape_or_data.shape().size(); + target_rank = std::max(int64_t(target_rank), int64_t(tmp_rank)); + } + + // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) + std::vector out_shape; + symbol::DimExprBuilder builder; + for (size_t i = 0; i < target_rank; i++) { + auto tmp_dim = symbol::DimExpr{1}; + for (const auto &input_shape_or_data : input_shape_or_data_list) { + size_t axis = i - target_rank + input_shape_or_data.size(); + if (axis >= 0) { + infer_context->AddBroadcastableCstr(input_shape_or_data.shape()[axis], + tmp_dim); + tmp_dim = builder.Broadcast(input_shape_or_data.shape()[axis], tmp_dim); + } + } + out_shape.emplace_back(tmp_dim); + } + + symbol::TensorListShapeOrDataDimExprs out_shapes; + for (size_t i = 0; i < input_shape_or_data_list.size(); i++) { + out_shapes.emplace_back(out_shape); + } + infer_context->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs{out_shapes}); + return true; +} bool BilinearInterpOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 3ae4ccf4dace5..d95a433c08c09 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -24,7 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 3c7fb22594626..f5d3197a8e296 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -679,6 +679,7 @@ func: broadcast_tensors data_type : input backward: broadcast_tensors_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : c_allgather args : (Tensor x, int ring_id, int nranks, bool use_calc_stream)