diff --git a/rocm_docs/core_kernels.md b/rocm_docs/core_kernels.md index 9152059d31d820..c972cbcc576e85 100644 --- a/rocm_docs/core_kernels.md +++ b/rocm_docs/core_kernels.md @@ -79,10 +79,10 @@ Kernels under tensorflow/core/kernels | O | | bias_op.h | O | | bias_op_gpu.cu.cc | O | | bias_op_gpu.h -| X |cub | bincount_op.cc -| X |cub | bincount_op.h -| X |cub | bincount_op_gpu.cu.cc -| X |cub | bincount_op_test.cc +| O |rocPRIM | bincount_op.cc +| O |rocPRIM | bincount_op.h +| O |rocPRIM | bincount_op_gpu.cu.cc +| O |rocPRIM | bincount_op_test.cc | O | | bitcast_op.cc | O | | bitcast_op.h | O | | boosted_trees @@ -557,9 +557,9 @@ Kernels under tensorflow/core/kernels | O | | hexagon/soc_interface.cc | O | | hexagon/soc_interface.h | O | | hinge-loss.h -| X |cub | histogram_op.cc -| X |cub | histogram_op.h -| X |cub | histogram_op_gpu.cu.cc +| O |rocPRIM | histogram_op.cc +| O |rocPRIM | histogram_op.h +| O |rocPRIM | histogram_op_gpu.cu.cc | O | | i_remote_fused_graph_executor.h | O | | i_remote_fused_graph_ops_definitions.cc | O | | i_remote_fused_graph_ops_definitions.h diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 380226dd07fe9a..615820833ec338 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -936,6 +936,8 @@ tf_kernel_library( deps = if_cuda([ ":cuda_solvers", "@cub_archive//:cub", + ]) + if_rocm([ + "@rocprim_archive//:rocprim", ]) + ARRAY_DEPS, ) @@ -1779,7 +1781,8 @@ tf_kernel_library( deps = DYNAMIC_DEPS + [ ":fill_functor", ":gather_functor", - ] + if_cuda(["@cub_archive//:cub"]), + ] + if_cuda(["@cub_archive//:cub"]) + + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_kernel_library( @@ -2911,7 +2914,7 @@ tf_kernel_library( name = "reduction_ops", gpu_srcs = ["reduction_gpu_kernels.cu.h"], prefix = "reduction_ops", - deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]), + deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_kernel_library( @@ -3416,7 +3419,9 @@ tf_kernel_library( deps = NN_DEPS + if_cuda([ ":reduction_ops", "@cub_archive//:cub", - ]), + ]) + if_rocm([ + "@rocprim_archive//:rocprim", + ]), ) tf_kernel_library( @@ -3434,7 +3439,7 @@ tf_kernel_library( tf_kernel_library( name = "topk_op", prefix = "topk_op", - deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]), + deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_kernel_library( @@ -3457,7 +3462,8 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]), + ] + if_cuda(["@cub_archive//:cub"]) + + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_kernel_library( @@ -3468,7 +3474,8 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]), + ] + if_cuda(["@cub_archive//:cub"]) + + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_kernel_library( @@ -3482,7 +3489,8 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_grad", "//tensorflow/core:nn_ops_op_lib", - ] + if_cuda(["@cub_archive//:cub"]), + ] + if_cuda(["@cub_archive//:cub"]) + + if_rocm(["@rocprim_archive//:rocprim"]), ) tf_gpu_cc_test( diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index 890fa3121bbf71..d2b531bae3da93 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -120,7 +120,7 @@ class BincountOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER(Name("Bincount") \ @@ -133,6 +133,6 @@ TF_CALL_int32(REGISTER_KERNELS); TF_CALL_float(REGISTER_KERNELS); #undef REGISTER_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc index 57a26e59d04c68..b9e01a34e1d43b 100644 --- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU +#if GOOGLE_CUDA #include "external/cub_archive/cub/device/device_histogram.cuh" +#elif TENSORFLOW_USE_ROCM +#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp" +#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -27,6 +31,12 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#if GOOGLE_CUDA +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +namespace gpuprim = ::hipcub; +#endif + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -55,11 +65,11 @@ struct BincountFunctor { int32 lower_level = 0; int32 upper_level = output.size(); int num_samples = arr.size(); - const cudaStream_t& stream = GetCudaStream(context); + const gpuStream_t& stream = GetGPUStream(context); // The first HistogramEven is to obtain the temp storage size required // with d_temp_storage = NULL passed to the call. - auto err = cub::DeviceHistogram::HistogramEven( + auto err = gpuprim::DeviceHistogram::HistogramEven( /* d_temp_storage */ NULL, /* temp_storage_bytes */ temp_storage_bytes, /* d_samples */ d_samples, @@ -69,10 +79,10 @@ struct BincountFunctor { /* upper_level */ upper_level, /* num_samples */ num_samples, /* stream */ stream); - if (err != cudaSuccess) { + if (err != gpuSuccess) { return errors::Internal( "Could not launch HistogramEven to get temp storage: ", - cudaGetErrorString(err), "."); + GPUGETERRORSTRING(err), "."); } Tensor temp_storage; TF_RETURN_IF_ERROR(context->allocate_temp( @@ -82,7 +92,7 @@ struct BincountFunctor { void* d_temp_storage = temp_storage.flat().data(); // The second HistogramEven is to actual run with d_temp_storage // allocated with temp_storage_bytes. - err = cub::DeviceHistogram::HistogramEven( + err = gpuprim::DeviceHistogram::HistogramEven( /* d_temp_storage */ d_temp_storage, /* temp_storage_bytes */ temp_storage_bytes, /* d_samples */ d_samples, @@ -92,9 +102,9 @@ struct BincountFunctor { /* upper_level */ upper_level, /* num_samples */ num_samples, /* stream */ stream); - if (err != cudaSuccess) { + if (err != gpuSuccess) { return errors::Internal( - "Could not launch HistogramEven: ", cudaGetErrorString(err), "."); + "Could not launch HistogramEven: ", GPUGETERRORSTRING(err), "."); } return Status::OK(); } diff --git a/tensorflow/core/kernels/determinant_op_gpu.cu.cc b/tensorflow/core/kernels/determinant_op_gpu.cu.cc index c866204c97e6ac..f4487c0c17f6b6 100644 --- a/tensorflow/core/kernels/determinant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/determinant_op_gpu.cu.cc @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_solvers.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 862a97723fd644..45f9a1da2650b6 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -47,7 +47,7 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/gather_functor_gpu.cu.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/histogram_op.cc b/tensorflow/core/kernels/histogram_op.cc index 4e035286f6f154..75f896b407d35e 100644 --- a/tensorflow/core/kernels/histogram_op.cc +++ b/tensorflow/core/kernels/histogram_op.cc @@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel { TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_KERNELS(type) \ REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \ .Device(DEVICE_GPU) \ @@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // end namespace tensorflow diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index a88e9b0ddcdda6..ab0a561fa2d77e 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -13,12 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#if GOOGLE_CUDA #include "external/cub_archive/cub/device/device_histogram.cuh" +#elif TENSORFLOW_USE_ROCM +#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp" +#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -26,7 +30,13 @@ limitations under the License. #include "tensorflow/core/kernels/histogram_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +#if GOOGLE_CUDA +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +namespace gpuprim = ::hipcub; +#endif namespace tensorflow { @@ -66,11 +76,11 @@ struct HistogramFixedWidthFunctor { int num_levels = levels.size(); T* d_levels = levels.data(); int num_samples = values.size(); - const cudaStream_t& stream = GetCudaStream(context); + const gpuStream_t& stream = GetGPUStream(context); // The first HistogramRange is to obtain the temp storage size required // with d_temp_storage = NULL passed to the call. - auto err = cub::DeviceHistogram::HistogramRange( + auto err = gpuprim::DeviceHistogram::HistogramRange( /* d_temp_storage */ NULL, /* temp_storage_bytes */ temp_storage_bytes, /* d_samples */ d_samples, @@ -79,10 +89,10 @@ struct HistogramFixedWidthFunctor { /* d_levels */ d_levels, /* num_samples */ num_samples, /* stream */ stream); - if (err != cudaSuccess) { + if (err != gpuSuccess) { return errors::Internal( "Could not launch HistogramRange to get temp storage: ", - cudaGetErrorString(err), "."); + GPUGETERRORSTRING(err), "."); } Tensor temp_storage; @@ -94,7 +104,7 @@ struct HistogramFixedWidthFunctor { // The second HistogramRange is to actual run with d_temp_storage // allocated with temp_storage_bytes. - err = cub::DeviceHistogram::HistogramRange( + err = gpuprim::DeviceHistogram::HistogramRange( /* d_temp_storage */ d_temp_storage, /* temp_storage_bytes */ temp_storage_bytes, /* d_samples */ d_samples, @@ -103,9 +113,9 @@ struct HistogramFixedWidthFunctor { /* d_levels */ d_levels, /* num_samples */ num_samples, /* stream */ stream); - if (err != cudaSuccess) { + if (err != gpuSuccess) { return errors::Internal( - "Could not launch HistogramRange: ", cudaGetErrorString(err), "."); + "Could not launch HistogramRange: ", GPUGETERRORSTRING(err), "."); } return Status::OK(); diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 0de2ebb5907caa..eff2f7b17b6773 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -26,7 +26,7 @@ limitations under the License. #include "cuda/include/cuComplex.h" #include "tensorflow/core/kernels/reduction_ops.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/transform_output_iterator.h" diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index b63dcbb163b1b7..e177705598762b 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc index 8c3a58b108abe6..5ba0da28d5d86f 100644 --- a/tensorflow/core/kernels/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc @@ -43,7 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc index ca296d5aa044d6..c0b5eabfa6c3a6 100644 --- a/tensorflow/core/kernels/topk_op_gpu.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" // Required for sorting Eigen::half namespace cub { diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index 57f51889de94d9..4172f75a61b5ce 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/where_op.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index 54223c0e72ef77..78461f5f8e8b18 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -45,6 +45,26 @@ limitations under the License. #define GPU_AXIS_KERNEL_LOOP(i, n, axis) \ for (int i : ::tensorflow::GpuGridRange##axis(n)) +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +#define GPUGETERRORSTRING(error) cudaGetErrorString(error) +using gpuStream_t = cudaStream_t; +#define GPUGETLASTERROR() cudaGetLastError() +using gpuError_t = cudaError_t; +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +#define GPUGETERRORSTRING(error) hipGetErrorString(error) +using gpuStream_t = hipStream_t; +#define GPUGETLASTERROR() hipGetLastError() +using gpuError_t = hipError_t; +#endif + +#if GOOGLE_CUDA +#define GetGPUStream(context) GetCudaStream(context) +#elif TENSORFLOW_USE_ROCM +#define GetGPUStream(context) context->eigen_gpu_device().stream() +#endif + namespace tensorflow { __host__ __device__ inline tensorflow::bfloat16 GpuLdg( const tensorflow::bfloat16* address) { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 944de58f26e510..8411240548b97f 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -654,6 +654,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = clean_dep("//third_party:cub.BUILD"), ) + tf_http_archive( + name = "rocprim_archive", + urls = [ + "https://mirror.bazel.build/github.com/ROCmSoftwarePlatform/rocPRIM/archive/563461f.zip", + "https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/563461f.zip", + ], + sha256 = "64d340057649d7643cb6ae158b168dd6d48d2b1dd77297fbdfcb6249528e0707", + strip_prefix = "rocPRIM-563461f3def38e30bcc53d9bf37b2e12f494ab99", + build_file = clean_dep("//third_party:rocprim.BUILD"), + ) + tf_http_archive( name = "cython", sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", diff --git a/third_party/rocprim.BUILD b/third_party/rocprim.BUILD new file mode 100644 index 00000000000000..a9ec83159f4f60 --- /dev/null +++ b/third_party/rocprim.BUILD @@ -0,0 +1,53 @@ +# Description: rocPRIM library which is a set of primitives for GPU programming on AMD ROCm stack. + +licenses(["notice"]) # BSD + +exports_files(["LICENSE.TXT"]) + +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_default_copts", "if_rocm") + +filegroup( + name = "rocprim_headers", + srcs = glob([ + "hipcub/include/**", + "rocprim/include/**", + ]), +) + +cc_library( + name = "rocprim", + hdrs = if_rocm([":rocprim_headers"]), + srcs= ["rocprim_version.hpp", "hipcub_version.hpp"], + deps = [ + "@local_config_rocm//rocm:rocm_headers", + ], + includes = ["hipcub/include", + "rocprim/include", + "rocprim/include/rocprim", + ".",], + visibility = ["//visibility:public"], +) + +genrule( + name = "rocprim_version_hpp", + message = "Creating rocPRIM version header...", + srcs = ["rocprim/include/rocprim/rocprim_version.hpp.in"], + outs = ["rocprim_version.hpp"], + cmd = ("sed " + + "-e 's/@rocprim_VERSION_MAJOR@/0/g' " + + "-e 's/@rocprim_VERSION_MINOR@/3/g' " + + "-e 's/@rocprim_VERSION_PATCH@/0/g' " + + "$< >$@"), +) + +genrule( + name = "hipcub_version_hpp", + message = "Creating hipcub version header...", + srcs = ["hipcub/include/hipcub/hipcub_version.hpp.in"], + outs = ["hipcub_version.hpp"], + cmd = ("sed " + + "-e 's/@rocprim_VERSION_MAJOR@/0/g' " + + "-e 's/@rocprim_VERSION_MINOR@/3/g' " + + "-e 's/@rocprim_VERSION_PATCH@/0/g' " + + "$< >$@"), +)