Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR OpTest Fix No.8】 fix test_shuffle_batch_op #59631

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
'seed',
'send_v2',
'shadow_feed',
'shuffle_batch',
'sparse_momentum',
'uniform_random_batch_size_like',
]
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,16 @@
func: share_data
param: [x]

- op : shuffle_batch
args : (Tensor x, Tensor seed, int startup_seed=0)
output : Tensor(out), Tensor(shuffle_idx), Tensor(seed_out)
infer_meta:
func: ShuffleBatchInferMeta
kernel:
func: shuffle_batch
data_type: x
backward : shuffle_batch_grad

- op : slice
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,16 @@
func: fused_elemwise_add_activation_grad
optional : x, intermediate_out

- backward_op: shuffle_batch_grad
forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out)
args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0)
output : Tensor(x_grad)
infer_meta:
func: ShuffleBatchGradInferMeta
kernel:
func: shuffle_batch_grad
data_type : out_grad

- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,13 @@
out : Out
xout : XOut

- op : shuffle_batch
backward: shuffle_batch_grad
inputs:
{x : X, seed : Seed}
outputs:
{out : Out, shuffle_idx : ShuffleIdx, seed_out : SeedOut}

- op : shuffle_channel
backward : shuffle_channel_grad
extra :
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index,
}
}

void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx,
const MetaTensor& out_grad,
int startup_seed,
MetaTensor* x_grad) {
x_grad->share_dims(out_grad);
x_grad->share_lod(out_grad);
x_grad->set_dtype(out_grad.dtype());
}

void SpectralNormGradInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad,
MetaTensor* updates_grad);

void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx,
const MetaTensor& out_grad,
int startup_seed,
MetaTensor* x_grad);

void SpectralNormGradInferMeta(const MetaTensor& weight,
const MetaTensor& u,
const MetaTensor& v,
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2735,6 +2735,21 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
}
}

void ShuffleBatchInferMeta(const MetaTensor& x,
const MetaTensor& seed,
int startup_seed,
MetaTensor* out,
MetaTensor* shuffle_idx,
MetaTensor* seed_out

) {
out->share_dims(x);
out->share_lod(x);
seed_out->share_dims(seed);
seed_out->share_lod(seed);
shuffle_idx->set_dims(phi::make_ddim({-1}));
}

void SequenceMaskInferMeta(const MetaTensor& x,
const MetaTensor& max_len_tensor,
int maxlen,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ void SequenceMaskInferMeta(const MetaTensor& x,
int out_dtype,
MetaTensor* y);

void ShuffleBatchInferMeta(const MetaTensor& x,
const MetaTensor& seed,
int startup_seed,
MetaTensor* out,
MetaTensor* shuffle_idx,
MetaTensor* seed_out

);

void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_no_check_list
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ test_randperm_op
test_seed_op
test_uniform_random_bf16_op
test_uniform_random_inplace_op
test_shuffle_batch_op
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ test_sgd_op
test_shape_mkldnn_op
test_shape_op
test_shard_index_op
test_shuffle_batch_op
test_sigmoid_cross_entropy_with_logits_op
test_sign_op
test_size_op
Expand Down