diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index d6e4d4ba01368..eb263a950ff09 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2376,6 +2376,71 @@ bool Where_OpInferSymbolicShape(pir::Operation *op, return WhereOpInferSymbolicShape(op, infer_context); } +bool MultiplexOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &inputs_shape_or_data_list = + infer_context->GetShapeOrDataForValue(op->operand_source(0)) + .dyn_cast(); + + PADDLE_ENFORCE_NE( + inputs_shape_or_data_list.empty(), + true, + common::errors::InvalidArgument("MultiInput(X) shouldn't be empty.")); + + const auto &ids_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + PADDLE_ENFORCE_EQ(ids_shape_or_data.shape().size(), + 2, + common::errors::PreconditionNotMet( + "The index tensor must be a vector with 2 dimensions")); + + infer_context->AddEqualCstr(ids_shape_or_data.shape()[1], symbol::DimExpr(1)); + + PADDLE_ENFORCE_GT( + inputs_shape_or_data_list.size(), + 1, + common::errors::InvalidArgument("multiplex operator should have more " + "than one candidate input tensors.")); + + size_t num_inputs = inputs_shape_or_data_list.size(); + std::vector first_input_shape = + inputs_shape_or_data_list[0].shape(); + PADDLE_ENFORCE_GE( + first_input_shape.size(), + 2, + common::errors::InvalidArgument( + "The rank of candidate tensors must be not less than 2.")); + + for (size_t i = 1; i < num_inputs; ++i) { + std::vector element_shape = + inputs_shape_or_data_list[i].shape(); + + PADDLE_ENFORCE_EQ(first_input_shape.size(), + element_shape.size(), + common::errors::PreconditionNotMet( + "All the candidate tensors must have the same dim.")); + + for (size_t j = 0; j < first_input_shape.size(); ++j) + infer_context->AddEqualCstr(first_input_shape[j], element_shape[j]); + // all of the input Tensors should have the same shape + } + + if (first_input_shape[0].isa() && + ids_shape_or_data.shape()[0].isa()) { + PADDLE_ENFORCE_GE(first_input_shape[0].dyn_cast(), + ids_shape_or_data.shape()[0].dyn_cast(), + common::errors::InvalidArgument( + "The 2nd-dim of input cannot be smaller than " + "batchSize of the index tensor.")); + } + std::vector &output_shape = first_input_shape; + output_shape[0] = ids_shape_or_data.shape()[0]; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrData{symbol::TensorShapeOrDataDimExprs(output_shape)}); + return true; +} bool YoloLossOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 2336b0d0abbb9..b3f921736e254 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -81,6 +81,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedMomentum) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MergedMomentum_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MulticlassNms3) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multiplex) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Moe) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc index e803f86112d50..2818068e6c8ce 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc @@ -119,6 +119,7 @@ OP_SAME_OPERANDS_AND_RESULT(Logsigmoid) OP_SAME_OPERANDS_AND_RESULT(Logsigmoid_) OP_SAME_OPERANDS_AND_RESULT(Memcpy) OP_SAME_OPERANDS_AND_RESULT(Mish) +OP_SAME_OPERANDS_AND_RESULT(NumberCount) OP_SAME_OPERANDS_AND_RESULT(Pow) OP_SAME_OPERANDS_AND_RESULT(Poisson) OP_SAME_OPERANDS_AND_RESULT(Pow_) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h index 970c24e6cd17b..bdb4c15c6b53f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h @@ -110,6 +110,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Memcpy) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mish) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NumberCount) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index fef1932c63329..f371131f672a2 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3346,6 +3346,7 @@ backward : multiplex_grad data_transform : skip_transform : index + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : mv args : (Tensor x, Tensor vec) @@ -5211,3 +5212,4 @@ kernel: func: number_count data_type: numbers + interfaces : paddle::dialect::InferSymbolicShapeInterface