Skip to content

Commit

Permalink
[CINN]add more InferSymbolicShape (#66732)
Browse files Browse the repository at this point in the history
* [CINN]add more OpInferSymbolicShape

* fix warning

* fix

* fix data_format
  • Loading branch information
Hongqing-work authored Aug 1, 2024
1 parent be7e429 commit 1a1404f
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,109 @@ bool AucOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool BatchNormOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &scale_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
const auto &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(4));

std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();

std::string data_layout_str =
op->attribute<pir::StrAttribute>("data_format").AsString();
const DataLayout data_layout = common::StringToDataLayout(data_layout_str);

PADDLE_ENFORCE_GE(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: the dimension of input "
"X must greater than or equal to 2. But received: the shape of input "
"X = [%s], the dimension of input X =[%d]",
x_dims,
x_dims.size()));
PADDLE_ENFORCE_LE(
x_dims.size(),
5,
phi::errors::InvalidArgument(
"ShapeError: the dimension of input X "
"must smaller than or equal to 5. But received: the shape of input X "
"= [%s], the dimension of input X = [%d]",
x_dims,
x_dims.size()));

symbol::DimExpr C = (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1];

if (!scale_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> scale_dims = scale_shape_or_data.shape();
PADDLE_ENFORCE_EQ(scale_dims.size(),
1UL,
phi::errors::InvalidArgument(
"ShapeError: the dimension of scale must equal to 1."
"But received: the dimension of scale is [%d]",
scale_dims.size()));
infer_context->AddEqualCstr(scale_dims[0], C);
}

if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> bias_dims = bias_shape_or_data.shape();
PADDLE_ENFORCE_EQ(bias_dims.size(),
1UL,
phi::errors::InvalidArgument(
"ShapeError: the dimension of bias must equal to 1."
"But received: the dimension of bias is [%d]",
bias_dims.size()));
infer_context->AddEqualCstr(bias_dims[0], C);
}

// Set output shapes
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)});

std::vector<symbol::DimExpr> param_dims = {C};
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(param_dims)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(param_dims)});

if (op->result(3) && op->result(3).type()) {
infer_context->SetShapeOrDataForValue(
op->result(3),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(param_dims)});
}
if (op->result(4) && op->result(4).type()) {
infer_context->SetShapeOrDataForValue(
op->result(4),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(param_dims)});
}
if (op->result(5) && op->result(5).type()) {
std::vector<symbol::DimExpr> reserve_space_dims = {symbol::DimExpr{-1}};
infer_context->SetShapeOrDataForValue(
op->result(5),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(reserve_space_dims)});
}

return true;
}

bool BatchNorm_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return BatchNormOpInferSymbolicShape(op, infer_context);
}

bool BicubicInterpOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x =
Expand Down Expand Up @@ -742,6 +845,55 @@ bool GroupNormOpInferSymbolicShape(
return true;
}

bool LayerNormOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Get the shapes of input tensors
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &scale_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &bias_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));

std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape();
int begin_norm_axis =
op->attribute<pir::Int32Attribute>("begin_norm_axis").data();

// Flatten x_dims to 2D and get dim[1]
symbol::DimExpr matrix_dim_1 = x_dims[begin_norm_axis];
for (std::size_t i = begin_norm_axis + 1; i < x_dims.size(); ++i) {
matrix_dim_1 = matrix_dim_1 * x_dims[i];
}

if (!scale_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> scale_dims = scale_shape_or_data.shape();
infer_context->AddEqualCstr(scale_dims[0], matrix_dim_1);
}
if (!bias_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
std::vector<symbol::DimExpr> bias_dims = bias_shape_or_data.shape();
infer_context->AddEqualCstr(bias_dims[0], matrix_dim_1);
}

// Set output shapes
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)});

// Set mean and variance shapes
std::vector<symbol::DimExpr> before_norm_dims(
x_dims.begin(), x_dims.begin() + begin_norm_axis);
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(before_norm_dims)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(before_norm_dims)});

return true;
}

bool LinspaceOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &num_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
Expand All @@ -32,6 +34,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LayerNorm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ OP_SAME_OPERANDS_AND_RESULT(LogicalNot)
OP_SAME_OPERANDS_AND_RESULT(LogicalNot_)
OP_SAME_OPERANDS_AND_RESULT(Logit)
OP_SAME_OPERANDS_AND_RESULT(Logit_)
OP_SAME_OPERANDS_AND_RESULT(Logsigmoid)
OP_SAME_OPERANDS_AND_RESULT(Logsigmoid_)
OP_SAME_OPERANDS_AND_RESULT(Pow)
OP_SAME_OPERANDS_AND_RESULT(Poisson)
OP_SAME_OPERANDS_AND_RESULT(Pow_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsigmoid_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_)
Expand Down
Loading

0 comments on commit 1a1404f

Please sign in to comment.