Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape BUAA No13、66、91、122】Add 4 ops #66877

Merged
merged 26 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5976307
Update unary_infer_sym.h
Guanhuachen2003 Jul 31, 2024
d6fc344
Update unary_infer_sym.cc
Guanhuachen2003 Jul 31, 2024
b132f90
Update ops.yaml
Guanhuachen2003 Jul 31, 2024
c813a54
Update unary_infer_sym.cc
Guanhuachen2003 Jul 31, 2024
ec646ce
Update multiary_infer_sym.cc
Guanhuachen2003 Jul 31, 2024
00b4875
Update multiary_infer_sym.h
Guanhuachen2003 Jul 31, 2024
9c74456
Update ops.yaml
Guanhuachen2003 Jul 31, 2024
edfb8c0
Update unary_infer_sym.h
Guanhuachen2003 Jul 31, 2024
ddfde35
Update unary_infer_sym.cc
Guanhuachen2003 Jul 31, 2024
6bbd633
Update ops.yaml
Guanhuachen2003 Jul 31, 2024
aaf8ebe
Update unary_infer_sym.cc
Guanhuachen2003 Aug 1, 2024
261d7a3
Merge branch 'develop' into cinn-rrelu
Guanhuachen2003 Aug 1, 2024
36266b5
Update multiary_infer_sym.cc
Guanhuachen2003 Aug 1, 2024
c469ecc
Update multiary_infer_sym.h
Guanhuachen2003 Aug 1, 2024
4a67064
Update static_ops.yaml
Guanhuachen2003 Aug 1, 2024
a67bf50
Merge branch 'PaddlePaddle:develop' into cinn-rrelu
Guanhuachen2003 Aug 4, 2024
cb2994c
Update multiary_infer_sym.cc
Guanhuachen2003 Aug 4, 2024
937c558
Merge branch 'PaddlePaddle:develop' into cinn-rrelu
Guanhuachen2003 Aug 4, 2024
8d39977
Update unary_infer_sym.cc
Guanhuachen2003 Aug 4, 2024
6e77d5e
Update multiary_infer_sym.cc
Guanhuachen2003 Aug 4, 2024
afb225a
Merge branch 'PaddlePaddle:develop' into cinn-rrelu
Guanhuachen2003 Aug 5, 2024
1972a11
Update multiary_infer_sym.cc
Guanhuachen2003 Aug 5, 2024
e7553a5
Merge branch 'develop' into cinn-rrelu
Guanhuachen2003 Aug 6, 2024
7b83438
Update check_finite_and_unscale using tensorlist
Guanhuachen2003 Aug 6, 2024
4cb1903
Merge branch 'develop' into cinn-rrelu
Guanhuachen2003 Aug 7, 2024
817594c
delete unnecessary paddle_enforce_eq
Guanhuachen2003 Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,36 @@ bool BilinearInterpOpInferSymbolicShape(
return BicubicInterpOpInferSymbolicShape(op, infer_context);
}

bool CheckFiniteAndUnscaleOpInferSymbolicShape(
Guanhuachen2003 marked this conversation as resolved.
Show resolved Hide resolved
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Retrieve the shape information of the input list
pir::Value operand_source = op->operand_source(0);
const symbol::TensorListShapeOrDataDimExprs &xs_shapes =
infer_context->GetShapeOrDataForValue(operand_source)
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();

// Set the shapes for the output tensor list
symbol::TensorListShapeOrDataDimExprs outs_shapes;
outs_shapes.reserve(xs_shapes.size());
for (const auto &input_shape : xs_shapes) {
outs_shapes.emplace_back(input_shape.shape());
}
infer_context->SetShapeOrDataForValue(
op->result(0), symbol::ShapeOrDataDimExprs{outs_shapes});

// Set the shape for the found_infinite output
symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)});
infer_context->SetShapeOrDataForValue(
op->result(1), symbol::ShapeOrDataDimExprs{found_infinite_shape});

return true;
}

bool CheckFiniteAndUnscale_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return CheckFiniteAndUnscaleOpInferSymbolicShape(op, infer_context);
}

// bool CrfDecodingOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
Expand Down Expand Up @@ -1244,6 +1274,31 @@ bool MeshgridOpInferSymbolicShape(
return true;
}

bool MovingAverageAbsMaxScaleOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Get shapes of input tensors
const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &in_state_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const auto &in_accum_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

// Set shapes for output tensors
infer_context->SetShapeOrDataForValue(op->result(0), x_shape);
infer_context->SetShapeOrDataForValue(
op->result(1), symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(1)}));
infer_context->SetShapeOrDataForValue(op->result(2), in_state_shape);
infer_context->SetShapeOrDataForValue(op->result(3), in_accum_shape);

return true;
}

bool MovingAverageAbsMaxScale_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return MovingAverageAbsMaxScaleOpInferSymbolicShape(op, infer_context);
}

// bool NceOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// // pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrfDecoding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax)
Expand Down Expand Up @@ -62,6 +64,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MulticlassNMS3)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,22 @@ bool KthvalueOpInferSymbolicShape(
return true;
}

bool L1NormOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
// The output is a scalar, set the output shape accordingly
std::vector<symbol::DimExpr> output_shape;
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});
return true;
}

bool L1Norm_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return L1NormOpInferSymbolicShape(op, infer_context);
}

bool InverseOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
Expand Down Expand Up @@ -1501,6 +1517,52 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool RreluOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
float lower = op->attribute<pir::FloatAttribute>("lower").data();
float upper = op->attribute<pir::FloatAttribute>("upper").data();
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &x_shape = x_shape_or_data.shape();

// Check constraints for the attributes lower and upper
PADDLE_ENFORCE_GE(lower,
0,
phi::errors::InvalidArgument(
"The lower value should be greater than or equal to 0. "
"But received lower value = %f.",
lower));
PADDLE_ENFORCE_LE(upper,
1,
phi::errors::InvalidArgument(
"The upper value should be less than or equal to 1. "
"But received upper value = %f.",
upper));
PADDLE_ENFORCE_GE(
upper,
lower,
phi::errors::InvalidArgument(
"The upper value should be greater than or equal to lower value. "
"But received upper value = %f, lower value = %f.",
upper,
lower));

// Set the shape for the output tensor out
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});

// Set the shape for the output tensor noise if it exists
if (op->num_results() > 1 && op->result(1) != nullptr) {
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(x_shape)});
}

return true;
}

bool ShapeSrOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return ShapeOpInferSymbolicShape(op, infer_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Inverse)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
Expand All @@ -92,6 +94,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rrelu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,7 @@
param: [x, in_accum, in_state, moving_rate, is_test]
optional : in_accum, in_state, out, out_state, out_accum
inplace : (in_accum -> out_accum), (in_state -> out_state)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@
param : [x, scale]
data_type : x
inplace : (x -> out)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : check_numerics
args : (Tensor tensor, str op_type = "", str var_name = "", int check_nan_inf_level = 0, int stack_height_limit = -1, str output_dir = "")
Expand Down Expand Up @@ -2642,6 +2643,7 @@
data_type : x
inplace: (x -> out)
backward : l1_norm_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : label_smooth
args : (Tensor label, Tensor prior_dist, float epsilon = 0.0f)
Expand Down Expand Up @@ -3945,6 +3947,7 @@
data_type : x
intermediate : noise
backward : rrelu_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : rsqrt
args : (Tensor x)
Expand Down