Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【BUAA】【Infer Symbolic Shape No.152,153】Add max_pool2d_with_index and max_pool3d_with_index #67390

Merged
merged 10 commits into from
Aug 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,155 @@ bool MinOpInferSymbolicShape(pir::Operation *op,
return MaxOpInferSymbolicShape(op, infer_context);
}

bool MaxPoolWithIndexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();

std::vector<int> paddings_ =
paddle::dialect::details::GetVectorAttr<int>(op, "paddings");
std::vector<int> strides =
paddle::dialect::details::GetVectorAttr<int>(op, "strides");
std::vector<int> kernel_sizes_ =
paddle::dialect::details::GetVectorAttr<int>(op, "kernel_sizes");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_size


std::vector<symbol::DimExpr> kernel_size_;
int rank_kernel = kernel_sizes_.size();
for (int i = 0; i < rank_kernel; ++i) {
kernel_size_.push_back(kernel_sizes_[i]);
}

bool adaptive = op->attribute<pir::BoolAttribute>("adaptive").data();
bool ceil_mode = op->attribute<pir::BoolAttribute>("ceil_mode").data();
bool global_pooling =
op->attribute<pir::BoolAttribute>("global_pooling").data();

int rank_x = x_shape.size();
int rank = kernel_size_.size();

if (global_pooling) {
kernel_size_.resize(rank_x - 2);
for (int i = 0; i < rank; ++i) {
paddings_[i] = 0;
kernel_size_[i] = x_shape[i + 2];
}
}

PADDLE_ENFORCE_EQ(
x_shape.size() - kernel_size_.size(),
2U,
phi::errors::InvalidArgument(
"The input size %d minus the kernel size %d should equal to 2.",
x_shape.size(),
kernel_size_.size()));

std::vector<symbol::DimExpr> out_shape = {x_shape[0], x_shape[1]};

if (adaptive) {
out_shape.insert(out_shape.end(), kernel_size_.begin(), kernel_size_.end());
} else {
for (int i = 0; i < rank; ++i) {
PADDLE_ENFORCE_NE(
strides[i],
0,
phi::errors::InvalidArgument(
"The stride of MaxPool shall not be 0, but received %d.",
strides[i]));
if (ceil_mode) {
out_shape.push_back(
symbol::DimExpr((x_shape[i + 2] - kernel_size_[i] +
2 * paddings_[i] + strides[i] - 1) /
strides[i] +
1));
} else {
out_shape.push_back(symbol::DimExpr(
(x_shape[i + 2] - kernel_size_[i] + 2 * paddings_[i]) / strides[i] +
1));
}
}
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

bool MaxPool2dWithIndexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();

PADDLE_ENFORCE_EQ(
x_shape.size(),
4,
phi::errors::InvalidArgument("Pooling intput should be 4-D Tensor"
"but received %dD-Tensor",
x_shape.size()));

std::vector<int> paddings_ =
paddle::dialect::details::GetVectorAttr<int>(op, "paddings");
std::vector<int> strides =
paddle::dialect::details::GetVectorAttr<int>(op, "strides");

PADDLE_ENFORCE_EQ(
paddings_.size(),
2,
phi::errors::InvalidArgument(
"It is expected paddings_size equals to 2, but got size %d",
paddings_.size()));
PADDLE_ENFORCE_EQ(
strides.size(),
2,
phi::errors::InvalidArgument(
"It is expected strides_size equals to 2, but got size %d",
strides.size()));

return MaxPoolWithIndexOpInferSymbolicShape(op, infer_context);
}

bool MaxPool3dWithIndexOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Comment on lines +1532 to +1533
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两个函数有很多相同逻辑,建议抽象点共同操作

const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();

PADDLE_ENFORCE_EQ(
x_shape.size(),
5,
phi::errors::InvalidArgument("Pooling intput should be 5-D Tensor"
"but received %dD-Tensor",
x_shape.size()));

std::vector<int> paddings_ =
paddle::dialect::details::GetVectorAttr<int>(op, "paddings");
std::vector<int> strides =
paddle::dialect::details::GetVectorAttr<int>(op, "strides");

PADDLE_ENFORCE_EQ(
paddings_.size(),
3,
phi::errors::InvalidArgument(
"It is expected paddings_size equals to 3, but got size %d",
paddings_.size()));
PADDLE_ENFORCE_EQ(
strides.size(),
3,
phi::errors::InvalidArgument(
"It is expected strides_size equals to 3, but got size %d",
strides.size()));

return MaxPoolWithIndexOpInferSymbolicShape(op, infer_context);
}

bool MeanAllOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixPower)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRank)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2DWithIndex)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3dWithIndex)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3120,6 +3120,7 @@
kernel :
func : max_pool2d_with_index
backward : max_pool2d_with_index_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : max_pool3d_with_index
args : (Tensor x, int[] kernel_size, int[] strides = {1, 1, 1}, int[] paddings = {0, 0, 0}, bool global_pooling = false, bool adaptive = false, bool ceil_mode = false)
Expand All @@ -3129,6 +3130,7 @@
kernel :
func : max_pool3d_with_index
backward : max_pool3d_with_index_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : maxout
args : (Tensor x, int groups, int axis = 1)
Expand Down