Skip to content

Commit 87c1245

Browse files
whchungsunway513
authored andcommitted
Merge pull request #72 from ROCmSoftwarePlatform/refactor-algorithmconfig-profileresult
Change CUDA StreamExecutor to use new member function in AlgorithmDesc
1 parent f40e780 commit 87c1245

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensorflow/stream_executor/cuda/cuda_dnn.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,7 +1980,7 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
19801980

19811981
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
19821982
Stream* stream, const CudnnHandle& cudnn,
1983-
const dnn::AlgorithmDesc& algorithm_desc,
1983+
dnn::AlgorithmDesc& algorithm_desc,
19841984
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
19851985
const CudnnConvolutionDescriptor& conv,
19861986
const CudnnTensorDescriptor& output_nd,
@@ -1998,6 +1998,7 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
19981998
/*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
19991999
/*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
20002000
/*sizeInBytes=*/&size_in_bytes));
2001+
algorithm_desc.set_scratch_size(size_in_bytes);
20012002
int64 size_in_bytes_int64 = size_in_bytes;
20022003

20032004
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2042,6 +2043,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
20422043
/*dxDesc=*/input_nd.handle(),
20432044
/*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
20442045
/*sizeInBytes=*/&size_in_bytes));
2046+
algorithm_desc.set_scratch_size(size_in_bytes);
20452047
int64 size_in_bytes_int64 = size_in_bytes;
20462048

20472049
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2066,7 +2068,7 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
20662068
port::StatusOr<DeviceMemory<uint8>>
20672069
AllocateCudnnConvolutionBackwardFilterWorkspace(
20682070
Stream* stream, const CudnnHandle& cudnn,
2069-
const dnn::AlgorithmDesc& algorithm_desc,
2071+
dnn::AlgorithmDesc& algorithm_desc,
20702072
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
20712073
const CudnnConvolutionDescriptor& conv,
20722074
const CudnnTensorDescriptor& output_nd,
@@ -2086,6 +2088,7 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
20862088
/*gradDesc=*/filter.handle(),
20872089
/*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
20882090
/*sizeInBytes=*/&size_in_bytes));
2091+
algorithm_desc.set_scratch_size(size_in_bytes);
20892092
int64 size_in_bytes_int64 = size_in_bytes;
20902093

20912094
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {

0 commit comments

Comments
 (0)