diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index ea942648685eda..8f6310eab537ec 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -125,6 +125,7 @@ 'add_n_', 'all_reduce', 'all_reduce_', + 'assign_pos', 'batch_fc', 'barrier', 'c_allgather', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index de64ca2f98a95a..3c245d1678a78b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -107,6 +107,14 @@ inplace : (output -> out) backward : assign_out__grad +- op : assign_pos + args : (Tensor x, Tensor cum_count, Tensor eff_num_len) + output : Tensor(out) + infer_meta : + func : AssignPosInferMeta + kernel : + func : assign_pos + - op : assign_value args : (int[] shape, DataType dtype, Scalar[] values, Place place = {}) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 0c3f7488362ebe..d8c19cb77447d1 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -296,6 +296,12 @@ get_expected_kernel_type : assign : GetAssignExpectedKernelType +- op : assign_pos + inputs : + {x : X} + outputs : + out : Out + - op : assign_value outputs : out : Out diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c5e5cb61a4a40a..fda7420e8ea66a 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -146,6 +146,25 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void AssignPosInferMeta(const MetaTensor& x, + const MetaTensor& cum_count, + const MetaTensor& eff_num_len, + MetaTensor* out) { + phi::DataType X_dtype = x.dtype(); + phi::DataType cum_count_dtype = cum_count.dtype(); + + PADDLE_ENFORCE_EQ(cum_count_dtype, + X_dtype, + phi::errors::InvalidArgument( + "The dtype of the cum_count and X should be same")); + PADDLE_ENFORCE_EQ(cum_count_dtype, + phi::DataType::INT64, + phi::errors::InvalidArgument( + "The dtype of the cum_count_dtype, eff_num_len and " + "X should be same as int64")); + out->set_dtype(X_dtype); +} + void BatchFCInferMeta(const MetaTensor& input, const MetaTensor& w, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 7a8fa648d434ed..ca8e19e3b5dee2 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& step, MetaTensor* out); +void AssignPosInferMeta(const MetaTensor& x, + const MetaTensor& cum_count, + const MetaTensor& eff_num_len, + MetaTensor* out); + void BatchFCInferMeta(const MetaTensor& input, const MetaTensor& w, const MetaTensor& bias, diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index 7bed213c8278f3..52b03b7aaf6852 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -1,3 +1,4 @@ +test_assign_pos_op test_bernoulli_op test_dirichlet_op test_empty_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 6df2ded8bc02fc..e7b20180b51a69 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -16,6 +16,7 @@ test_arg_min_max_op_static_build test_arg_min_max_v2_op test_argsort_op test_assign_op +test_assign_pos_op test_assign_value_op test_atan2_op test_auc_op