From 02234fea41eb848fd8a93d646c6dc166b85072d4 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Tue, 19 Mar 2024 13:02:16 +0000 Subject: [PATCH 1/5] fix assign_pos --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 8 ++++++++ paddle/phi/api/yaml/op_compat.yaml | 6 ++++++ paddle/phi/infermeta/ternary.cc | 19 +++++++++++++++++++ paddle/phi/infermeta/ternary.h | 5 +++++ 5 files changed, 39 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 6a86600c1f5f2..813c725ee97bb 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -122,6 +122,7 @@ 'add_n_', 'all_reduce', 'all_reduce_', + 'assign_pos', 'c_allgather', 'c_allreduce_avg', 'c_allreduce_avg_', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 60f1bf2f03941..56fbc6526798f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -107,6 +107,14 @@ inplace : (output -> out) backward : assign_out__grad +- op : assign_pos + args : (Tensor x, Tensor cum_count, Tensor eff_num_len) + output : Tensor(out) + infer_meta : + func : AssignPosInferMeta + kernel : + func : assign_pos + - op : assign_value args : (int[] shape, DataType dtype, Scalar[] values, Place place = {}) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2fbd73623f0a0..c92058c71702f 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -296,6 +296,12 @@ get_expected_kernel_type : assign : GetAssignExpectedKernelType +- op : assign_pos + inputs : + {x : X, cum_count : cum_count, eff_num_len : eff_num_len} + outputs : + out : Out + - op : assign_value outputs : out : Out diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 0551859ed3789..00d89a4fab422 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -146,6 +146,25 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void AssignPosInferMeta(const MetaTensor& x, + const MetaTensor& cum_count, + const MetaTensor& eff_num_len, + MetaTensor* out) { + phi::DataType X_dtype = x.dtype(); + phi::DataType cum_count_dtype = cum_count.dtype(); + + PADDLE_ENFORCE_EQ(cum_count_dtype, + X_dtype, + phi::errors::InvalidArgument( + "The dtype of the cum_count and X should be same")); + PADDLE_ENFORCE_EQ(cum_count_dtype, + phi::DataType::INT64, + phi::errors::InvalidArgument( + "The dtype of the cum_count_dtype, eff_num_len and " + "X should be same as int64")); + out->set_dtype(X_dtype); +} + void BoxCoderInferMeta(const MetaTensor& prior_box, const MetaTensor& prior_box_var, const MetaTensor& target_box, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index c331f7198de7a..c4eca0f19a47a 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& step, MetaTensor* out); +void AssignPosInferMeta(const MetaTensor& x, + const MetaTensor& cum_count, + const MetaTensor& eff_num_len, + MetaTensor* out); + void BoxCoderInferMeta(const MetaTensor& prior_box, const MetaTensor& prior_box_var, const MetaTensor& target_box, From f44a0095340cb2786e5de849b26ce31d8ba5c13a Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 21 Mar 2024 13:46:26 +0000 Subject: [PATCH 2/5] update fix --- test/white_list/pir_op_test_no_check_list | 1 + test/white_list/pir_op_test_white_list | 1 + 2 files changed, 2 insertions(+) diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index 7bed213c8278f..52b03b7aaf685 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -1,3 +1,4 @@ +test_assign_pos_op test_bernoulli_op test_dirichlet_op test_empty_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index dfa901c0ca126..816b602eb2142 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -16,6 +16,7 @@ test_arg_min_max_op_static_build test_arg_min_max_v2_op test_argsort_op test_assign_op +test_assign_pos_op test_assign_value_op test_atan2_op test_auc_op From 8b0c41691a0bc35d59a929d3b5ee3371769d82b9 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 21 Mar 2024 14:10:48 +0000 Subject: [PATCH 3/5] fix bug --- paddle/phi/infermeta/ternary.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 5f8cf98f739a3..fda7420e8ea66 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -163,7 +163,8 @@ void AssignPosInferMeta(const MetaTensor& x, "The dtype of the cum_count_dtype, eff_num_len and " "X should be same as int64")); out->set_dtype(X_dtype); - +} + void BatchFCInferMeta(const MetaTensor& input, const MetaTensor& w, const MetaTensor& bias, From 2ace60537b7261a975f91dd1e2acd394a8e6a762 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 21 Mar 2024 15:18:19 +0000 Subject: [PATCH 4/5] fix codestyle --- paddle/phi/infermeta/ternary.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index f0fd09216112d..ca8e19e3b5dee 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -57,7 +57,7 @@ void AssignPosInferMeta(const MetaTensor& x, const MetaTensor& cum_count, const MetaTensor& eff_num_len, MetaTensor* out); - + void BatchFCInferMeta(const MetaTensor& input, const MetaTensor& w, const MetaTensor& bias, From edc986cdca7fe792b31f22aca7e5078d9425c2f6 Mon Sep 17 00:00:00 2001 From: Eddie-Wang Date: Tue, 26 Mar 2024 22:00:16 +0800 Subject: [PATCH 5/5] Update paddle/phi/api/yaml/op_compat.yaml Co-authored-by: kangguangli --- paddle/phi/api/yaml/op_compat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index f72a2e1b9f0a5..d8c19cb77447d 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -298,7 +298,7 @@ - op : assign_pos inputs : - {x : X, cum_count : cum_count, eff_num_len : eff_num_len} + {x : X} outputs : out : Out