From d5503697ff14c24ca08567dfb0709516972f79d4 Mon Sep 17 00:00:00 2001 From: xingmingyyj Date: Wed, 3 Jan 2024 02:10:24 +0000 Subject: [PATCH] fix --- paddle/phi/infermeta/backward.cc | 9 +++++++++ paddle/phi/infermeta/backward.h | 5 +++++ paddle/phi/infermeta/binary.cc | 8 -------- paddle/phi/infermeta/binary.h | 5 ----- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a3eb7ce8c906b..5485517ab339e 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1014,6 +1014,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad) { + x_grad->share_dims(out_grad); + x_grad->share_lod(out_grad); + x_grad->set_dtype(out_grad.dtype()); +} + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c1d79f2378926..756fd93086396 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -413,6 +413,11 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad); + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c411e955a1c7f..4a7bf8d35b44b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2750,14 +2750,6 @@ void ShuffleBatchInferMeta(const MetaTensor& x, shuffle_idx->set_dims(phi::make_ddim({-1})); } -void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, - const MetaTensor& out_grad, - int startup_seed, - MetaTensor* x_grad) { - x_grad->share_dims(out_grad); - x_grad->share_lod(out_grad); -} - void SequenceMaskInferMeta(const MetaTensor& x, const MetaTensor& max_len_tensor, int maxlen, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 5421c0ac0a333..fcc407a7c93f9 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -439,11 +439,6 @@ void ShuffleBatchInferMeta(const MetaTensor& x, ); -void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, - const MetaTensor& out_grad, - int startup_seed, - MetaTensor* x_grad); - void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out);