From 98db01aef2af43ef4c51e6205d8877384e2fd8c6 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Thu, 19 Jul 2018 19:42:54 +0000 Subject: [PATCH] Change CUDA StreamExecutor implementation to use new member function in AlgorithmDesc --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index c12eb1c61f003e..9213c3f72ca7d0 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1980,7 +1980,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, + dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, @@ -1998,6 +1998,7 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + algorithm_desc.set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2042,6 +2043,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( /*dxDesc=*/input_nd.handle(), /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + algorithm_desc.set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2066,7 +2068,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace( port::StatusOr> AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, + dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, const CudnnTensorDescriptor& output_nd, @@ -2086,6 +2088,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( /*gradDesc=*/filter.handle(), /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + algorithm_desc.set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {