diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index b5bb10f4f173e9..8d1b90db03ceb7 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -290,6 +290,26 @@ bool FlashAttnOpInferSymbolicShape( return true; } +bool GroupNormOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &x_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + + shape_analysis->SetShapeOrDataForValue(op->result(0), x_shape); + + const symbol::DimExpr &batch_size = x_shape.shape()[0]; + int groups = op->attribute("groups").data(); + symbol::TensorShapeOrDataDimExprs mean_shape( + std::vector{batch_size, groups}); + if (op->result(1)) { + shape_analysis->SetShapeOrDataForValue(op->result(1), mean_shape); + } + if (op->result(2)) { + shape_analysis->SetShapeOrDataForValue(op->result(2), mean_shape); + } + return true; +} + bool LinspaceOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { const auto &num_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index be528d31139cf3..2e5f506b373ac4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -23,6 +23,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace) OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 7a8c69464a5351..1adc5767862474 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -391,6 +391,58 @@ bool PadOpInferSymbolicShape(pir::Operation *op, return true; } +bool Pad3dOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + PADDLE_ENFORCE_EQ(x_shape.size(), + 5, + common::errors::InvalidArgument( + "The size of Input(X)'s dimension should be equal to " + "5, but received %d. ", + x_shape.size())); + const auto &paddings_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + if (!paddings_shape.data().has_value()) { + std::stringstream ss; + ss << paddings_shape; + PADDLE_THROW( + common::errors::InvalidArgument("The data of paddings's symbol shape " + "should have value, but now got [%s].", + ss.str())); + } + const std::string &data_format = + op->attribute("data_format").AsString(); + + const std::vector &out_dims = [&] { + std::vector out_dims = x_shape; + const auto &paddings = paddings_shape.data().value(); + PADDLE_ENFORCE_EQ(paddings.size(), + 6, + common::errors::InvalidArgument( + "Shape of Input(Paddings) should be equal to " + "[6], but received [%d].", + paddings.size())); + if (data_format == "NCDHW") { + out_dims[1] = x_shape[1]; + out_dims[2] = x_shape[2] + paddings[4] + paddings[5]; + out_dims[3] = x_shape[3] + paddings[2] + paddings[3]; + out_dims[4] = x_shape[4] + paddings[0] + paddings[1]; + } else { + out_dims[1] = x_shape[1] + paddings[4] + paddings[5]; + out_dims[2] = x_shape[2] + paddings[2] + paddings[3]; + out_dims[3] = x_shape[3] + paddings[0] + paddings[1]; + out_dims[4] = x_shape[4]; + } + return out_dims; + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims)); + + return true; +} + bool ProdOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { bool keepdim = GetBoolAttr(op, "keep_dim"); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index e52b9aabc15689..8155b4f0ac5052 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -37,6 +37,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 38c65b98930bd8..56205cc06d8078 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1324,6 +1324,7 @@ optional : scale, bias intermediate : mean, variance backward : group_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : gumbel_softmax args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1) @@ -2204,6 +2205,7 @@ kernel : func : pad3d backward : pad3d_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : pixel_shuffle args : (Tensor x, int upscale_factor=1, str data_format="NCHW")