Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add infer_symbol_shape for some ops #65880

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,37 @@ bool AllcloseOpInferSymbolicShape(
return true;
}

bool BceLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &label_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

int rank = input_shape.shape().size();
PADDLE_ENFORCE_EQ(rank,
label_shape.shape().size(),
common::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
label_shape.shape().size()));

for (int i = 0; i < rank; ++i) {
infer_context->AddEqualCstr(input_shape.shape()[i], label_shape.shape()[i]);
}

infer_context->SetShapeOrDataForValue(op->result(0), input_shape);

return true;
}

bool BceLoss_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return BceLossOpInferSymbolicShape(op, infer_context);
}

bool Conv2dOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const std::vector<int> strides =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
namespace paddle::dialect {

OP_DECLARE_INFER_SYMBOLIC_SHAPE(Allclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,47 @@

namespace paddle::dialect {

bool AccuracyOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &out_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &label_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2));

// Assume indices has same shape as inference, because
// it's the output of topk.
PADDLE_ENFORCE_EQ(
label_shape.shape().size(),
2UL,
common::errors::InvalidArgument(
"ShapeError: label's dimensions of AccuracyOp must be 2. "
"But received label's dimensions = %d",
label_shape.shape().size()));

infer_context->AddEqualCstr(label_shape.shape()[1], symbol::DimExpr{1});
infer_context->AddEqualCstr(out_shape.shape()[0], label_shape.shape()[0]);

std::vector<symbol::DimExpr> accuracy_shape = {};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(accuracy_shape)});

std::vector<symbol::DimExpr> correct_shape = {};
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(correct_shape)});

std::vector<symbol::DimExpr> total_shape = {};
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(total_shape)});

return true;
}

bool AddNOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &input_list_shape =
Expand Down Expand Up @@ -60,6 +101,133 @@ bool AddNOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool AddmmOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &y_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2));

auto ndim_input = input_shape.shape().size();
auto ndim_x = x_shape.shape().size();
auto ndim_y = y_shape.shape().size();

PADDLE_ENFORCE_EQ(ndim_input == 2 || ndim_input == 1,
true,
common::errors::InvalidArgument(
"The input tensor input's dimension must be 2 or 1. "
"But received input's dimension = [%d].",
ndim_input));
PADDLE_ENFORCE_EQ(ndim_x,
2,
common::errors::InvalidArgument(
"The input tensor x's dimension must be 2. "
"But received x's dimension = [%d].",
ndim_x));
PADDLE_ENFORCE_EQ(ndim_y,
2,
common::errors::InvalidArgument(
"The input tensor y's dimension must be 2. "
"But received y's dimension = [%d].",
ndim_y));

std::vector<symbol::DimExpr> output_shape;
output_shape.push_back(x_shape.shape()[0]);
output_shape.push_back(y_shape.shape()[1]);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});

infer_context->AddEqualCstr(x_shape.shape()[1], y_shape.shape()[0]);

if (ndim_input == 2) {
infer_context->AddBroadcastableCstr(input_shape.shape()[0],
x_shape.shape()[0]);
infer_context->AddBroadcastableCstr(input_shape.shape()[1],
y_shape.shape()[1]);
} else if (ndim_input == 1) {
infer_context->AddBroadcastableCstr(input_shape.shape()[0],
y_shape.shape()[1]);
}

return true;
}

bool Addmm_OpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
return AddmmOpInferSymbolicShape(op, infer_context);
}

bool AucOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &predict_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &label_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

PADDLE_ENFORCE_GE(
predict_shape.shape().size(),
2,
common::errors::InvalidArgument(
"The Input(Predict) has not been initialized properly. The "
"shape of Input(Predict) = [%s], the shape size must be "
"greater_equal 2.",
predict_shape.shape()));

const auto &predict_height = predict_shape.shape()[0];
const auto &label_height = label_shape.shape()[0];

infer_context->AddEqualCstr(predict_height, label_height);

int num_thresholds =
op->attribute<pir::Int32Attribute>("num_thresholds").data();
int slide_steps = op->attribute<pir::Int32Attribute>("slide_steps").data();

int num_pred_buckets = num_thresholds + 1;

PADDLE_ENFORCE_GE(
num_pred_buckets,
1,
common::errors::InvalidArgument("num_thresholds must larger than 1"));
PADDLE_ENFORCE_GE(
slide_steps,
0,
common::errors::InvalidArgument("slide_steps must be natural number"));

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(std::vector<symbol::DimExpr>{})});

if (slide_steps) {
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(std::vector<symbol::DimExpr>{
(1 + slide_steps) * num_pred_buckets + 1})});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(std::vector<symbol::DimExpr>{
(1 + slide_steps) * num_pred_buckets + 1})});
} else {
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
std::vector<symbol::DimExpr>{1, num_pred_buckets})});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
std::vector<symbol::DimExpr>{1, num_pred_buckets})});
}

return true;
}

bool BicubicInterpOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x =
Expand Down Expand Up @@ -287,6 +455,55 @@ bool BicubicInterpOpInferSymbolicShape(
return true;
}

bool BilinearOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &y_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &weight_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2));

PADDLE_ENFORCE_EQ(
x_shape.shape().size(),
2UL,
common::errors::InvalidArgument("The input(X) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
y_shape.shape().size(),
2UL,
common::errors::InvalidArgument("The input(Y) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_shape.shape().size(),
3UL,
common::errors::InvalidArgument(
"Expected the input(Weight) is a 3D tensor. But received %dD tensor.",
weight_shape.shape().size()));

infer_context->AddEqualCstr(x_shape.shape()[0], y_shape.shape()[0]);

infer_context->AddEqualCstr(x_shape.shape()[1], weight_shape.shape()[1]);
infer_context->AddEqualCstr(y_shape.shape()[1], weight_shape.shape()[2]);

if (op->operand_source(3)) { // has bias
const auto &bias_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
PADDLE_ENFORCE_EQ(bias_shape.shape().size(),
2UL,
common::errors::InvalidArgument(
"The Input(Bias) must be a 2-D tensor with "
"the 2nd dimension fixed to 1 (a row vector)."));
infer_context->AddEqualCstr(bias_shape.shape()[0], symbol::DimExpr{1});
infer_context->AddEqualCstr(bias_shape.shape()[1], weight_shape.shape()[0]);
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
{x_shape.shape()[0], weight_shape.shape()[0]})});

return true;
}

bool BilinearInterpOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return BicubicInterpOpInferSymbolicShape(op, infer_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

namespace paddle::dialect {

OP_DECLARE_INFER_SYMBOLIC_SHAPE(Accuracy)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax)
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
kernel :
func : accuracy
data_type : x
interfaces : paddle::dialect::InferSymbolicShapeInterface
traits : paddle::dialect::ForwardOnlyTrait

- op : accuracy_check
Expand Down Expand Up @@ -145,6 +146,7 @@
data_type : x
inplace: (input -> out)
backward : addmm_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : affine_channel
args: (Tensor x, Tensor scale, Tensor bias, str data_layout = "AnyLayout")
Expand Down Expand Up @@ -429,6 +431,7 @@
func : auc
data_type : x
optional : ins_tag_weight
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : average_accumulates_
args : (Tensor param, Tensor in_sum_1, Tensor in_sum_2, Tensor in_sum_3, Tensor in_num_accumulates, Tensor in_old_num_accumulates, Tensor in_num_updates, float average_window = 0, int64_t max_average_window = INT64_MAX, int64_t min_average_window = 10000L)
Expand Down Expand Up @@ -460,6 +463,7 @@
data_type : input
inplace : (input -> out)
backward : bce_loss_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : beam_search
args: (Tensor pre_ids, Tensor pre_scores, Tensor ids, Tensor scores, int level,
Expand Down Expand Up @@ -505,6 +509,7 @@
func : bilinear
optional : bias
backward : bilinear_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : bilinear_interp
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_format="NCHW", int out_d=0, int out_h=0, int out_w=0, float[] scale={}, str interp_method="bilinear", bool align_corners=true, int align_mode=1)
Expand Down