diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index cd21ae8178c64..1f5320f518524 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1414,6 +1414,155 @@ bool MinOpInferSymbolicShape(pir::Operation *op, return MaxOpInferSymbolicShape(op, infer_context); } +bool MaxPoolWithIndexOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + + std::vector paddings_ = + paddle::dialect::details::GetVectorAttr(op, "paddings"); + std::vector strides = + paddle::dialect::details::GetVectorAttr(op, "strides"); + std::vector kernel_sizes_ = + paddle::dialect::details::GetVectorAttr(op, "kernel_sizes"); + + std::vector kernel_size_; + int rank_kernel = kernel_sizes_.size(); + for (int i = 0; i < rank_kernel; ++i) { + kernel_size_.push_back(kernel_sizes_[i]); + } + + bool adaptive = op->attribute("adaptive").data(); + bool ceil_mode = op->attribute("ceil_mode").data(); + bool global_pooling = + op->attribute("global_pooling").data(); + + int rank_x = x_shape.size(); + int rank = kernel_size_.size(); + + if (global_pooling) { + kernel_size_.resize(rank_x - 2); + for (int i = 0; i < rank; ++i) { + paddings_[i] = 0; + kernel_size_[i] = x_shape[i + 2]; + } + } + + PADDLE_ENFORCE_EQ( + x_shape.size() - kernel_size_.size(), + 2U, + phi::errors::InvalidArgument( + "The input size %d minus the kernel size %d should equal to 2.", + x_shape.size(), + kernel_size_.size())); + + std::vector out_shape = {x_shape[0], x_shape[1]}; + + if (adaptive) { + out_shape.insert(out_shape.end(), kernel_size_.begin(), kernel_size_.end()); + } else { + for (int i = 0; i < rank; ++i) { + PADDLE_ENFORCE_NE( + strides[i], + 0, + phi::errors::InvalidArgument( + "The stride of MaxPool shall not be 0, but received %d.", + strides[i])); + if (ceil_mode) { + out_shape.push_back( + symbol::DimExpr((x_shape[i + 2] - kernel_size_[i] + + 2 * paddings_[i] + strides[i] - 1) / + strides[i] + + 1)); + } else { + out_shape.push_back(symbol::DimExpr( + (x_shape[i + 2] - kernel_size_[i] + 2 * paddings_[i]) / strides[i] + + 1)); + } + } + } + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + +bool MaxPool2dWithIndexOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + x_shape.size(), + 4, + phi::errors::InvalidArgument("Pooling intput should be 4-D Tensor" + "but received %dD-Tensor", + x_shape.size())); + + std::vector paddings_ = + paddle::dialect::details::GetVectorAttr(op, "paddings"); + std::vector strides = + paddle::dialect::details::GetVectorAttr(op, "strides"); + + PADDLE_ENFORCE_EQ( + paddings_.size(), + 2, + phi::errors::InvalidArgument( + "It is expected paddings_size equals to 2, but got size %d", + paddings_.size())); + PADDLE_ENFORCE_EQ( + strides.size(), + 2, + phi::errors::InvalidArgument( + "It is expected strides_size equals to 2, but got size %d", + strides.size())); + + return MaxPoolWithIndexOpInferSymbolicShape(op, infer_context); +} + +bool MaxPool3dWithIndexOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + x_shape.size(), + 5, + phi::errors::InvalidArgument("Pooling intput should be 5-D Tensor" + "but received %dD-Tensor", + x_shape.size())); + + std::vector paddings_ = + paddle::dialect::details::GetVectorAttr(op, "paddings"); + std::vector strides = + paddle::dialect::details::GetVectorAttr(op, "strides"); + + PADDLE_ENFORCE_EQ( + paddings_.size(), + 3, + phi::errors::InvalidArgument( + "It is expected paddings_size equals to 3, but got size %d", + paddings_.size())); + PADDLE_ENFORCE_EQ( + strides.size(), + 3, + phi::errors::InvalidArgument( + "It is expected strides_size equals to 3, but got size %d", + strides.size())); + + return MaxPoolWithIndexOpInferSymbolicShape(op, infer_context); +} + bool MeanAllOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 073f5b0049eb4..4bbd53bb7927d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -88,8 +88,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRank) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2DWithIndex) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithIndex) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3dWithIndex) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 6fd6bcc83a7be..871d5b9fc1cb9 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3120,6 +3120,7 @@ kernel : func : max_pool2d_with_index backward : max_pool2d_with_index_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : max_pool3d_with_index args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false) @@ -3129,6 +3130,7 @@ kernel : func : max_pool3d_with_index backward : max_pool3d_with_index_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : maxout args : (Tensor x, int groups, int axis = 1)