diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 6f0a552b529d9..9f0cef3241b87 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -123,6 +123,7 @@ 'seed', 'send_v2', 'shadow_feed', + 'shuffle_batch', 'sparse_momentum', 'uniform_random_batch_size_like', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 460ca5ad373ce..2d6fe9bf03eb0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1123,6 +1123,16 @@ func: share_data param: [x] +- op : shuffle_batch + args : (Tensor x, Tensor seed, int startup_seed=0) + output : Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) + infer_meta: + func: ShuffleBatchInferMeta + kernel: + func: shuffle_batch + data_type: x + backward : shuffle_batch_grad + - op : slice args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) output : Tensor diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index beba440b5b6de..e011efe752a6b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -895,6 +895,16 @@ func: fused_elemwise_add_activation_grad optional : x, intermediate_out +- backward_op: shuffle_batch_grad + forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) + args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0) + output : Tensor(x_grad) + infer_meta: + func: ShuffleBatchGradInferMeta + kernel: + func: shuffle_batch_grad + data_type : out_grad + - backward_op: unpool_grad forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 553df312fdec7..b1ac12a417feb 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2747,6 +2747,13 @@ out : Out xout : XOut +- op : shuffle_batch + backward: shuffle_batch_grad + inputs: + {x : X, seed : Seed} + outputs: + {out : Out, shuffle_idx : ShuffleIdx, seed_out : SeedOut} + - op : shuffle_channel backward : shuffle_channel_grad extra : diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a3eb7ce8c906b..5485517ab339e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1014,6 +1014,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad) { + x_grad->share_dims(out_grad); + x_grad->share_lod(out_grad); + x_grad->set_dtype(out_grad.dtype()); +} + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c1d79f2378926..756fd93086396 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -413,6 +413,11 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad); + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 53b3b00286a58..4a7bf8d35b44b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2735,6 +2735,21 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, } } +void ShuffleBatchInferMeta(const MetaTensor& x, + const MetaTensor& seed, + int startup_seed, + MetaTensor* out, + MetaTensor* shuffle_idx, + MetaTensor* seed_out + +) { + out->share_dims(x); + out->share_lod(x); + seed_out->share_dims(seed); + seed_out->share_lod(seed); + shuffle_idx->set_dims(phi::make_ddim({-1})); +} + void SequenceMaskInferMeta(const MetaTensor& x, const MetaTensor& max_len_tensor, int maxlen, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index e7b646d664530..fcc407a7c93f9 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -430,6 +430,15 @@ void SequenceMaskInferMeta(const MetaTensor& x, int out_dtype, MetaTensor* y); +void ShuffleBatchInferMeta(const MetaTensor& x, + const MetaTensor& seed, + int startup_seed, + MetaTensor* out, + MetaTensor* shuffle_idx, + MetaTensor* seed_out + +); + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out); diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index 99e67a1dc2d02..00d8f054df131 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -9,3 +9,4 @@ test_randperm_op test_seed_op test_uniform_random_bf16_op test_uniform_random_inplace_op +test_shuffle_batch_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index d1c3605ca618f..14e30381cd28c 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -269,6 +269,7 @@ test_sgd_op test_shape_mkldnn_op test_shape_op test_shard_index_op +test_shuffle_batch_op test_sigmoid_cross_entropy_with_logits_op test_sign_op test_size_op