From 5976307c6be38db242a43566ac0596908bf926bf Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:14:38 +0800 Subject: [PATCH 01/20] Update unary_infer_sym.h --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 2488b579929b2..7c341530da6a3 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -58,6 +58,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rrelu) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) From d6fc344440a9fc4b80992f240639c339a22de985 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:18:28 +0800 Subject: [PATCH 02/20] Update unary_infer_sym.cc --- .../infer_symbolic_shape/unary_infer_sym.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 8a44b8bade20f..8ac134da81214 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -928,6 +928,20 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, return true; } +bool RreluOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand(0)); + + infer_context->SetShapeOrDataForValue(op->result(0), x_shape_or_data); + + if (op->num_results() > 1) { + infer_context->SetShapeOrDataForValue(op->result(1), x_shape_or_data); + } + + return true; +} + bool ShapeSrOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return ShapeOpInferSymbolicShape(op, infer_context); From b132f90b1bb78639da3357f0597a8d7747582a29 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:19:06 +0800 Subject: [PATCH 03/20] Update ops.yaml --- paddle/phi/ops/yaml/ops.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 983635dfb09a8..bc87137b7a4bf 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3884,6 +3884,7 @@ data_type : x intermediate : noise backward : rrelu_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rsqrt args : (Tensor x) From c813a545fb0929312117e83aa7d8a163a327c1fb Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:21:38 +0800 Subject: [PATCH 04/20] Update unary_infer_sym.cc --- .../infer_symbolic_shape/unary_infer_sym.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 8ac134da81214..ac5dc8f9b0d2b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -928,15 +928,14 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, return true; } -bool RreluOpInferSymbolicShape(pir::Operation *op, - pir::InferSymbolicShapeContext *infer_context) { - const auto &x_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand(0)); +bool RreluOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); - infer_context->SetShapeOrDataForValue(op->result(0), x_shape_or_data); + infer_context->SetShapeOrDataForValue(op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); - if (op->num_results() > 1) { - infer_context->SetShapeOrDataForValue(op->result(1), x_shape_or_data); + if (op->num_results() > 1 && op->result(1) != nullptr) { + infer_context->SetShapeOrDataForValue(op->result(1), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); } return true; From ec646ce678e1fb7c56407a9e26da4d983fa60edd Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:51:58 +0800 Subject: [PATCH 05/20] Update multiary_infer_sym.cc --- .../multiary_infer_sym.cc | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index e399ead76e7e8..48eb7f3fffa9c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -509,6 +509,33 @@ bool BilinearInterpOpInferSymbolicShape( return BicubicInterpOpInferSymbolicShape(op, infer_context); } +bool CheckFiniteAndUnscaleOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + std::vector xs_shapes; + for (size_t i = 0; i < op->num_operands() - 1; ++i) { + xs_shapes.push_back( + infer_context->GetShapeOrDataForValue(op->operand_source(i))); + } + + for (size_t i = 0; i < xs_shapes.size(); ++i) { + symbol::TensorShapeOrDataDimExprs output_shape(xs_shapes[i].shape()); + infer_context->SetShapeOrDataForValue( + op->result(i), symbol::ShapeOrDataDimExprs{output_shape}); + } + + symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)}); + infer_context->SetShapeOrDataForValue( + op->result(op->num_results() - 1), + symbol::ShapeOrDataDimExprs{found_infinite_shape}); + + return true; +} + +bool CheckFiniteAndUnscale_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return CheckFiniteAndUnscaleOpInferSymbolicShape(op, infer_context); +} + bool CrossEntropyWithSoftmaxOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &input_shape = From 00b4875fdbf561fd48b7cb5ea0d9361d7d0c3820 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:52:31 +0800 Subject: [PATCH 06/20] Update multiary_infer_sym.h --- .../interface/infer_symbolic_shape/multiary_infer_sym.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 2d0eab8d15c74..3f3ac2258dc13 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -26,6 +26,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_) From 9c744560ece8959af946517a0589ce240fd9710e Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:53:08 +0800 Subject: [PATCH 07/20] Update ops.yaml --- paddle/phi/ops/yaml/ops.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index bc87137b7a4bf..48ea0de075271 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -844,6 +844,7 @@ param : [x, scale] data_type : x inplace : (x -> out) + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : check_numerics args : (Tensor tensor, str op_type = "", str var_name = "", int check_nan_inf_level = 0, int stack_height_limit = -1, str output_dir = "") From edfb8c0601e4fd0f3f0c93250e095e494fcc3a65 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:56:02 +0800 Subject: [PATCH 08/20] Update unary_infer_sym.h --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 7c341530da6a3..1959089785d37 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -45,6 +45,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) From ddfde357c12eff45e90e77a1dc0712d4ee3cc992 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:59:34 +0800 Subject: [PATCH 09/20] Update unary_infer_sym.cc --- .../infer_symbolic_shape/unary_infer_sym.cc | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index ac5dc8f9b0d2b..4a4a643c8ab0b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -564,6 +564,21 @@ bool KthvalueOpInferSymbolicShape( return true; } +bool L1NormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + std::vector output_shape; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(output_shape)}); + return true; +} + +bool L1Norm_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return L1Norm_OpInferSymbolicShape(op, infer_context); +} + bool LogcumsumexpOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // same as CumsumOpInferSymbolicShape @@ -928,14 +943,21 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, return true; } -bool RreluOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); +bool RreluOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); - infer_context->SetShapeOrDataForValue(op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); if (op->num_results() > 1 && op->result(1) != nullptr) { - infer_context->SetShapeOrDataForValue(op->result(1), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(x_shape)}); } return true; From 6bbd6332b812f24934c4ae63573e147b9142f5bd Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 31 Jul 2024 18:00:07 +0800 Subject: [PATCH 10/20] Update ops.yaml --- paddle/phi/ops/yaml/ops.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 48ea0de075271..77e81356c1324 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -2604,6 +2604,7 @@ data_type : x inplace: (x -> out) backward : l1_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : label_smooth args : (Tensor label, Tensor prior_dist, float epsilon = 0.0f) From aaf8ebe00cbeb5d9ea64e36ad89bcda435276ad6 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Thu, 1 Aug 2024 09:37:25 +0800 Subject: [PATCH 11/20] Update unary_infer_sym.cc --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 4a4a643c8ab0b..e39c057a5f6be 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -576,7 +576,7 @@ bool L1NormOpInferSymbolicShape(pir::Operation *op, bool L1Norm_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return L1Norm_OpInferSymbolicShape(op, infer_context); + return L1NormOpInferSymbolicShape(op, infer_context); } bool LogcumsumexpOpInferSymbolicShape( From 36266b5d1d2130bdec62f4af9cd1d3e32ea70d2c Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:41:01 +0800 Subject: [PATCH 12/20] Update multiary_infer_sym.cc --- .../multiary_infer_sym.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index d53b97f4f74b1..99d2c9f3bde2e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1077,6 +1077,41 @@ bool MeshgridOpInferSymbolicShape( return true; } +bool MovingAverageAbsMaxScaleOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const symbol::ShapeOrDataDimExprs &in_state_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const symbol::ShapeOrDataDimExprs &in_accum_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + if (op->num_results() > 0 && op->result(0) != nullptr) { + infer_context->SetShapeOrDataForValue(op->result(0), x_shape); + } + + if (op->num_results() > 1 && op->result(1) != nullptr) { + symbol::TensorShapeOrDataDimExprs scalar_shape( + std::vector{symbol::DimExpr(1)}); + infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape); + } + + if (op->num_results() > 2 && op->result(2) != nullptr) { + infer_context->SetShapeOrDataForValue(op->result(2), in_state_shape); + } + + if (op->num_results() > 3 && op->result(3) != nullptr) { + infer_context->SetShapeOrDataForValue(op->result(3), in_accum_shape); + } + + return true; +} + +bool MovingAverageAbsMaxScale_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return MovingAverageAbsMaxScaleOpInferSymbolicShape(op, infer_context); +} + bool StackOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); From c469ecc0b32bd92034f9dbfd0ec957118f58e3dc Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:41:48 +0800 Subject: [PATCH 13/20] Update multiary_infer_sym.h --- .../interface/infer_symbolic_shape/multiary_infer_sym.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 82e67739c9e5a..5013086e6cf46 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -42,6 +42,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack) From 4a67064c52a01eacb36617d5989fb2af665b68c8 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:42:48 +0800 Subject: [PATCH 14/20] Update static_ops.yaml --- paddle/phi/ops/yaml/inconsistent/static_ops.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 5afaa90709d2e..1088dcda3b9aa 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -960,6 +960,7 @@ param: [x, in_accum, in_state, moving_rate, is_test] optional : in_accum, in_state, out, out_state, out_accum inplace : (in_accum -> out_accum), (in_state -> out_state) + interfaces : paddle::dialect::InferSymbolicShapeInterface - op: nce args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) From cb2994c716ea443184d190cf2ff3158ebfe2989d Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Sun, 4 Aug 2024 22:28:39 +0800 Subject: [PATCH 15/20] Update multiary_infer_sym.cc --- .../multiary_infer_sym.cc | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 2bd8f5c5a54f4..6d53197e6714f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -615,22 +615,29 @@ bool BilinearInterpOpInferSymbolicShape( bool CheckFiniteAndUnscaleOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + // Retrieve shapes for all input tensors except the last one (scale) std::vector xs_shapes; for (size_t i = 0; i < op->num_operands() - 1; ++i) { xs_shapes.push_back( infer_context->GetShapeOrDataForValue(op->operand_source(i))); } + // Ensure the number of inputs (xs) matches the number of outputs (outs) + infer_context->AddEqualCstr( + symbol::DimExpr(static_cast(xs_shapes.size())), + symbol::DimExpr(static_cast(op->num_results() - 1))); + + // Set shapes for output tensors corresponding to each input tensor for (size_t i = 0; i < xs_shapes.size(); ++i) { - symbol::TensorShapeOrDataDimExprs output_shape(xs_shapes[i].shape()); infer_context->SetShapeOrDataForValue( - op->result(i), symbol::ShapeOrDataDimExprs{output_shape}); + op->result(i), symbol::ShapeOrDataDimExprs{xs_shapes[i].shape()}); } - symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)}); + // Set shape for the found_infinite output tensor infer_context->SetShapeOrDataForValue( op->result(op->num_results() - 1), - symbol::ShapeOrDataDimExprs{found_infinite_shape}); + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(1)})}); return true; } @@ -1080,27 +1087,34 @@ bool MeshgridOpInferSymbolicShape( bool MovingAverageAbsMaxScaleOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + // Get the symbolic shape of the input tensor x const symbol::ShapeOrDataDimExprs &x_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)); + // Get the symbolic shape of the input tensor in_state const symbol::ShapeOrDataDimExprs &in_state_shape = infer_context->GetShapeOrDataForValue(op->operand_source(2)); + // Get the symbolic shape of the input tensor in_accum const symbol::ShapeOrDataDimExprs &in_accum_shape = infer_context->GetShapeOrDataForValue(op->operand_source(1)); + // Set the shape for the output tensor out if it exists if (op->num_results() > 0 && op->result(0) != nullptr) { infer_context->SetShapeOrDataForValue(op->result(0), x_shape); } + // Set the shape for the output tensor out_scale as a scalar if it exists if (op->num_results() > 1 && op->result(1) != nullptr) { symbol::TensorShapeOrDataDimExprs scalar_shape( std::vector{symbol::DimExpr(1)}); infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape); } + // Set the shape for the output tensor out_state if it exists if (op->num_results() > 2 && op->result(2) != nullptr) { infer_context->SetShapeOrDataForValue(op->result(2), in_state_shape); } + // Set the shape for the output tensor out_accum if it exists if (op->num_results() > 3 && op->result(3) != nullptr) { infer_context->SetShapeOrDataForValue(op->result(3), in_accum_shape); } From 8d3997730f76ef0563f17f91605b0a1f65103f3c Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Sun, 4 Aug 2024 22:42:40 +0800 Subject: [PATCH 16/20] Update unary_infer_sym.cc --- .../infer_symbolic_shape/unary_infer_sym.cc | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index d0d35ba729f2a..9a4c3dd9e7af7 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -851,6 +851,7 @@ bool KthvalueOpInferSymbolicShape( bool L1NormOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + // The output is a scalar, set the output shape accordingly std::vector output_shape; infer_context->SetShapeOrDataForValue( op->result(0), @@ -1279,14 +1280,40 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, bool RreluOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + float lower = op->attribute("lower").data(); + float upper = op->attribute("upper").data(); const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); + // Check constraints for the attributes lower and upper + PADDLE_ENFORCE_GE(lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE(upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GE( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value. " + "But received upper value = %f, lower value = %f.", + upper, + lower)); + + // Set the shape for the output tensor out infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + // Set the shape for the output tensor noise if it exists if (op->num_results() > 1 && op->result(1) != nullptr) { infer_context->SetShapeOrDataForValue( op->result(1), From 6e77d5ef09708cd8171dd2a52b3b2c020b2fd5e6 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Sun, 4 Aug 2024 23:05:28 +0800 Subject: [PATCH 17/20] Update multiary_infer_sym.cc --- .../infer_symbolic_shape/multiary_infer_sym.cc | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 6d53197e6714f..e24f65d7c4eea 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -615,29 +615,25 @@ bool BilinearInterpOpInferSymbolicShape( bool CheckFiniteAndUnscaleOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - // Retrieve shapes for all input tensors except the last one (scale) std::vector xs_shapes; for (size_t i = 0; i < op->num_operands() - 1; ++i) { xs_shapes.push_back( infer_context->GetShapeOrDataForValue(op->operand_source(i))); } - // Ensure the number of inputs (xs) matches the number of outputs (outs) infer_context->AddEqualCstr( symbol::DimExpr(static_cast(xs_shapes.size())), symbol::DimExpr(static_cast(op->num_results() - 1))); - - // Set shapes for output tensors corresponding to each input tensor for (size_t i = 0; i < xs_shapes.size(); ++i) { + symbol::TensorShapeOrDataDimExprs output_shape(xs_shapes[i].shape()); infer_context->SetShapeOrDataForValue( - op->result(i), symbol::ShapeOrDataDimExprs{xs_shapes[i].shape()}); + op->result(i), symbol::ShapeOrDataDimExprs{output_shape}); } - // Set shape for the found_infinite output tensor + symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)}); infer_context->SetShapeOrDataForValue( op->result(op->num_results() - 1), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(1)})}); + symbol::ShapeOrDataDimExprs{found_infinite_shape}); return true; } From 1972a1144759edbf7fd37cd73d014d49482255f9 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Tue, 6 Aug 2024 00:27:11 +0800 Subject: [PATCH 18/20] Update multiary_infer_sym.cc --- .../multiary_infer_sym.cc | 63 ++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 5866e53f15103..e5fccd29430c8 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -615,21 +615,29 @@ bool BilinearInterpOpInferSymbolicShape( bool CheckFiniteAndUnscaleOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - std::vector xs_shapes; - for (size_t i = 0; i < op->num_operands() - 1; ++i) { - xs_shapes.push_back( - infer_context->GetShapeOrDataForValue(op->operand_source(i))); - } - // Ensure the number of inputs (xs) matches the number of outputs (outs) - infer_context->AddEqualCstr( - symbol::DimExpr(static_cast(xs_shapes.size())), - symbol::DimExpr(static_cast(op->num_results() - 1))); + // Retrieve the shape information of the input list + pir::Value operand_source = op->operand_source(0); + const symbol::TensorListShapeOrDataDimExprs &xs_shapes = + infer_context->GetShapeOrDataForValue(operand_source) + .dyn_cast(); + + PADDLE_ENFORCE_EQ( + xs_shapes.size(), + op->num_results() - 1, + phi::errors::InvalidArgument("The number of inputs (xs) should match the " + "number of outputs (outs), " + "but got %d inputs and %d outputs.", + xs_shapes.size(), + op->num_results() - 1)); + + // Set the shape for each output for (size_t i = 0; i < xs_shapes.size(); ++i) { symbol::TensorShapeOrDataDimExprs output_shape(xs_shapes[i].shape()); infer_context->SetShapeOrDataForValue( op->result(i), symbol::ShapeOrDataDimExprs{output_shape}); } + // Set the shape for the found_infinite output symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)}); infer_context->SetShapeOrDataForValue( op->result(op->num_results() - 1), @@ -1134,37 +1142,20 @@ bool MeshgridOpInferSymbolicShape( bool MovingAverageAbsMaxScaleOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - // Get the symbolic shape of the input tensor x - const symbol::ShapeOrDataDimExprs &x_shape = + // Get shapes of input tensors + const auto &x_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)); - // Get the symbolic shape of the input tensor in_state - const symbol::ShapeOrDataDimExprs &in_state_shape = + const auto &in_state_shape = infer_context->GetShapeOrDataForValue(op->operand_source(2)); - // Get the symbolic shape of the input tensor in_accum - const symbol::ShapeOrDataDimExprs &in_accum_shape = + const auto &in_accum_shape = infer_context->GetShapeOrDataForValue(op->operand_source(1)); - // Set the shape for the output tensor out if it exists - if (op->num_results() > 0 && op->result(0) != nullptr) { - infer_context->SetShapeOrDataForValue(op->result(0), x_shape); - } - - // Set the shape for the output tensor out_scale as a scalar if it exists - if (op->num_results() > 1 && op->result(1) != nullptr) { - symbol::TensorShapeOrDataDimExprs scalar_shape( - std::vector{symbol::DimExpr(1)}); - infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape); - } - - // Set the shape for the output tensor out_state if it exists - if (op->num_results() > 2 && op->result(2) != nullptr) { - infer_context->SetShapeOrDataForValue(op->result(2), in_state_shape); - } - - // Set the shape for the output tensor out_accum if it exists - if (op->num_results() > 3 && op->result(3) != nullptr) { - infer_context->SetShapeOrDataForValue(op->result(3), in_accum_shape); - } + // Set shapes for output tensors + infer_context->SetShapeOrDataForValue(op->result(0), x_shape); + infer_context->SetShapeOrDataForValue( + op->result(1), symbol::TensorShapeOrDataDimExprs({symbol::DimExpr(1)})); + infer_context->SetShapeOrDataForValue(op->result(2), in_state_shape); + infer_context->SetShapeOrDataForValue(op->result(3), in_accum_shape); return true; } From 7b83438486fba7aee355b058ceff79854f27a041 Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:08:56 +0800 Subject: [PATCH 19/20] Update check_finite_and_unscale using tensorlist --- .../infer_symbolic_shape/multiary_infer_sym.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 6d02ecb3c10c1..48932a4cad784 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -651,18 +651,19 @@ bool CheckFiniteAndUnscaleOpInferSymbolicShape( xs_shapes.size(), op->num_results() - 1)); - // Set the shape for each output - for (size_t i = 0; i < xs_shapes.size(); ++i) { - symbol::TensorShapeOrDataDimExprs output_shape(xs_shapes[i].shape()); - infer_context->SetShapeOrDataForValue( - op->result(i), symbol::ShapeOrDataDimExprs{output_shape}); + // Set the shapes for the output tensor list + symbol::TensorListShapeOrDataDimExprs outs_shapes; + outs_shapes.reserve(xs_shapes.size()); + for (const auto &input_shape : xs_shapes) { + outs_shapes.emplace_back(input_shape.shape()); } + infer_context->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs{outs_shapes}); // Set the shape for the found_infinite output symbol::TensorShapeOrDataDimExprs found_infinite_shape({symbol::DimExpr(1)}); infer_context->SetShapeOrDataForValue( - op->result(op->num_results() - 1), - symbol::ShapeOrDataDimExprs{found_infinite_shape}); + op->result(1), symbol::ShapeOrDataDimExprs{found_infinite_shape}); return true; } @@ -1334,7 +1335,6 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape( // return true; // } - bool StackOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); From 817594cd4c8420a96c588e49dea8041acd04d9fa Mon Sep 17 00:00:00 2001 From: Guanhuachen2003 <166631022+Guanhuachen2003@users.noreply.github.com> Date: Wed, 7 Aug 2024 11:15:34 +0800 Subject: [PATCH 20/20] delete unnecessary paddle_enforce_eq --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 9 --------- 1 file changed, 9 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 48932a4cad784..e74f4931d4a3b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -642,15 +642,6 @@ bool CheckFiniteAndUnscaleOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(operand_source) .dyn_cast(); - PADDLE_ENFORCE_EQ( - xs_shapes.size(), - op->num_results() - 1, - phi::errors::InvalidArgument("The number of inputs (xs) should match the " - "number of outputs (outs), " - "but got %d inputs and %d outputs.", - xs_shapes.size(), - op->num_results() - 1)); - // Set the shapes for the output tensor list symbol::TensorListShapeOrDataDimExprs outs_shapes; outs_shapes.reserve(xs_shapes.size());