From ae6a9c563951baabe26341ab7c699bd3763b450f Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Thu, 1 Aug 2024 11:27:18 +0800 Subject: [PATCH 01/21] add cinn --- .../infer_symbolic_shape/binary_infer_sym.cc | 25 +++++++++++++++++++ .../infer_symbolic_shape/binary_infer_sym.h | 1 + .../infer_symbolic_shape/nullary_infer_sym.cc | 12 +++++++++ .../infer_symbolic_shape/nullary_infer_sym.h | 1 + .../infer_symbolic_shape/unary_infer_sym.cc | 15 +++++++++++ .../infer_symbolic_shape/unary_infer_sym.h | 1 + .../phi/ops/yaml/inconsistent/static_ops.yaml | 1 + paddle/phi/ops/yaml/ops.yaml | 2 ++ 8 files changed, 58 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index ab60df5859e9f..20733d01458fe 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -599,6 +599,31 @@ bool SearchsortedOpInferSymbolicShape( return true; } +bool SegmentPoolOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + std::vector out_shape; + out_shape.push_back(symbol::DimExpr{-1}); + auto axis = input_shape.size(); + for (int i = 1; i < axis; ++i) { + out_shape.push_back(input_shape[i]); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_shape)}; + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + + const auto pooltype = op->attribute("pooltype").AsString(); + if (pooltype == "MEAN") { + std::vector out_shape = {-1, 1}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + } + return true; +} + bool IscloseOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // The shape of output is the same as input `values` (op->operand_source(1)) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index f3ac5ba69c5d9..80eaa3fe90a50 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -39,6 +39,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SegmentPool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index b9198dbb8a899..8bb96c5c2e572 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -322,6 +322,18 @@ bool GaussianOpInferSymbolicShape( } } +bool RandpermOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + int n = op->attribute("n").data(); + std::vector out_shape = {n}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + bool RandintOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &shape_gen_op = op->operand_source(0).defining_op(); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h index a221eec936528..0a9325644546a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Full) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullIntArray) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randint) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Uniform) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 7e23e0ca730d7..8fdf964975b4c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -950,6 +950,21 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, return true; } +bool SetValueOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + PADDLE_ENFORCE_LT( + input_shape.shape().size(), + 7, + phi::errors::InvalidArgument( + "Input(x) of SetValueOp must have rank less than 7, but received %d.", + input_shape.shape().size())); + infer_context->SetShapeOrDataForValue(op->result(0), input_shape); + + return true; +} + bool ShapeSrOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return ShapeOpInferSymbolicShape(op, infer_context); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 56866817477dc..cfef18f3cc184 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -59,6 +59,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 148f572379367..e812cb037a31b 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -809,6 +809,7 @@ kernel : func : set_value backward: set_value_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : shadow_feed args : (Tensor x) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 1126966ac60a1..68806130316a9 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3647,6 +3647,7 @@ data_type : dtype backend : place traits : pir::SideEffectTrait + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rank_attention args : (Tensor x, Tensor rank_offset, Tensor rank_param, int max_rank = 3, int max_size = 0) @@ -3961,6 +3962,7 @@ data_type : x intermediate : summed_ids backward : segment_pool_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : selu args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) From 73966fd8e1a98fb272fe3c53b18d079a7e05cc6c Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Thu, 1 Aug 2024 12:10:48 +0800 Subject: [PATCH 02/21] add cinn --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 20733d01458fe..54987cbbc69ae 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -605,7 +605,7 @@ bool SegmentPoolOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); std::vector out_shape; out_shape.push_back(symbol::DimExpr{-1}); - auto axis = input_shape.size(); + int axis = input_shape.size(); for (int i = 1; i < axis; ++i) { out_shape.push_back(input_shape[i]); } From a23074d502b4f799801499de0977b81222f36473 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Thu, 1 Aug 2024 13:18:39 +0800 Subject: [PATCH 03/21] fix setvalue --- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 5 +++++ .../interface/infer_symbolic_shape/unary_infer_sym.h | 1 + 2 files changed, 6 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 8fdf964975b4c..b82570143ddb3 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -965,6 +965,11 @@ bool SetValueOpInferSymbolicShape( return true; } +bool SetValue_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return SetValueOpInferSymbolicShape(op, infer_context); +} + bool ShapeSrOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return ShapeOpInferSymbolicShape(op, infer_context); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index cfef18f3cc184..2490d5da4671d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -60,6 +60,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) From c78837e3778f4445f5173dd315e9a463bc34bd59 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Fri, 2 Aug 2024 15:39:05 +0800 Subject: [PATCH 04/21] fix segmentpoll --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 54987cbbc69ae..ef90dbd054b0c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -613,13 +613,13 @@ bool SegmentPoolOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); - const auto pooltype = op->attribute("pooltype").AsString(); + string pooltype = op->attribute("pooltype").AsString(); if (pooltype == "MEAN") { - std::vector out_shape = {-1, 1}; + std::vector summed_shape = {-1, 1}; infer_context->SetShapeOrDataForValue( - op->result(0), + op->result(1), symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_shape)}); + symbol::TensorShapeOrDataDimExprs(summed_shape)}); } return true; } From 43067b48cc2b44ea59cfa2b860fbd84442b651ab Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Fri, 2 Aug 2024 16:21:20 +0800 Subject: [PATCH 05/21] fix --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index ef90dbd054b0c..247965faaefb1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -613,8 +613,8 @@ bool SegmentPoolOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); - string pooltype = op->attribute("pooltype").AsString(); - if (pooltype == "MEAN") { + auto pool_type = op->attribute("pooltype").AsString(); + if (pool_type == "MEAN") { std::vector summed_shape = {-1, 1}; infer_context->SetShapeOrDataForValue( op->result(1), From 039d4d25a8a81829870bae0ac6d9b01a6b4813f2 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Mon, 5 Aug 2024 13:19:28 +0800 Subject: [PATCH 06/21] fix --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 247965faaefb1..592ca3f897097 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -613,7 +613,8 @@ bool SegmentPoolOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); - auto pool_type = op->attribute("pooltype").AsString(); + const std::string pool_type = + op->attribute("pooltype").AsString(); if (pool_type == "MEAN") { std::vector summed_shape = {-1, 1}; infer_context->SetShapeOrDataForValue( From 8bc41d628f48c8aa9707404baf917793fbdcbb14 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Mon, 5 Aug 2024 22:38:25 +0800 Subject: [PATCH 07/21] fix segment --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 592ca3f897097..2b63978d4c16b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -603,9 +603,14 @@ bool SegmentPoolOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &input_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &segment_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector out_shape; - out_shape.push_back(symbol::DimExpr{-1}); + int axis = input_shape.size(); + int row_shape = static_cast( + segment_shape_or_data.data().value()[input_shape[0] - 1].Get()); + out_shape.push_back(symbol::DimExpr{row_shape}); for (int i = 1; i < axis; ++i) { out_shape.push_back(input_shape[i]); } @@ -616,7 +621,7 @@ bool SegmentPoolOpInferSymbolicShape( const std::string pool_type = op->attribute("pooltype").AsString(); if (pool_type == "MEAN") { - std::vector summed_shape = {-1, 1}; + std::vector summed_shape = {row_shape, 1}; infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ From a6cb5f04b42123adfebcf3e1c23d7515151937fc Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Mon, 5 Aug 2024 22:44:32 +0800 Subject: [PATCH 08/21] fix segment --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 2b63978d4c16b..7b2c266cd6c1d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -610,6 +610,7 @@ bool SegmentPoolOpInferSymbolicShape( int axis = input_shape.size(); int row_shape = static_cast( segment_shape_or_data.data().value()[input_shape[0] - 1].Get()); + row_shape += 1; out_shape.push_back(symbol::DimExpr{row_shape}); for (int i = 1; i < axis; ++i) { out_shape.push_back(input_shape[i]); From f05d832549a94c028d4ba3aca87b86175bcac5b2 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Tue, 6 Aug 2024 10:14:31 +0800 Subject: [PATCH 09/21] fix segment --- .../infer_symbolic_shape/binary_infer_sym.cc | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 7b2c266cd6c1d..ec6700dfa3842 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -603,15 +603,11 @@ bool SegmentPoolOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &input_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - const auto &segment_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(1)); std::vector out_shape; - + symbol::DimExpr out_unknown = + infer_context->GetNextSymName(); // unknown until runtime + out_shape.push_back(out_unknown); int axis = input_shape.size(); - int row_shape = static_cast( - segment_shape_or_data.data().value()[input_shape[0] - 1].Get()); - row_shape += 1; - out_shape.push_back(symbol::DimExpr{row_shape}); for (int i = 1; i < axis; ++i) { out_shape.push_back(input_shape[i]); } @@ -622,7 +618,9 @@ bool SegmentPoolOpInferSymbolicShape( const std::string pool_type = op->attribute("pooltype").AsString(); if (pool_type == "MEAN") { - std::vector summed_shape = {row_shape, 1}; + std::vector summed_shape; + summed_shape.push_back(out_unknown); + summed_shape.push_back(symbol::DimExpr{1}); infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ From b2dfb5846df3532def3a7ed42a1f741917ccedc9 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Tue, 6 Aug 2024 17:37:15 +0800 Subject: [PATCH 10/21] close ctest --- test/legacy_test/test_segment_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 8278cd984d1d6..929a37151e0ab 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -124,7 +124,7 @@ def setUp(self): self.convert_bf16() def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=False) def test_check_grad(self): self.check_grad(["X"], "Out", check_pir=True) From ea28e959a91c24735afb9ee39d3d1102a65d1780 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Wed, 7 Aug 2024 09:55:32 +0800 Subject: [PATCH 11/21] fix type --- .../interface/infer_symbolic_shape/nullary_infer_sym.cc | 2 +- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index b6e9cf3dc4931..e0c3077cd42b9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -333,7 +333,7 @@ bool GaussianOpInferSymbolicShape( bool RandpermOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - int n = op->attribute("n").data(); + int64_t n = op->attribute("n").data(); std::vector out_shape = {n}; infer_context->SetShapeOrDataForValue( op->result(0), diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 09e75c7736252..25988f45a619d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1484,15 +1484,15 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, bool SetValueOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &input_shape = + const auto &input_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); PADDLE_ENFORCE_LT( - input_shape.shape().size(), + input_shape_or_data.shape().size(), 7, phi::errors::InvalidArgument( "Input(x) of SetValueOp must have rank less than 7, but received %d.", - input_shape.shape().size())); - infer_context->SetShapeOrDataForValue(op->result(0), input_shape); + input_shape_or_data.shape().size())); + infer_context->SetShapeOrDataForValue(op->result(0), input_shape_or_data); return true; } From 2f68e94aff6bbe8f58e503320b6147dce9815a34 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Wed, 7 Aug 2024 16:53:49 +0800 Subject: [PATCH 12/21] fix segment --- .../infer_symbolic_shape/binary_infer_sym.cc | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index d519e2a8b06f9..43d67933499a6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -754,12 +754,38 @@ bool SearchsortedOpInferSymbolicShape( bool SegmentPoolOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &input_shape = - infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &input_shape = input_shape_or_data.shape(); + const auto &ids_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &ids_shape = ids_shape_or_data.shape(); + const std::string pool_type = + op->attribute("pooltype").AsString(); + + int ndims_ids = ids_shape.size(); + int last_dim = static_cast(ids_shape[ndims_ids - 1].Get()); + std::vector out_shape; - symbol::DimExpr out_unknown = - infer_context->GetNextSymName(); // unknown until runtime - out_shape.push_back(out_unknown); + if (pool_type == "MEAN") { + std::vector summed_shape; + } + if (ids_shape_or_data.data().has_value()) { + const auto &ids_data = ids_shape_or_data.data(); + int out_known = + static_cast(ids_data.value()[last_dim - 1].Get()); + out_shape.push_back(symbol::DimExpr{out_known + 1}); + if (pool_type == "MEAN") { + summed_shape.push_back(out_shape[0]); + } + } else { + symbol::DimExpr out_unknown = + infer_context->GetNextSymName(); // unknown until runtime + out_shape.push_back(out_unknown); + if (pool_type == "MEAN") { + summed_shape.push_back(out_unknown); // same as before + } + } int axis = input_shape.size(); for (int i = 1; i < axis; ++i) { out_shape.push_back(input_shape[i]); @@ -767,12 +793,7 @@ bool SegmentPoolOpInferSymbolicShape( symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); - - const std::string pool_type = - op->attribute("pooltype").AsString(); if (pool_type == "MEAN") { - std::vector summed_shape; - summed_shape.push_back(out_unknown); summed_shape.push_back(symbol::DimExpr{1}); infer_context->SetShapeOrDataForValue( op->result(1), From c1dc74a554fd59a94e6c54e3854435e991d88280 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Wed, 7 Aug 2024 16:57:06 +0800 Subject: [PATCH 13/21] fix segment --- test/legacy_test/test_segment_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 929a37151e0ab..60fdecf09f80b 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -124,7 +124,7 @@ def setUp(self): self.convert_bf16() def test_check_output(self): - self.check_output(check_pir=False) + self.check_output(check_pir=True, check_symbol_infer=False) def test_check_grad(self): self.check_grad(["X"], "Out", check_pir=True) From 15707aced0250b813616586230a9b940d13b0447 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Wed, 7 Aug 2024 17:21:14 +0800 Subject: [PATCH 14/21] fix --- .../infer_symbolic_shape/binary_infer_sym.cc | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 43d67933499a6..d487fef1b405f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -767,24 +767,15 @@ bool SegmentPoolOpInferSymbolicShape( int last_dim = static_cast(ids_shape[ndims_ids - 1].Get()); std::vector out_shape; - if (pool_type == "MEAN") { - std::vector summed_shape; - } if (ids_shape_or_data.data().has_value()) { const auto &ids_data = ids_shape_or_data.data(); int out_known = static_cast(ids_data.value()[last_dim - 1].Get()); out_shape.push_back(symbol::DimExpr{out_known + 1}); - if (pool_type == "MEAN") { - summed_shape.push_back(out_shape[0]); - } } else { symbol::DimExpr out_unknown = infer_context->GetNextSymName(); // unknown until runtime out_shape.push_back(out_unknown); - if (pool_type == "MEAN") { - summed_shape.push_back(out_unknown); // same as before - } } int axis = input_shape.size(); for (int i = 1; i < axis; ++i) { @@ -794,6 +785,8 @@ bool SegmentPoolOpInferSymbolicShape( symbol::TensorShapeOrDataDimExprs(out_shape)}; infer_context->SetShapeOrDataForValue(op->result(0), shape_data); if (pool_type == "MEAN") { + std::vector summed_shape; + summed_shape.push_back(out_shape[0]); // same as before summed_shape.push_back(symbol::DimExpr{1}); infer_context->SetShapeOrDataForValue( op->result(1), From c03d5944c6ca60f877893499a978e6fb3917a36c Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Thu, 8 Aug 2024 09:52:23 +0800 Subject: [PATCH 15/21] fix --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index d487fef1b405f..e074d334ea874 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -758,7 +758,7 @@ bool SegmentPoolOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &input_shape = input_shape_or_data.shape(); const auto &ids_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(0)); + infer_context->GetShapeOrDataForValue(op->operand_source(1)); const std::vector &ids_shape = ids_shape_or_data.shape(); const std::string pool_type = op->attribute("pooltype").AsString(); From 674aaa378e589bd9bced98864ff1cf8357897330 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Fri, 9 Aug 2024 11:54:46 +0800 Subject: [PATCH 16/21] fix --- test/legacy_test/test_segment_ops.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 60fdecf09f80b..7d800e3d7e2a3 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -223,11 +223,15 @@ def setUp(self): def test_check_output(self): if core.is_compiled_with_cuda(): - self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + self.check_output_with_place( + core.CUDAPlace(0), check_pir=True, check_symbol_infer=False + ) # due to CPU kernel not implement calculate 'SummedIds' # so cannot check 'SummedIds' del self.outputs['SummedIds'] - self.check_output_with_place(core.CPUPlace(), check_pir=True) + self.check_output_with_place( + core.CPUPlace(), check_pir=True, check_symbol_infer=False + ) class TestSegmentMean2(TestSegmentMean): @@ -274,10 +278,14 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): - self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True) + self.check_grad_with_place( + self.place, ["X"], "Out", check_pir=True, check_symbol_infer=False + ) @unittest.skipIf( @@ -292,7 +300,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place( @@ -316,7 +326,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place( @@ -340,7 +352,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True) From 8c73d31f359ab7a2864f63cd7ed8bb44353ec7a5 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Fri, 9 Aug 2024 14:30:16 +0800 Subject: [PATCH 17/21] fix codestyle --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 25dd46c3bf301..6215717cbaf27 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1561,7 +1561,7 @@ bool SetValue_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return SetValueOpInferSymbolicShape(op, infer_context); } - + bool RreluOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { float lower = op->attribute("lower").data(); From 736a6f46c908c6f36ede519daf41d865329b31d7 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Sat, 10 Aug 2024 11:21:06 +0800 Subject: [PATCH 18/21] fix segment --- test/legacy_test/test_segment_ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 7d800e3d7e2a3..afa3645d22878 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -283,9 +283,7 @@ def test_check_output(self): ) def test_check_grad(self): - self.check_grad_with_place( - self.place, ["X"], "Out", check_pir=True, check_symbol_infer=False - ) + self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True) @unittest.skipIf( From 6f68ce0464a6c87970a5b5fba1d34d943dbe440c Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Mon, 12 Aug 2024 10:33:45 +0800 Subject: [PATCH 19/21] fix --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 37d552dea3030..2158edeba411e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -818,15 +818,10 @@ bool SegmentPoolOpInferSymbolicShape( const std::string pool_type = op->attribute("pooltype").AsString(); - int ndims_ids = ids_shape.size(); - int last_dim = static_cast(ids_shape[ndims_ids - 1].Get()); - std::vector out_shape; if (ids_shape_or_data.data().has_value()) { const auto &ids_data = ids_shape_or_data.data(); - int out_known = - static_cast(ids_data.value()[last_dim - 1].Get()); - out_shape.push_back(symbol::DimExpr{out_known + 1}); + out_shape.push_back(ids_data.value()[ids_shape.size() - 1] + DimExpr{1}); } else { symbol::DimExpr out_unknown = infer_context->GetNextSymName(); // unknown until runtime From f21d0473635969e1e8662e5c456f75b42bfe6276 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Mon, 12 Aug 2024 11:08:20 +0800 Subject: [PATCH 20/21] fix --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 2158edeba411e..e3743e0bf660e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -821,7 +821,8 @@ bool SegmentPoolOpInferSymbolicShape( std::vector out_shape; if (ids_shape_or_data.data().has_value()) { const auto &ids_data = ids_shape_or_data.data(); - out_shape.push_back(ids_data.value()[ids_shape.size() - 1] + DimExpr{1}); + out_shape.push_back(ids_data.value()[ids_shape.size() - 1] + + symbol::DimExpr{1}); } else { symbol::DimExpr out_unknown = infer_context->GetNextSymName(); // unknown until runtime From b00eb5607c0d5b8d4f0b968952c6b2863c959b93 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Tue, 13 Aug 2024 11:22:57 +0800 Subject: [PATCH 21/21] fix --- .../infer_symbolic_shape/unary_infer_sym.cc | 20 ------------------- .../infer_symbolic_shape/unary_infer_sym.h | 2 -- 2 files changed, 22 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 6215717cbaf27..edf30e36ae5f9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1542,26 +1542,6 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, return true; } -bool SetValueOpInferSymbolicShape( - pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &input_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(0)); - PADDLE_ENFORCE_LT( - input_shape_or_data.shape().size(), - 7, - phi::errors::InvalidArgument( - "Input(x) of SetValueOp must have rank less than 7, but received %d.", - input_shape_or_data.shape().size())); - infer_context->SetShapeOrDataForValue(op->result(0), input_shape_or_data); - - return true; -} - -bool SetValue_OpInferSymbolicShape( - pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return SetValueOpInferSymbolicShape(op, infer_context); -} - bool RreluOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { float lower = op->attribute("lower").data(); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 6b14cb331040e..3e3a32e490e7e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -95,8 +95,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rrelu) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr)