Skip to content

Commit

Permalink
[CINN] Add infer_symbolic_shape for group_norm and pad3d (#63879)
Browse files Browse the repository at this point in the history
* add infer_symbolic_shape of group_norm and pad3d

* modify yaml

* udpate yaml
  • Loading branch information
zyfncg authored Apr 26, 2024
1 parent 481f32d commit 3289d25
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,26 @@ bool FlashAttnOpInferSymbolicShape(
return true;
}

bool GroupNormOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const symbol::ShapeOrDataDimExprs &x_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));

shape_analysis->SetShapeOrDataForValue(op->result(0), x_shape);

const symbol::DimExpr &batch_size = x_shape.shape()[0];
int groups = op->attribute<pir::Int32Attribute>("groups").data();
symbol::TensorShapeOrDataDimExprs mean_shape(
std::vector<symbol::DimExpr>{batch_size, groups});
if (op->result(1)) {
shape_analysis->SetShapeOrDataForValue(op->result(1), mean_shape);
}
if (op->result(2)) {
shape_analysis->SetShapeOrDataForValue(op->result(2), mean_shape);
}
return true;
}

bool LinspaceOpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
const auto &num_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,58 @@ bool PadOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool Pad3dOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
const auto &x_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape();
PADDLE_ENFORCE_EQ(x_shape.size(),
5,
common::errors::InvalidArgument(
"The size of Input(X)'s dimension should be equal to "
"5, but received %d. ",
x_shape.size()));
const auto &paddings_shape =
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
if (!paddings_shape.data().has_value()) {
std::stringstream ss;
ss << paddings_shape;
PADDLE_THROW(
common::errors::InvalidArgument("The data of paddings's symbol shape "
"should have value, but now got [%s].",
ss.str()));
}
const std::string &data_format =
op->attribute<pir::StrAttribute>("data_format").AsString();

const std::vector<symbol::DimExpr> &out_dims = [&] {
std::vector<symbol::DimExpr> out_dims = x_shape;
const auto &paddings = paddings_shape.data().value();
PADDLE_ENFORCE_EQ(paddings.size(),
6,
common::errors::InvalidArgument(
"Shape of Input(Paddings) should be equal to "
"[6], but received [%d].",
paddings.size()));
if (data_format == "NCDHW") {
out_dims[1] = x_shape[1];
out_dims[2] = x_shape[2] + paddings[4] + paddings[5];
out_dims[3] = x_shape[3] + paddings[2] + paddings[3];
out_dims[4] = x_shape[4] + paddings[0] + paddings[1];
} else {
out_dims[1] = x_shape[1] + paddings[4] + paddings[5];
out_dims[2] = x_shape[2] + paddings[2] + paddings[3];
out_dims[3] = x_shape[3] + paddings[0] + paddings[1];
out_dims[4] = x_shape[4];
}
return out_dims;
}();

shape_analysis->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(out_dims));

return true;
}

bool ProdOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
bool keepdim = GetBoolAttr(op, "keep_dim");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,7 @@
optional : scale, bias
intermediate : mean, variance
backward : group_norm_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : gumbel_softmax
args : (Tensor x, float temperature = 1.0, bool hard = false, int axis = -1)
Expand Down Expand Up @@ -2204,6 +2205,7 @@
kernel :
func : pad3d
backward : pad3d_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : pixel_shuffle
args : (Tensor x, int upscale_factor=1, str data_format="NCHW")
Expand Down

0 comments on commit 3289d25

Please sign in to comment.