diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 38e43450301ef..e3743e0bf660e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -807,6 +807,46 @@ bool SearchsortedOpInferSymbolicShape( return true; } +bool SegmentPoolOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &input_shape = input_shape_or_data.shape(); + const auto &ids_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const std::vector &ids_shape = ids_shape_or_data.shape(); + const std::string pool_type = + op->attribute("pooltype").AsString(); + + std::vector out_shape; + if (ids_shape_or_data.data().has_value()) { + const auto &ids_data = ids_shape_or_data.data(); + out_shape.push_back(ids_data.value()[ids_shape.size() - 1] + + symbol::DimExpr{1}); + } else { + symbol::DimExpr out_unknown = + infer_context->GetNextSymName(); // unknown until runtime + out_shape.push_back(out_unknown); + } + int axis = input_shape.size(); + for (int i = 1; i < axis; ++i) { + out_shape.push_back(input_shape[i]); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_shape)}; + infer_context->SetShapeOrDataForValue(op->result(0), shape_data); + if (pool_type == "MEAN") { + std::vector summed_shape; + summed_shape.push_back(out_shape[0]); // same as before + summed_shape.push_back(symbol::DimExpr{1}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(summed_shape)}); + } + return true; +} + // bool SequenceMaskOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext // *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index d7d2f032fbded..08326589fd966 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -61,6 +61,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullSparseV2) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SegmentPool) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(SequenceMask) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Swiglu) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc index 5350ef4b9672a..e0c3077cd42b9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -331,6 +331,18 @@ bool GaussianOpInferSymbolicShape( } } +bool RandpermOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + int64_t n = op->attribute("n").data(); + std::vector out_shape = {n}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + bool RandintOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &shape_gen_op = op->operand_source(0).defining_op(); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h index d618508e3f171..3f1de5242a492 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Full) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullIntArray) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randint) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 4a8e4cd429287..dd0627edd26df 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -788,6 +788,7 @@ kernel : func : set_value backward: set_value_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : shadow_feed args : (Tensor x) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 4fe05dab20954..b398cc895de22 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3697,6 +3697,7 @@ data_type : dtype backend : place traits : pir::SideEffectTrait + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rank_attention args : (Tensor x, Tensor rank_offset, Tensor rank_param, int max_rank = 3, int max_size = 0) @@ -4038,6 +4039,7 @@ data_type : x intermediate : summed_ids backward : segment_pool_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : selu args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) diff --git a/test/legacy_test/test_segment_ops.py b/test/legacy_test/test_segment_ops.py index 8278cd984d1d6..afa3645d22878 100644 --- a/test/legacy_test/test_segment_ops.py +++ b/test/legacy_test/test_segment_ops.py @@ -124,7 +124,7 @@ def setUp(self): self.convert_bf16() def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_symbol_infer=False) def test_check_grad(self): self.check_grad(["X"], "Out", check_pir=True) @@ -223,11 +223,15 @@ def setUp(self): def test_check_output(self): if core.is_compiled_with_cuda(): - self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + self.check_output_with_place( + core.CUDAPlace(0), check_pir=True, check_symbol_infer=False + ) # due to CPU kernel not implement calculate 'SummedIds' # so cannot check 'SummedIds' del self.outputs['SummedIds'] - self.check_output_with_place(core.CPUPlace(), check_pir=True) + self.check_output_with_place( + core.CPUPlace(), check_pir=True, check_symbol_infer=False + ) class TestSegmentMean2(TestSegmentMean): @@ -274,7 +278,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True) @@ -292,7 +298,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place( @@ -316,7 +324,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place( @@ -340,7 +350,9 @@ def prepare(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place(self.place, check_pir=True) + self.check_output_with_place( + self.place, check_pir=True, check_symbol_infer=False + ) def test_check_grad(self): self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True)