@@ -16,12 +16,12 @@ limitations under the License.
1616// The algorithm for dynamic partition has the following steps:
1717// 1. Let N be the size of partitions. We initialize a new vector indices_in
1818// with the values 0, 1, 2, ..., N-1.
19- // 2. We apply cub ::DeviceRadixSort::SortPairs to the key - value pairs given
19+ // 2. We apply gpuprim ::DeviceRadixSort::SortPairs to the key - value pairs given
2020// by partitions and indices_in. This will result in two new vectors
2121// partitions_out and indices_out, with partitions_out sorted.
2222// 3. The first dimension of outputs[i] is equal to the number of i-values in
2323// partitions_out. We determine it in two steps:
24- // - apply cub ::DeviceReduce::ReduceByKey to count how many times each value
24+ // - apply gpuprim ::DeviceReduce::ReduceByKey to count how many times each value
2525// appears in partitions_out,
2626// - move the results to partition_count. This handles missing values
2727// (corresponding to empty parts).
@@ -31,14 +31,18 @@ limitations under the License.
3131// This works, because for each interval of i-values, indices_out points
3232// to the slices which should form output[i].
3333
34- #if GOOGLE_CUDA
34+ #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3535
3636#define EIGEN_USE_GPU
3737
38+ #if GOOGLE_CUDA
3839#include " external/cub_archive/cub/device/device_radix_sort.cuh"
3940#include " external/cub_archive/cub/device/device_reduce.cuh"
4041#include " external/cub_archive/cub/iterator/constant_input_iterator.cuh"
4142#include " external/cub_archive/cub/thread/thread_operators.cuh"
43+ #elif TENSORFLOW_USE_ROCM
44+ #include " external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
45+ #endif
4246#include " tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
4347#include " tensorflow/core/framework/op_kernel.h"
4448#include " tensorflow/core/framework/register_types.h"
@@ -50,6 +54,12 @@ limitations under the License.
5054#include " tensorflow/core/util/gpu_kernel_helper.h"
5155#include " tensorflow/core/util/transform_output_iterator.h"
5256
57+ #if GOOGLE_CUDA
58+ namespace gpuprim = ::cub;
59+ #elif TENSORFLOW_USE_ROCM
60+ namespace gpuprim = ::hipcub;
61+ #endif
62+
5363namespace tensorflow {
5464
5565typedef Eigen::GpuDevice GPUDevice;
@@ -59,14 +69,14 @@ namespace {
5969template <typename T>
6070__global__ void RangeInitKernel (const T start, const T delta, const int32 size,
6171 T* out) {
62- CUDA_1D_KERNEL_LOOP (i, size) { out[i] = start + i * delta; }
72+ GPU_1D_KERNEL_LOOP (i, size) { out[i] = start + i * delta; }
6373}
6474
6575__global__ void MoveValuesKernel (const int32* keys, const int32* values,
6676 const int32* size, int32 out_size,
6777 int32* out) {
6878 int32 N = min (ldg (size), out_size);
69- CUDA_1D_KERNEL_LOOP (i, N) {
79+ GPU_1D_KERNEL_LOOP (i, N) {
7080 int32 key = ldg (keys + i);
7181 int32 value = ldg (values + i);
7282 if (FastBoundsCheck (key, out_size)) out[key] = value;
@@ -78,9 +88,9 @@ __global__ void MoveValuesKernel(const int32* keys, const int32* values,
7888template <typename T>
7989void RangeInit (const GPUDevice& d, const T start, const T delta,
8090 const int32 size, typename TTypes<T>::Flat out) {
81- CudaLaunchConfig config = GetCudaLaunchConfig (size, d);
82- RangeInitKernel<T>
83- <<<config. block_count , config.thread_per_block , 0 , d.stream ()>>>(
91+ GpuLaunchConfig config = GetGpuLaunchConfig (size, d);
92+ GPU_LAUNCH_KERNEL ( RangeInitKernel<T>, dim3 (config. block_count ),
93+ dim3 ( config.thread_per_block ) , 0 , d.stream (),
8494 start, delta, size, out.data ());
8595}
8696
@@ -93,18 +103,19 @@ void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
93103 // This is valid for correct inputs, because then out_size >= *num_runs.
94104 // For wrong inputs, we may have out_size < *num_runs. In this case we will
95105 // only handle the first out_size values.
96- CudaLaunchConfig config = GetCudaLaunchConfig (out_size, d);
97- MoveValuesKernel<<<config.block_count , config.thread_per_block , 0 ,
98- d.stream ()>>>(keys, values, num_runs, out_size, out);
106+ GpuLaunchConfig config = GetGpuLaunchConfig (out_size, d);
107+ GPU_LAUNCH_KERNEL (MoveValuesKernel, dim3 (config.block_count ),
108+ dim3 (config.thread_per_block ), 0 , d.stream (), keys, values,
109+ num_runs, out_size, out);
99110}
100111
101112template <typename T>
102113void CallGatherKernel (const GPUDevice& d, const T* params, const int32* indices,
103114 T* out, int64 gather_dim_size, int64 indices_size,
104115 int64 slice_size, int64 out_size) {
105- CudaLaunchConfig config = GetCudaLaunchConfig (out_size, d);
106- GatherOpKernel<T, int32, true >
107- <<< config.block_count , config.thread_per_block , 0 , d.stream ()>>>(
116+ GpuLaunchConfig config = GetGpuLaunchConfig (out_size, d);
117+ GPU_LAUNCH_KERNEL ( GatherOpKernel<T, int32, true >,
118+ dim3 ( config.block_count ), dim3 ( config.thread_per_block ) , 0 , d.stream (),
108119 params, indices, out, gather_dim_size, indices_size, slice_size,
109120 out_size);
110121}
@@ -180,7 +191,7 @@ class BoundedOutputIterator
180191// I + P + max(3N + R + P, O + N), where:
181192// I - the size of the input
182193// N - the size of the partitions tensor
183- // R - the temporary storage used by cub ::RadixSort, about 2N
194+ // R - the temporary storage used by gpuprim ::RadixSort, about 2N
184195// P - the number of partitions
185196// O - the size of the output
186197// So roughly the cost is I + P + max(5N, O + N).
@@ -325,7 +336,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
325336 Tensor* indices_out, DoneCallback done) {
326337 int32 N = partitions->NumElements ();
327338 const GPUDevice& device = c->eigen_device <GPUDevice>();
328- const cudaStream_t & cu_stream = GetCudaStream (c);
339+ const gpuStream_t & cu_stream = GetGPUStream (c);
329340
330341 // Initialize the indices_in tensor using the Range GPU kernel.
331342 RangeInit (device, 0 , 1 , N, indices_in->flat <int32>());
@@ -337,7 +348,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
337348 // Determine temporary device storage requirements.
338349 Tensor cub_temp_storage;
339350 size_t temp_storage_bytes = 0 ;
340- cub ::DeviceRadixSort::SortPairs (
351+ gpuprim ::DeviceRadixSort::SortPairs (
341352 NULL , temp_storage_bytes, partitions_ptr, partitions_out_ptr,
342353 indices_in_ptr, indices_out_ptr, N, 0 , sizeof (int32) * 8 , cu_stream);
343354 // Allocate temporary storage.
@@ -348,7 +359,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
348359 &cub_temp_storage),
349360 done);
350361 // Radix-sort the partition information.
351- cub ::DeviceRadixSort::SortPairs (
362+ gpuprim ::DeviceRadixSort::SortPairs (
352363 cub_temp_storage.flat <int8>().data (), temp_storage_bytes,
353364 partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
354365 0 , sizeof (int32) * 8 , cu_stream);
@@ -358,7 +369,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
358369 Tensor* partition_count, Tensor* indices_out,
359370 DoneCallback done) {
360371 const GPUDevice& device = c->eigen_device <GPUDevice>();
361- const cudaStream_t & cu_stream = GetCudaStream (c);
372+ const gpuStream_t & cu_stream = GetGPUStream (c);
362373 int32 N = partitions->NumElements ();
363374 Tensor indices_in;
364375 Tensor partitions_out;
@@ -395,8 +406,13 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
395406 BoundedOutputIterator aggregates_out_it (aggregates_out_ptr, id_op,
396407 num_partitions_);
397408
409+ #if GOOGLE_CUDA
398410 cub::ConstantInputIterator<int32> values_in (1 );
399- cub::Sum reduction_op;
411+ #elif TENSORFLOW_USE_ROCM
412+ using ConstantInputIterator = ::rocprim::constant_iterator<int32, ptrdiff_t >;
413+ ConstantInputIterator values_in (1 );
414+ #endif
415+ gpuprim::Sum reduction_op;
400416
401417 // Allocate space on GPU for the number of runs. This is required by CUB.
402418 Tensor num_runs;
@@ -407,7 +423,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
407423 // Determine temporary device storage requirements
408424 Tensor cub_temp_storage;
409425 size_t temp_storage_bytes = 0 ;
410- cub ::DeviceReduce::ReduceByKey (NULL , temp_storage_bytes, keys_in_ptr,
426+ gpuprim ::DeviceReduce::ReduceByKey (NULL , temp_storage_bytes, keys_in_ptr,
411427 unique_out_it, values_in, aggregates_out_it,
412428 num_runs_ptr, reduction_op, N, cu_stream);
413429 // Allocate temporary storage.
@@ -421,7 +437,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
421437 // each index appears in partitions. The distinct indices are stored
422438 // in unique_out, while the count is stored in aggregates_out.
423439 // The total number of distinct indices is stored in num_runs.
424- cub ::DeviceReduce::ReduceByKey (cub_temp_storage.flat <int8>().data (),
440+ gpuprim ::DeviceReduce::ReduceByKey (cub_temp_storage.flat <int8>().data (),
425441 temp_storage_bytes, keys_in_ptr,
426442 unique_out_it, values_in, aggregates_out_it,
427443 num_runs_ptr, reduction_op, N, cu_stream);
@@ -467,4 +483,4 @@ TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);
467483
468484} // namespace tensorflow
469485
470- #endif // GOOGLE_CUDA
486+ #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
0 commit comments