@@ -1980,7 +1980,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
19801980
19811981port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace (
19821982 Stream* stream, const CudnnHandle& cudnn,
1983- const dnn::AlgorithmDesc& algorithm_desc,
1983+ dnn::AlgorithmDesc& algorithm_desc,
19841984 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
19851985 const CudnnConvolutionDescriptor& conv,
19861986 const CudnnTensorDescriptor& output_nd,
@@ -1998,6 +1998,7 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
19981998 /* wDesc=*/ filter.handle (), /* convDesc=*/ conv.handle (),
19991999 /* yDesc=*/ output_nd.handle (), /* algo=*/ ToConvForwardAlgo (algorithm_desc),
20002000 /* sizeInBytes=*/ &size_in_bytes));
2001+ algorithm_desc.set_scratch_size (size_in_bytes);
20012002 int64 size_in_bytes_int64 = size_in_bytes;
20022003
20032004 if (TF_PREDICT_FALSE (size_in_bytes_int64 < 0 )) {
@@ -2042,6 +2043,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
20422043 /* dxDesc=*/ input_nd.handle (),
20432044 /* algo=*/ ToConvBackwardDataAlgo (algorithm_desc),
20442045 /* sizeInBytes=*/ &size_in_bytes));
2046+ algorithm_desc.set_scratch_size (size_in_bytes);
20452047 int64 size_in_bytes_int64 = size_in_bytes;
20462048
20472049 if (TF_PREDICT_FALSE (size_in_bytes_int64 < 0 )) {
@@ -2066,7 +2068,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
20662068port::StatusOr<DeviceMemory<uint8>>
20672069AllocateCudnnConvolutionBackwardFilterWorkspace (
20682070 Stream* stream, const CudnnHandle& cudnn,
2069- const dnn::AlgorithmDesc& algorithm_desc,
2071+ dnn::AlgorithmDesc& algorithm_desc,
20702072 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
20712073 const CudnnConvolutionDescriptor& conv,
20722074 const CudnnTensorDescriptor& output_nd,
@@ -2086,6 +2088,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
20862088 /* gradDesc=*/ filter.handle (),
20872089 /* algo=*/ ToConvBackwardFilterAlgo (algorithm_desc),
20882090 /* sizeInBytes=*/ &size_in_bytes));
2091+ algorithm_desc.set_scratch_size (size_in_bytes);
20892092 int64 size_in_bytes_int64 = size_in_bytes;
20902093
20912094 if (TF_PREDICT_FALSE (size_in_bytes_int64 < 0 )) {
0 commit comments