Skip to content

Commit

Permalink
fix backward
Browse files Browse the repository at this point in the history
  • Loading branch information
EsdeathYZH committed Jun 9, 2022
1 parent 55b1f2a commit 26a43fc
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions paddle/phi/kernels/gpu/batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,22 @@ void BatchNormKernel(const Context &ctx,
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
#if CUDNN_VERSION_MIN(7, 4, 1)
// -------------- allocate reserve space for backward--------------
if (reserve_space != nullptr) {
size_t reserve_space_size = 0;
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*activationDesc=*/nullptr,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size));
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
}
#endif
} else {
#if CUDNN_VERSION_MIN(7, 4, 1)
size_t workspace_size = 0;
Expand Down

0 comments on commit 26a43fc

Please sign in to comment.