Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【BUAA】【Infer Symbolic Shape No.3,46,107】Add as_strided, fold and yolo_loss #67037

Merged
merged 9 commits into from
Aug 15, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,79 @@ bool Where_OpInferSymbolicShape(pir::Operation *op,
return WhereOpInferSymbolicShape(op, infer_context);
}

bool YoloLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
const auto &box_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();
const auto &label_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2)).shape();
const std::vector<int> &anchors_mask =
paddle::dialect::details::GetVectorAttr<int>(op, "anchor_mask");
int mask_num = anchors_mask.size();
int class_num = op->attribute<pir::Int32Attribute>("class_num").data();

PADDLE_ENFORCE_EQ(x_shape.size(),
4,
phi::errors::InvalidArgument(
"Input(X) should be a 4-D tensor. But received "
"X dimension size(%s)",
x_shape.size()));
PADDLE_ENFORCE_EQ(
box_shape.size(),
3,
phi::errors::InvalidArgument("Input(GTBox) should be a 3-D tensor, but "
"received gtbox dimension size(%s)",
box_shape.size()));
PADDLE_ENFORCE_EQ(label_shape.size(),
2,
phi::errors::InvalidArgument(
"Input(GTLabel) should be a 2-D tensor,"
"But received Input(GTLabel) dimension size(%s) != 2.",
label_shape.size()));
infer_context->AddEqualCstr(box_shape[2], symbol::DimExpr(4));
infer_context->AddEqualCstr(x_shape[2], x_shape[3]);
infer_context->AddEqualCstr(x_shape[1],
symbol::DimExpr(mask_num * (5 + class_num)));
infer_context->AddEqualCstr(label_shape[0], box_shape[0]);
infer_context->AddEqualCstr(label_shape[1], box_shape[1]);

if (op->operand_source(3) != nullptr) {
const auto &score_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(3)).shape();
PADDLE_ENFORCE_EQ(
score_shape.size(),
2,
phi::errors::InvalidArgument("Input(GTScore) should be a 2-D tensor"
"But received GTScore dimension(%s)",
box_shape.size()));
infer_context->AddEqualCstr(score_shape[0], box_shape[0]);
infer_context->AddEqualCstr(score_shape[1], box_shape[1]);
}

std::vector<symbol::DimExpr> out_shape = {x_shape[0]};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

std::vector<symbol::DimExpr> obj_mask_shape = {
x_shape[0], symbol::DimExpr(mask_num), x_shape[2], x_shape[3]};
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(obj_mask_shape)});

std::vector<symbol::DimExpr> match_mask_shape = {box_shape[0], box_shape[1]};
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(match_mask_shape)});

return true;
}

bool FakeChannelWiseDequantizeMaxAbsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HsigmoidLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeChannelWiseDequantizeMaxAbs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UpdateLossScaling_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBoxPost)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,30 @@ bool Assign_OpInferSymbolicShape(
return AssignOpInferSymbolicShape(op, infer_context);
}

bool AsStridedOpInferSymbolicShape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改一下AsStride的InferMeta接口

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const std::vector<int> &shape =
paddle::dialect::details::GetVectorAttr<int>(op, "dims");

int rank = shape.size();
std::vector<symbol::DimExpr> out_shape;
for (int i = 0; i < rank; ++i) {
symbol::DimExpr out_unknown = infer_context->GetNextSymName();
if (shape[i] == -1) {
out_shape.push_back(out_unknown);
} else {
out_shape.push_back(symbol::DimExpr(shape[i]));
}
}
Comment on lines +398 to +405
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symbol::DimExpr out_unknown = infer_context->GetNextSymName();最好放在if分支里,虽然这些新符号都没用到但是每个循环里符号计数都会增加


infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

bool BipartiteMatchOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &dist_mat_shape_or_data =
Expand Down Expand Up @@ -1060,6 +1084,83 @@ bool Flatten_OpInferSymbolicShape(
return FlattenOpInferSymbolicShape(op, infer_context);
}

bool FoldOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();

std::vector<symbol::DimExpr> out_shape;
out_shape.push_back(x_shape[0]);

const std::vector<int> &output_sizes =
paddle::dialect::details::GetVectorAttr<int>(op, "output_sizes");
PADDLE_ENFORCE_EQ(
output_sizes.size(),
2,
common::errors::InvalidArgument(
"It is expected output_size equals to 2, but got size %d",
output_sizes.size()));
infer_context->AddGreatThanOneCstr(output_sizes[0]);
infer_context->AddGreatThanOneCstr(output_sizes[1]);

const std::vector<int> &kernel_sizes =
paddle::dialect::details::GetVectorAttr<int>(op, "kernel_sizes");
const std::vector<int> &dilations =
paddle::dialect::details::GetVectorAttr<int>(op, "dilations");
const std::vector<int> &strides =
paddle::dialect::details::GetVectorAttr<int>(op, "strides");
const std::vector<int> &paddings =
paddle::dialect::details::GetVectorAttr<int>(op, "paddings");

PADDLE_ENFORCE_EQ(
kernel_sizes.size(),
2,
common::errors::InvalidArgument(
"It is expected kernel_size equals to 2, but got size %d",
kernel_sizes.size()));
PADDLE_ENFORCE_EQ(
strides.size(),
2,
common::errors::InvalidArgument(
"It is expected strides_size equals to 2, but got size %d",
strides.size()));
PADDLE_ENFORCE_EQ(
paddings.size(),
4,
common::errors::InvalidArgument(
"It is expected paddings_size equals to 4, but got size %d",
paddings.size()));
PADDLE_ENFORCE_EQ(
dilations.size(),
2,
common::errors::InvalidArgument(
"It is expected dilations_size equals to 2, but got size %d",
dilations.size()));

int blocks_height = (output_sizes[0] + 2 * paddings[0] -
(dilations[0] * (kernel_sizes[0] - 1) + 1)) /
strides[0] +
1;
int blocks_width = (output_sizes[1] + 2 * paddings[1] -
(dilations[1] * (kernel_sizes[1] - 1) + 1)) /
strides[1] +
1;

infer_context->AddEqualCstr((blocks_height * blocks_width), x_shape[2]);

out_shape.push_back(x_shape[1] / (kernel_sizes[0] * kernel_sizes[1]));

out_shape.push_back(symbol::DimExpr(output_sizes[0]));
out_shape.push_back(symbol::DimExpr(output_sizes[1]));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了DimExpr的约束,参考infermeta再修改一下吧

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

bool IdentityLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsComplex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsReal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsStrided)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(AllReduce)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(AllReduce_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Barrier)
Expand Down Expand Up @@ -67,6 +68,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fold)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeAbsMax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Inverse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GumbelSoftmax)
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@
func : as_strided
backward : as_strided_grad
no_need_buffer : input
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : asgd_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor d, Tensor y, Tensor n, Tensor master_param, bool multi_precision=false)
Expand Down Expand Up @@ -1959,6 +1960,7 @@
kernel:
func: fold
backward: fold_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fractional_max_pool2d
args : (Tensor x, int[] output_size, int[] kernel_size = {0, 0}, float random_u = 0.0, bool return_mask = true)
Expand Down Expand Up @@ -5085,6 +5087,7 @@
optional : gt_score
intermediate : objectness_mask, gt_match_mask
backward : yolo_loss_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : zeros
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
Expand Down