Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tensorflow/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,12 @@ struct ProtoHelper<Eigen::half> {
static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
proto->mutable_half_val()->Reserve(n);
for (size_t i = 0; i < n; ++i) {
#if defined(TENSORFLOW_USE_ROCM_HIP_FP16)
// Implementation of Eigen::half is different when using HIP FP16 on the GPU
proto->mutable_half_val()->AddAlreadyReserved(__half_as_ushort(data[i].x));
#else
proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
#endif
}
}
};
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/lib/random/random_distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,15 @@ PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16 x) {
const uint16 val = (exp << 10) | man;

Eigen::half result;

// The underlying Eigen::Half implementation is different on the CPU vs the GPU
// So need to handle this assignment differently depending on what we are compiling for
#if defined(TENSORFLOW_USE_ROCM_HIP_FP16)
result.x = __ushort_as_half(val);
#else
result.x = val;
#endif

return result - Eigen::half(1.0);
}

Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/platform/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,17 @@ limitations under the License.
#define PLATFORM_IS_X86
#endif

// Some of the Tensorflow code reaches into the implmentation of Eigen::half
// The implementation of Eigen::half is platform dependent in ROCm,
// i.e. it is different for CPU vs GPU
// Therefore the Tensorflow code that reaches into the Eigen::half implemenation
// needs to be different based on the platform we are compiling it for.
// Creating a TENSORFLOW_USE_ROCM_HIP_FP16 macro for that purpose
#if defined(TENSORFLOW_USE_ROCM)
#if defined(__HIPCC__) && defined(__HIP_DEVICE_COMPILE__)
#define TENSORFLOW_USE_ROCM_HIP_FP16
#endif
#endif


#endif // TENSORFLOW_PLATFORM_PLATFORM_DEFINE_H_
10 changes: 10 additions & 0 deletions tensorflow/core/util/gpu_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,12 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short high = static_cast<unsigned short>(arg >> 16);
Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
#if defined(TENSORFLOW_USE_ROCM_HIP_FP16)
// Implementation of Eigen::half is different when using HIP FP16 on the GPU
return (static_cast<uint32>(__half_as_ushort(acc.x)) << 16) | (arg & 0xffff);
#else
return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
#endif
});
return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
} else {
Expand All @@ -509,7 +514,12 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short low = static_cast<unsigned short>(arg & 0xffff);
Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
#if defined(TENSORFLOW_USE_ROCM_HIP_FP16)
// Implementation of Eigen::half is different when using HIP FP16 on the GPU
return (arg & 0xffff0000) | static_cast<uint32>(__half_as_ushort(acc.x));
#else
return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
#endif
});
return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
}
Expand Down
13 changes: 12 additions & 1 deletion tensorflow/core/util/port.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,22 @@ bool IsGoogleCudaEnabled() {
#endif
}

bool CudaSupportsHalfMatMulAndConv() {
bool IsBuiltWithROCm() {
#if TENSORFLOW_USE_ROCM
return true;
#else
return false;
#endif
}


bool GpuSupportsHalfMatMulAndConv() {
#if GOOGLE_CUDA
// NOTE: We check compile-time and not runtime, since the check for
// whether we include the fp16 kernels or not is compile-time.
return CUDA_VERSION >= 7050;
#elif TENSORFLOW_USE_ROCM
return true;
#else
return false;
#endif
Expand Down
16 changes: 13 additions & 3 deletions tensorflow/core/util/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,19 @@ namespace tensorflow {
// Returns true if GOOGLE_CUDA is defined.
bool IsGoogleCudaEnabled();

// Returns true if GOOGLE_CUDA is defined, and the given CUDA version supports
// half-precision matrix multiplications and convolution operations.
bool CudaSupportsHalfMatMulAndConv();
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
bool IsBuiltWithROCm();

// Returns true if either
//
// GOOGLE_CUDA is defined, and the given CUDA version supports
// half-precision matrix multiplications and convolution operations.
//
// OR
//
// TENSORFLOW_USE_ROCM is defined
//
bool GpuSupportsHalfMatMulAndConv();

// Returns true if INTEL_MKL is defined
bool IsMklEnabled();
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/util/saved_tensor_slice_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
typename protobuf::RepeatedField<int32>* val = t->mutable_half_val();
val->Resize(n, 0);
for (size_t i = 0; i < n; ++i) {
#if defined(TENSORFLOW_USE_ROCM_HIP_FP16)
// Implementation of Eigen::half is different when using HIP FP16 on the GPU
val->Set(i, __half_as_ushort(data[i].x));
#else
val->Set(i, data[i].x);
#endif
}
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/docs_src/api_guides/python/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ depending on the python version.
* @{tf.test.assert_equal_graph_def}
* @{tf.test.get_temp_dir}
* @{tf.test.is_built_with_cuda}
* @{tf.test.is_built_with_rocm}
* @{tf.test.is_built_with_gpu_support}
* @{tf.test.is_gpu_available}
* @{tf.test.gpu_device_name}

Expand Down
23 changes: 21 additions & 2 deletions tensorflow/python/framework/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,12 @@ def IsGoogleCudaEnabled():
return pywrap_tensorflow.IsGoogleCudaEnabled()


def CudaSupportsHalfMatMulAndConv():
return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
def IsBuiltWithROCm():
return pywrap_tensorflow.IsBuiltWithROCm()


def GpuSupportsHalfMatMulAndConv():
return pywrap_tensorflow.GpuSupportsHalfMatMulAndConv()


def IsMklEnabled():
Expand Down Expand Up @@ -740,6 +744,21 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
CUDA compute capability required, or None if no requirement.

Note that the keyword arg name "cuda_only" is misleading (since routine will
return true when a GPU device is available irrespective of whether TF was
built with CUDA support or ROCm support. However no changes there because

++ Changing the name "cuda_only" to something more generic would break
backward compatibility

++ Adding an equivalent "rocm_only" would require the implementation check
the build type. This in turn would require doing the same for CUDA and thus
potentially break backward compatibility

++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, but
would require most (if not all) callers to update the call to use
"cuda_or_rocm_only" instead of "cuda_only"

Returns:
True iff a gpu device of the requested kind is available.
"""
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/python/framework/test_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def testIsGoogleCudaEnabled(self):
else:
print("GoogleCuda is disabled")

def testIsBuiltWithROCm(self):
# The test doesn't assert anything. It ensures the py wrapper
# function is generated correctly.
if test_util.IsBuiltWithROCm():
print("Tensorflow build has ROCm support")
else:
print("Tensorflow build does not have ROCm support")

def testIsMklEnabled(self):
# This test doesn't assert anything.
# It ensures the py wrapper function is generated correctly.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/conv_ops_3d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Conv3DTest(test.TestCase):

def _DtypesToTest(self, use_gpu):
if use_gpu:
if not test_util.CudaSupportsHalfMatMulAndConv():
if not test_util.GpuSupportsHalfMatMulAndConv():
return [dtypes.float32]
else:
# It is important that float32 comes before float16 here,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/conv_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def GetTestConfigs():
class Conv2DTest(test.TestCase):

def _DtypesToTest(self, use_gpu):
if use_gpu and not test_util.CudaSupportsHalfMatMulAndConv():
if use_gpu and not test_util.GpuSupportsHalfMatMulAndConv():
return [dtypes.float32, dtypes.float64]
else:
# It is important that float32 comes before float16 here,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/dense_update_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _testTypes(self, vals):
var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
self.assertAllEqual(x - y, var_value)
self.assertAllEqual(x - y, op_value)
if test.is_built_with_cuda() and dtype in [np.float32, np.float64]:
if test.is_built_with_gpu_support() and dtype in [np.float32, np.float64]:
var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
self.assertAllEqual(y, var_value)
self.assertAllEqual(y, op_value)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/matmul_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def Test(self):

use_gpu = True
if a_np_.dtype is np.float16 and (
not test_util.CudaSupportsHalfMatMulAndConv()):
not test_util.GpuSupportsHalfMatMulAndConv()):
use_gpu = False
print("Built without fp16 matmul support for Cuda, running test on CPU.")

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/pooling_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float64, expected, use_gpu, v2)

if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv():
if not use_gpu or test_util.GpuSupportsHalfMatMulAndConv():
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float16, expected, use_gpu, v2)

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/kernel_tests/softmax_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def testFloat(self):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32))

@unittest.skipUnless(test.is_built_with_cuda(),
@unittest.skipUnless(test.is_built_with_gpu_support(),
"Test only applicable when running on GPUs")
def testFloatGPU(self):
if test.is_gpu_available(cuda_only=True):
Expand All @@ -142,7 +142,7 @@ def testHalf(self):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16))

@unittest.skipUnless(test.is_built_with_cuda(),
@unittest.skipUnless(test.is_built_with_gpu_support(),
"Test only applicable when running on GPUs")
def testHalfGPU(self):
if test.is_gpu_available(cuda_only=True):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/sparse_xent_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def testInvalidLabel(self):
[1., 2., 3., 4.]]
labels = [4, 3, 0, -1]

if test.is_built_with_cuda() and test.is_gpu_available():
if test.is_built_with_gpu_support() and test.is_gpu_available():
with self.test_session(use_gpu=True) as sess:
loss, backprop = (
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/python/platform/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
@@assert_equal_graph_def
@@get_temp_dir
@@is_built_with_cuda
@@is_built_with_rocm
@@is_built_with_gpu_support
@@is_gpu_available
@@gpu_device_name
@@compute_gradient
Expand Down Expand Up @@ -107,3 +109,17 @@ def test_src_dir_path(relative_path):
def is_built_with_cuda():
"""Returns whether TensorFlow was built with CUDA (GPU) support."""
return _test_util.IsGoogleCudaEnabled()


@tf_export('test.is_built_with_rocm')
def is_built_with_rocm():
"""Returns whether TensorFlow was built with ROCm (GPU) support."""
return _test_util.IsBuiltWithROCm()


@tf_export('test.is_built_with_gpu_support')
def is_built_with_gpu_support():
"""Returns whether TensorFlow was built with GPU (either CUDA or ROCm) support.
"""
return is_built_with_cuda() or is_built_with_rocm()

3 changes: 2 additions & 1 deletion tensorflow/python/util/port.i
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ limitations under the License.
%ignoreall
%unignore tensorflow;
%unignore tensorflow::IsGoogleCudaEnabled;
%unignore tensorflow::CudaSupportsHalfMatMulAndConv;
%unignore tensorflow::IsBuiltWithROCm;
%unignore tensorflow::GpuSupportsHalfMatMulAndConv;
%unignore tensorflow::IsMklEnabled;
%include "tensorflow/core/util/port.h"
%unignoreall
Loading