diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index d5a45c73c37bf0..171973c91c6108 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -428,7 +428,12 @@ struct ProtoHelper { static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) { proto->mutable_half_val()->Reserve(n); for (size_t i = 0; i < n; ++i) { +#if defined(TENSORFLOW_USE_ROCM_HIP_FP16) + // Implementation of Eigen::half is different when using HIP FP16 on the GPU + proto->mutable_half_val()->AddAlreadyReserved(__half_as_ushort(data[i].x)); +#else proto->mutable_half_val()->AddAlreadyReserved(data[i].x); +#endif } } }; diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index e963511f5cfe64..3b8374a8f837d8 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -679,7 +679,15 @@ PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x) { const uint16 val = (exp << 10) | man; Eigen::half result; + + // The underlying Eigen::Half implementation is different on the CPU vs the GPU + // So need to handle this assignment differently depending on what we are compiling for +#if defined(TENSORFLOW_USE_ROCM_HIP_FP16) + result.x = __ushort_as_half(val); +#else result.x = val; +#endif + return result - Eigen::half(1.0); } diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h index 0481b3687137c8..b7b6d3d804d416 100644 --- a/tensorflow/core/platform/platform.h +++ b/tensorflow/core/platform/platform.h @@ -63,4 +63,17 @@ limitations under the License. #define PLATFORM_IS_X86 #endif +// Some of the Tensorflow code reaches into the implmentation of Eigen::half +// The implementation of Eigen::half is platform dependent in ROCm, +// i.e. it is different for CPU vs GPU +// Therefore the Tensorflow code that reaches into the Eigen::half implemenation +// needs to be different based on the platform we are compiling it for. +// Creating a TENSORFLOW_USE_ROCM_HIP_FP16 macro for that purpose +#if defined(TENSORFLOW_USE_ROCM) + #if defined(__HIPCC__) && defined(__HIP_DEVICE_COMPILE__) + #define TENSORFLOW_USE_ROCM_HIP_FP16 + #endif +#endif + + #endif // TENSORFLOW_PLATFORM_PLATFORM_DEFINE_H_ diff --git a/tensorflow/core/util/gpu_device_functions.h b/tensorflow/core/util/gpu_device_functions.h index 035b8514b1835c..fece8c0a1c3ef7 100644 --- a/tensorflow/core/util/gpu_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -500,7 +500,12 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) { uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) { unsigned short high = static_cast(arg >> 16); Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); +#if defined(TENSORFLOW_USE_ROCM_HIP_FP16) + // Implementation of Eigen::half is different when using HIP FP16 on the GPU + return (static_cast(__half_as_ushort(acc.x)) << 16) | (arg & 0xffff); +#else return (static_cast(acc.x) << 16) | (arg & 0xffff); +#endif }); return half_impl::raw_uint16_to_half(static_cast(result >> 16)); } else { @@ -509,7 +514,12 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) { uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) { unsigned short low = static_cast(arg & 0xffff); Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); +#if defined(TENSORFLOW_USE_ROCM_HIP_FP16) + // Implementation of Eigen::half is different when using HIP FP16 on the GPU + return (arg & 0xffff0000) | static_cast(__half_as_ushort(acc.x)); +#else return (arg & 0xffff0000) | static_cast(acc.x); +#endif }); return half_impl::raw_uint16_to_half(static_cast(result & 0xffff)); } diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc index 490c584dc5cf67..200c88f099176e 100644 --- a/tensorflow/core/util/port.cc +++ b/tensorflow/core/util/port.cc @@ -29,11 +29,22 @@ bool IsGoogleCudaEnabled() { #endif } -bool CudaSupportsHalfMatMulAndConv() { +bool IsBuiltWithROCm() { +#if TENSORFLOW_USE_ROCM + return true; +#else + return false; +#endif +} + + +bool GpuSupportsHalfMatMulAndConv() { #if GOOGLE_CUDA // NOTE: We check compile-time and not runtime, since the check for // whether we include the fp16 kernels or not is compile-time. return CUDA_VERSION >= 7050; +#elif TENSORFLOW_USE_ROCM + return true; #else return false; #endif diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h index 981def9d22a029..56f7070e3ae149 100644 --- a/tensorflow/core/util/port.h +++ b/tensorflow/core/util/port.h @@ -21,9 +21,19 @@ namespace tensorflow { // Returns true if GOOGLE_CUDA is defined. bool IsGoogleCudaEnabled(); -// Returns true if GOOGLE_CUDA is defined, and the given CUDA version supports -// half-precision matrix multiplications and convolution operations. -bool CudaSupportsHalfMatMulAndConv(); +// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm) +bool IsBuiltWithROCm(); + +// Returns true if either +// +// GOOGLE_CUDA is defined, and the given CUDA version supports +// half-precision matrix multiplications and convolution operations. +// +// OR +// +// TENSORFLOW_USE_ROCM is defined +// +bool GpuSupportsHalfMatMulAndConv(); // Returns true if INTEL_MKL is defined bool IsMklEnabled(); diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h index ee43945a393a8e..9e6778da19881f 100644 --- a/tensorflow/core/util/saved_tensor_slice_util.h +++ b/tensorflow/core/util/saved_tensor_slice_util.h @@ -171,7 +171,12 @@ inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { typename protobuf::RepeatedField* val = t->mutable_half_val(); val->Resize(n, 0); for (size_t i = 0; i < n; ++i) { +#if defined(TENSORFLOW_USE_ROCM_HIP_FP16) + // Implementation of Eigen::half is different when using HIP FP16 on the GPU + val->Set(i, __half_as_ushort(data[i].x)); +#else val->Set(i, data[i].x); +#endif } } diff --git a/tensorflow/docs_src/api_guides/python/test.md b/tensorflow/docs_src/api_guides/python/test.md index 5dc88124e7e1c2..ed30727e7028df 100644 --- a/tensorflow/docs_src/api_guides/python/test.md +++ b/tensorflow/docs_src/api_guides/python/test.md @@ -37,6 +37,8 @@ depending on the python version. * @{tf.test.assert_equal_graph_def} * @{tf.test.get_temp_dir} * @{tf.test.is_built_with_cuda} +* @{tf.test.is_built_with_rocm} +* @{tf.test.is_built_with_gpu_support} * @{tf.test.is_gpu_available} * @{tf.test.gpu_device_name} diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index dc56d88066cbe6..2265fe1dc4fbbf 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -204,8 +204,12 @@ def IsGoogleCudaEnabled(): return pywrap_tensorflow.IsGoogleCudaEnabled() -def CudaSupportsHalfMatMulAndConv(): - return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv() +def IsBuiltWithROCm(): + return pywrap_tensorflow.IsBuiltWithROCm() + + +def GpuSupportsHalfMatMulAndConv(): + return pywrap_tensorflow.GpuSupportsHalfMatMulAndConv() def IsMklEnabled(): @@ -740,6 +744,21 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): min_cuda_compute_capability: a (major,minor) pair that indicates the minimum CUDA compute capability required, or None if no requirement. + Note that the keyword arg name "cuda_only" is misleading (since routine will + return true when a GPU device is available irrespective of whether TF was + built with CUDA support or ROCm support. However no changes there because + + ++ Changing the name "cuda_only" to something more generic would break + backward compatibility + + ++ Adding an equivalent "rocm_only" would require the implementation check + the build type. This in turn would require doing the same for CUDA and thus + potentially break backward compatibility + + ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, but + would require most (if not all) callers to update the call to use + "cuda_or_rocm_only" instead of "cuda_only" + Returns: True iff a gpu device of the requested kind is available. """ diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 8d492256aac17d..479d068c431e17 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -85,6 +85,14 @@ def testIsGoogleCudaEnabled(self): else: print("GoogleCuda is disabled") + def testIsBuiltWithROCm(self): + # The test doesn't assert anything. It ensures the py wrapper + # function is generated correctly. + if test_util.IsBuiltWithROCm(): + print("Tensorflow build has ROCm support") + else: + print("Tensorflow build does not have ROCm support") + def testIsMklEnabled(self): # This test doesn't assert anything. # It ensures the py wrapper function is generated correctly. diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index 0b531125f36c6d..70f24b13cbaed0 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -51,7 +51,7 @@ class Conv3DTest(test.TestCase): def _DtypesToTest(self, use_gpu): if use_gpu: - if not test_util.CudaSupportsHalfMatMulAndConv(): + if not test_util.GpuSupportsHalfMatMulAndConv(): return [dtypes.float32] else: # It is important that float32 comes before float16 here, diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index a291bef0ad6f16..254a0007d6fc3a 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -158,7 +158,7 @@ def GetTestConfigs(): class Conv2DTest(test.TestCase): def _DtypesToTest(self, use_gpu): - if use_gpu and not test_util.CudaSupportsHalfMatMulAndConv(): + if use_gpu and not test_util.GpuSupportsHalfMatMulAndConv(): return [dtypes.float32, dtypes.float64] else: # It is important that float32 comes before float16 here, diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py index 4dda9f093b5329..efd40df50a3a94 100644 --- a/tensorflow/python/kernel_tests/dense_update_ops_test.py +++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py @@ -70,7 +70,7 @@ def _testTypes(self, vals): var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False) self.assertAllEqual(x - y, var_value) self.assertAllEqual(x - y, op_value) - if test.is_built_with_cuda() and dtype in [np.float32, np.float64]: + if test.is_built_with_gpu_support() and dtype in [np.float32, np.float64]: var_value, op_value = self._initAssignFetch(x, y, use_gpu=True) self.assertAllEqual(y, var_value) self.assertAllEqual(y, op_value) diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index b167278984cf45..104ec8afb2e4c4 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -62,7 +62,7 @@ def Test(self): use_gpu = True if a_np_.dtype is np.float16 and ( - not test_util.CudaSupportsHalfMatMulAndConv()): + not test_util.GpuSupportsHalfMatMulAndConv()): use_gpu = False print("Built without fp16 matmul support for Cuda, running test on CPU.") diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index a0c372db7d0a4e..7b2bddcd5af5d2 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -197,7 +197,7 @@ def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float64, expected, use_gpu, v2) - if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv(): + if not use_gpu or test_util.GpuSupportsHalfMatMulAndConv(): self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float16, expected, use_gpu, v2) diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index dc4d4dbeabf3c5..8c36bde1637af7 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -127,7 +127,7 @@ def testFloat(self): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32)) - @unittest.skipUnless(test.is_built_with_cuda(), + @unittest.skipUnless(test.is_built_with_gpu_support(), "Test only applicable when running on GPUs") def testFloatGPU(self): if test.is_gpu_available(cuda_only=True): @@ -142,7 +142,7 @@ def testHalf(self): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)) - @unittest.skipUnless(test.is_built_with_cuda(), + @unittest.skipUnless(test.is_built_with_gpu_support(), "Test only applicable when running on GPUs") def testHalfGPU(self): if test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index a841fe83a7f585..22e0162fa4fb00 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -85,7 +85,7 @@ def testInvalidLabel(self): [1., 2., 3., 4.]] labels = [4, 3, 0, -1] - if test.is_built_with_cuda() and test.is_gpu_available(): + if test.is_built_with_gpu_support() and test.is_gpu_available(): with self.test_session(use_gpu=True) as sess: loss, backprop = ( gen_nn_ops.sparse_softmax_cross_entropy_with_logits( diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 0a0fe68be569a7..e5ea8332349163 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -26,6 +26,8 @@ @@assert_equal_graph_def @@get_temp_dir @@is_built_with_cuda +@@is_built_with_rocm +@@is_built_with_gpu_support @@is_gpu_available @@gpu_device_name @@compute_gradient @@ -107,3 +109,17 @@ def test_src_dir_path(relative_path): def is_built_with_cuda(): """Returns whether TensorFlow was built with CUDA (GPU) support.""" return _test_util.IsGoogleCudaEnabled() + + +@tf_export('test.is_built_with_rocm') +def is_built_with_rocm(): + """Returns whether TensorFlow was built with ROCm (GPU) support.""" + return _test_util.IsBuiltWithROCm() + + +@tf_export('test.is_built_with_gpu_support') +def is_built_with_gpu_support(): + """Returns whether TensorFlow was built with GPU (either CUDA or ROCm) support. + """ + return is_built_with_cuda() or is_built_with_rocm() + diff --git a/tensorflow/python/util/port.i b/tensorflow/python/util/port.i index 2f730732bee373..64681a92798937 100644 --- a/tensorflow/python/util/port.i +++ b/tensorflow/python/util/port.i @@ -22,7 +22,8 @@ limitations under the License. %ignoreall %unignore tensorflow; %unignore tensorflow::IsGoogleCudaEnabled; -%unignore tensorflow::CudaSupportsHalfMatMulAndConv; +%unignore tensorflow::IsBuiltWithROCm; +%unignore tensorflow::GpuSupportsHalfMatMulAndConv; %unignore tensorflow::IsMklEnabled; %include "tensorflow/core/util/port.h" %unignoreall diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 3044c7ae0505bc..4a28a3f27168ba 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -1197,13 +1197,13 @@ bool MIOpenSupport::DoRnnBackwardImpl( const DeviceMemory& output_h_data, const MIOpenRnnStateTensorDescriptor& output_c_desc, const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) { @@ -1475,8 +1475,29 @@ bool MIOpenSupport::DoRnnForward( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + + // ROCM TODO: output_profile_result is ignore for now + + const MIOpenRnnDescriptor& miopen_rnn_desc = + static_cast(rnn_desc); + const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc = + static_cast(input_desc); + const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc = + static_cast(input_h_desc); + const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc = + static_cast(input_c_desc); + const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc = + static_cast(output_desc); + const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc = + static_cast(output_h_desc); + const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = + static_cast(output_c_desc); + + return DoRnnForwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, miopen_input_h_desc, + input_h_data, miopen_input_c_desc, input_c_data, params, miopen_output_desc, + output_data, miopen_output_h_desc, output_h_data, miopen_output_c_desc, + output_c_data, is_training, reserve_space_allocator, workspace_allocator); } bool MIOpenSupport::DoRnnForward( @@ -1568,8 +1589,32 @@ bool MIOpenSupport::DoRnnBackward( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - LOG(ERROR) << "miopen does not support half type RNN bwd yet"; - return false; + + // ROCM TODO: output_profile_result is ignore for now + + const MIOpenRnnDescriptor& miopen_rnn_desc = + static_cast(rnn_desc); + const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc = + static_cast(input_desc); + const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc = + static_cast(input_h_desc); + const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc = + static_cast(input_c_desc); + const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc = + static_cast(output_desc); + const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc = + static_cast(output_h_desc); + const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc = + static_cast(output_c_desc); + + return DoRnnBackwardImpl( + stream, miopen_rnn_desc, miopen_input_desc, input_data, miopen_input_h_desc, + input_h_data, miopen_input_c_desc, input_c_data, params, miopen_output_desc, + output_data, miopen_output_h_desc, output_h_data, miopen_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator); } bool MIOpenSupport::DoRnnBackward( @@ -1873,8 +1918,11 @@ bool MIOpenSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - LOG(ERROR) << "BN with x as half type not implemented yet"; - return false; + return DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kHalf, x, scale, offset, estimated_mean, + estimated_variance, x_desc, scale_offset_desc, epsilon, y, batch_mean, + batch_var, saved_mean, saved_inv_var, is_training, + std::move(var_to_inv_var), std::move(inv_var_to_var)); } bool MIOpenSupport::DoBatchNormalizationForward( @@ -1889,24 +1937,24 @@ bool MIOpenSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - return DoBatchNormalizationForwardImpl( + return DoBatchNormalizationForwardImpl( stream, dnn::DataType::kFloat, x, scale, offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var, is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)); } -template + template bool MIOpenSupport::DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType data_type, const DeviceMemory& x, - const DeviceMemory& scale, const DeviceMemory& offset, - const DeviceMemory& estimated_mean, - const DeviceMemory& estimated_variance, + const DeviceMemory& scale, const DeviceMemory& offset, + const DeviceMemory& estimated_mean, + const DeviceMemory& estimated_variance, const dnn::BatchDescriptor& x_desc, const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, - DeviceMemory* saved_mean, DeviceMemory* saved_inv_var, - bool is_training, std::function&()> var_to_inv_var, + DeviceMemory* y, DeviceMemory* batch_mean, DeviceMemory* batch_var, + DeviceMemory* saved_mean, DeviceMemory* saved_inv_var, + bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { mutex_lock lock{dnn_handle_mutex_}; auto status = wrap::miopenSetStream(parent_, ToHandle(dnn_handle_), @@ -1961,8 +2009,9 @@ bool MIOpenSupport::DoBatchNormalizationBackward( DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - LOG(ERROR) << "BN with y_backprop with half type not implemented yet"; - return false; + return DoBatchNormalizationBackwardImpl( + stream, miopenHalf, y_backprop, x, scale, mean, inv_var, x_desc, + scale_offset_desc, epsilon, x_backprop, scale_backprop, offset_backprop); } bool MIOpenSupport::DoBatchNormalizationBackward( @@ -1973,20 +2022,20 @@ bool MIOpenSupport::DoBatchNormalizationBackward( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - return DoBatchNormalizationBackwardImpl( + return DoBatchNormalizationBackwardImpl( stream, miopenFloat, y_backprop, x, scale, mean, variance, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, offset_backprop); } -template + template bool MIOpenSupport::DoBatchNormalizationBackwardImpl( Stream* stream, int miopen_type, const DeviceMemory& y_backprop, - const DeviceMemory& x, const DeviceMemory& scale, - const DeviceMemory& mean, const DeviceMemory& variance, + const DeviceMemory& x, const DeviceMemory& scale, + const DeviceMemory& mean, const DeviceMemory& variance, const dnn::BatchDescriptor& x_desc, const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - DeviceMemory* x_backprop, DeviceMemory* scale_backprop, - DeviceMemory* offset_backprop) { + DeviceMemory* x_backprop, DeviceMemory* scale_backprop, + DeviceMemory* offset_backprop) { mutex_lock lock{dnn_handle_mutex_}; auto status = wrap::miopenSetStream(parent_, ToHandle(dnn_handle_), AsROCMStreamValue(stream)); @@ -2061,8 +2110,11 @@ bool MIOpenSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + return DoConvolveImpl( + stream, miopenHalf, batch_descriptor, input_data, filter_descriptor, + filter_data, convolution_descriptor, + output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result); } bool MIOpenSupport::DoFusedConvolve( @@ -2421,8 +2473,11 @@ bool MIOpenSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + return DoConvolveBackwardDataImpl( + stream, miopenHalf, filter_descriptor, filter_data, + output_descriptor_in, backward_output_data, convolution_descriptor, + input_descriptor, backward_input_data, scratch_allocator, + algorithm_config, output_profile_result); } template @@ -2636,8 +2691,11 @@ bool MIOpenSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + return DoConvolveBackwardFilterImpl( + stream, miopenHalf, input_descriptor, input_data, + output_descriptor_in, backward_output_data, convolution_descriptor, + filter_descriptor, backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result); } template @@ -2700,8 +2758,9 @@ bool MIOpenSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + return DoConvolveBackwardBiasImpl(stream, miopenHalf, input_descriptor, + input_data, bias_descriptor, + backward_bias_data); } bool MIOpenSupport::DoMatMul(Stream* stream, @@ -2980,8 +3039,56 @@ bool MIOpenSupport::DoPoolForward( const dnn::BatchDescriptor& output_dimensions, DeviceMemory* output_data, ScratchAllocator* workspace_allocator) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + + auto status = wrap::miopenSetStream(parent_, ToHandle(dnn_handle_), + AsROCMStreamValue(stream)); + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to set stream for miopen handle: " << ToString(status); + return false; + } + + // Alpha is the scaling factor for input. + float alpha = 1.0; + // Beta is the scaling factor for output. + float beta = 0.0; + + ScopedTensorDescriptor src_desc{parent_, input_dimensions, miopenHalf}; + ScopedTensorDescriptor dest_desc{parent_, output_dimensions, + miopenHalf}; + ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; + + DeviceMemory workspace; + size_t workspace_size_in_bytes = 0; + status = wrap::miopenPoolingGetWorkSpaceSize(parent_, dest_desc.handle(), + &workspace_size_in_bytes); + + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to obtain workspace size for pooling on stream: " + << ToString(status); + return false; + } + + // Allocate the workspace. + if (workspace_size_in_bytes > 0) { + assert(workspace_allocator); + auto allocated = + workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "Failed to allocate pooling workspace"; + return false; + } + } + + status = wrap::miopenPoolingForward( + parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, + src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), + output_data->opaque(), true, workspace.opaque(), workspace_size_in_bytes); + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to enqueue forward pooling on stream: " + << ToString(status); + return false; + } + return true; } bool MIOpenSupport::DoPoolBackward( @@ -3106,8 +3213,96 @@ bool MIOpenSupport::DoPoolBackward( const DeviceMemory& input_diff_data, DeviceMemory* output_diff_data, ScratchAllocator* workspace_allocator) { - LOG(ERROR) << "miopen does not support half type yet"; - return false; + + mutex_lock lock{dnn_handle_mutex_}; + auto status = wrap::miopenSetStream(parent_, ToHandle(dnn_handle_), + AsROCMStreamValue(stream)); + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to set stream for miopen handle: " << ToString(status); + return false; + } + + // Alpha is the scaling factor for input. + float alpha = 1.0; + // Beta is the scaling factor for output. + float beta = 0.0; + + ScopedTensorDescriptor src_desc{parent_, input_dimensions, miopenHalf}; + ScopedTensorDescriptor dest_desc{parent_, output_dimensions, + miopenHalf}; + ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; + + DeviceMemory workspace; + size_t workspace_size_in_bytes = 0; + status = wrap::miopenPoolingGetWorkSpaceSize(parent_, dest_desc.handle(), + &workspace_size_in_bytes); + + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to obtain workspace size for backward pooling on stream: " + << ToString(status); + return false; + } + + // Allocate the workspace. + if (workspace_size_in_bytes > 0) { + assert(workspace_allocator); + auto allocated = + workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); + if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "Failed to allocate backward pooling workspace"; + return false; + } + } + + DeviceMemory dest2; // duplicated dest from forward: + int dest2_size = 0; + + // miopen requires the strides and dims to be ordered as BDYX. + std::vector dims64 = + output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX); + + // miopen does not use strides and must have 4D tensor. + std::vector dims(4); + + std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), + &CheckedNarrowing); + + dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float); + + if (dest2_size > 0) { + assert(workspace_allocator); + auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size); + if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "Failed to allocate backward pooling workspace"; + return false; + } + } else { + LOG(ERROR) << "Failed to calcuate tensor size to chain forward and backward pooling"; + } + + status = wrap::miopenPoolingForward( + parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, + src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), + dest2.opaque(), true, workspace.opaque(), workspace_size_in_bytes); + + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to enqueue forward pooling (before backward) on stream: " + << ToString(status); + return false; + } + + status = wrap::miopenPoolingBackward( + parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, + dest_desc.handle(), dest2.opaque(), dest_desc.handle(), + input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, + src_desc.handle(), output_diff_data->opaque(), workspace.opaque()); + + if (status != miopenStatusSuccess) { + LOG(ERROR) << "failed to enqueue backward pooling on stream: " + << ToString(status); + return false; + } + return true; } bool MIOpenSupport::DoNormalize( diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index e5e37361f5fd1d..f4c0cb5924fc6c 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -661,29 +661,29 @@ class MIOpenSupport : public dnn::DnnSupport { std::unique_ptr>* transform_scratch) EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_); - template + template bool DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType data_type, const DeviceMemory& x, - const DeviceMemory& scale, const DeviceMemory& offset, - const DeviceMemory& estimated_mean, - const DeviceMemory& estimated_variance, + const DeviceMemory& scale, const DeviceMemory& offset, + const DeviceMemory& estimated_mean, + const DeviceMemory& estimated_variance, const dnn::BatchDescriptor& x_desc, const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - DeviceMemory* y, DeviceMemory* batch_mean, - DeviceMemory* batch_var, DeviceMemory* saved_mean, - DeviceMemory* saved_inv_var, bool is_training, - std::function&()> var_to_inv_var, + DeviceMemory* y, DeviceMemory* batch_mean, + DeviceMemory* batch_var, DeviceMemory* saved_mean, + DeviceMemory* saved_inv_var, bool is_training, + std::function&()> var_to_inv_var, std::function inv_var_to_var); - template + template bool DoBatchNormalizationBackwardImpl( Stream* stream, int miopen_type, const DeviceMemory& y_backprop, - const DeviceMemory& x, const DeviceMemory& scale, - const DeviceMemory& mean, const DeviceMemory& variance, + const DeviceMemory& x, const DeviceMemory& scale, + const DeviceMemory& mean, const DeviceMemory& variance, const dnn::BatchDescriptor& x_desc, const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, - DeviceMemory* x_backprop, DeviceMemory* scale_backprop, - DeviceMemory* offset_backprop); + DeviceMemory* x_backprop, DeviceMemory* scale_backprop, + DeviceMemory* offset_backprop); template bool DoConvolveImpl(Stream* stream, @@ -767,13 +767,13 @@ class MIOpenSupport : public dnn::DnnSupport { const DeviceMemory& output_h_data, const MIOpenRnnStateTensorDescriptor& output_c_desc, const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator);