Skip to content

Commit b056653

Browse files
authored
Merge pull request #23 from iotamudelta/cubkernels_v2
Cubkernels v2
2 parents 8b9a948 + d7ea152 commit b056653

File tree

4 files changed

+119
-91
lines changed

4 files changed

+119
-91
lines changed

rocm_docs/core_kernels.md

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ Kernels under tensorflow/core/kernels
441441
| O | | dilation_ops_gpu.cu.cc
442442
| O | | draw_bounding_box_op.cc
443443
| O | | dynamic_partition_op.cc
444-
| X |cub | dynamic_partition_op_gpu.cu.cc
444+
| P |rocPRIM | dynamic_partition_op_gpu.cu.cc
445445
| O | | dynamic_partition_op_test.cc
446446
| O | | dynamic_stitch_op.cc
447447
| O | | dynamic_stitch_op_gpu.cu.cc
@@ -581,9 +581,9 @@ Kernels under tensorflow/core/kernels
581581
| O | | inplace_ops.cc
582582
| O | | inplace_ops_functor.h
583583
| O | | inplace_ops_functor_gpu.cu.cc
584-
| P |cub | l2loss_op.cc
585-
| P |cub | l2loss_op.h
586-
| P |cub | l2loss_op_gpu.cu.cc
584+
| P |rocPRIM | l2loss_op.cc
585+
| P |rocPRIM | l2loss_op.h
586+
| P |rocPRIM | l2loss_op_gpu.cu.cc
587587
| O | | linalg_ops_common.cc
588588
| O | | linalg_ops_common.h
589589
| O | | list_kernels.cc
@@ -796,26 +796,26 @@ Kernels under tensorflow/core/kernels
796796
| O | | record_yielder.cc
797797
| O | | record_yielder.h
798798
| O | | reduce_join_op.cc
799-
| P |cub | reduction_gpu_kernels.cu.h
800-
| P |cub | reduction_ops.h
801-
| P |cub | reduction_ops_all.cc
802-
| P |cub | reduction_ops_any.cc
803-
| P |cub | reduction_ops_common.cc
804-
| P |cub | reduction_ops_common.h
805-
| P |cub | reduction_ops_gpu_bool.cu.cc
799+
| P |rocPRIM | reduction_gpu_kernels.cu.h
800+
| P |rocPRIM | reduction_ops.h
801+
| P |rocPRIM | reduction_ops_all.cc
802+
| P |rocPRIM | reduction_ops_any.cc
803+
| P |rocPRIM | reduction_ops_common.cc
804+
| P |rocPRIM | reduction_ops_common.h
805+
| P |rocPRIM | reduction_ops_gpu_bool.cu.cc
806806
| X |cub | reduction_ops_gpu_complex128.cu.cc
807807
| X |cub | reduction_ops_gpu_complex64.cu.cc
808-
| P |cub | reduction_ops_gpu_double.cu.cc
809-
| P |cub | reduction_ops_gpu_float.cu.cc
810-
| P |cub | reduction_ops_gpu_int.cu.cc
811-
| P |cub | reduction_ops_half_mean_sum.cu.cc
812-
| P |cub | reduction_ops_half_prod_max_min.cu.cc
813-
| P |cub | reduction_ops_max.cc
814-
| P |cub | reduction_ops_mean.cc
815-
| P |cub | reduction_ops_min.cc
816-
| P |cub | reduction_ops_prod.cc
817-
| P |cub | reduction_ops_sum.cc
818-
| P |cub | reduction_ops_test.cc
808+
| P |rocPRIM | reduction_ops_gpu_double.cu.cc
809+
| P |rocPRIM | reduction_ops_gpu_float.cu.cc
810+
| P |rocPRIM | reduction_ops_gpu_int.cu.cc
811+
| P |rocPRIM | reduction_ops_half_mean_sum.cu.cc
812+
| P |rocPRIM | reduction_ops_half_prod_max_min.cu.cc
813+
| P |rocPRIM | reduction_ops_max.cc
814+
| P |rocPRIM | reduction_ops_mean.cc
815+
| P |rocPRIM | reduction_ops_min.cc
816+
| P |rocPRIM | reduction_ops_prod.cc
817+
| P |rocPRIM | reduction_ops_sum.cc
818+
| P |rocPRIM | reduction_ops_test.cc
819819
| O | | reference_gemm.h
820820
| O | | regex_replace_op.cc
821821
| O | | relu_op.cc
@@ -940,9 +940,9 @@ Kernels under tensorflow/core/kernels
940940
| O | | snapshot_op.cc
941941
| O | | snapshot_op.h
942942
| O | | snapshot_op_gpu.cu.cc
943-
| P |cub | softmax_op.cc
944-
| P |cub | softmax_op_functor.h
945-
| P |cub | softmax_op_gpu.cu.cc
943+
| P |rocPRIM | softmax_op.cc
944+
| P |rocPRIM | softmax_op_functor.h
945+
| P |rocPRIM | softmax_op_gpu.cu.cc
946946
| O | | softplus_op.cc
947947
| O | | softplus_op.h
948948
| O | | softplus_op_gpu.cu.cc
@@ -1107,13 +1107,13 @@ Kernels under tensorflow/core/kernels
11071107
| O | | warn_about_ints.cc
11081108
| O | | warn_about_ints.h
11091109
| X |cub | where_op.cc
1110-
| X |cub | where_op.h
1111-
| X |cub | where_op_gpu.cu.h
1112-
| X |cub | where_op_gpu_impl_1.cu.cc
1113-
| X |cub | where_op_gpu_impl_2.cu.cc
1114-
| X |cub | where_op_gpu_impl_3.cu.cc
1115-
| X |cub | where_op_gpu_impl_4.cu.cc
1116-
| X |cub | where_op_gpu_impl_5.cu.cc
1110+
| O | | where_op.h
1111+
| P |rocPRIM | where_op_gpu.cu.h
1112+
| P |rocPRIM | where_op_gpu_impl_1.cu.cc
1113+
| P |rocPRIM | where_op_gpu_impl_2.cu.cc
1114+
| P |rocPRIM | where_op_gpu_impl_3.cu.cc
1115+
| P |rocPRIM | where_op_gpu_impl_4.cu.cc
1116+
| P |rocPRIM | where_op_gpu_impl_5.cu.cc
11171117
| O | | whole_file_read_ops.cc
11181118
| O | | winograd_transform.h
11191119
| O | | word2vec_kernels.cc

tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ limitations under the License.
1616
// The algorithm for dynamic partition has the following steps:
1717
// 1. Let N be the size of partitions. We initialize a new vector indices_in
1818
// with the values 0, 1, 2, ..., N-1.
19-
// 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given
19+
// 2. We apply gpuprim::DeviceRadixSort::SortPairs to the key - value pairs given
2020
// by partitions and indices_in. This will result in two new vectors
2121
// partitions_out and indices_out, with partitions_out sorted.
2222
// 3. The first dimension of outputs[i] is equal to the number of i-values in
2323
// partitions_out. We determine it in two steps:
24-
// - apply cub::DeviceReduce::ReduceByKey to count how many times each value
24+
// - apply gpuprim::DeviceReduce::ReduceByKey to count how many times each value
2525
// appears in partitions_out,
2626
// - move the results to partition_count. This handles missing values
2727
// (corresponding to empty parts).
@@ -31,14 +31,18 @@ limitations under the License.
3131
// This works, because for each interval of i-values, indices_out points
3232
// to the slices which should form output[i].
3333

34-
#if GOOGLE_CUDA
34+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3535

3636
#define EIGEN_USE_GPU
3737

38+
#if GOOGLE_CUDA
3839
#include "external/cub_archive/cub/device/device_radix_sort.cuh"
3940
#include "external/cub_archive/cub/device/device_reduce.cuh"
4041
#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
4142
#include "external/cub_archive/cub/thread/thread_operators.cuh"
43+
#elif TENSORFLOW_USE_ROCM
44+
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
45+
#endif
4246
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
4347
#include "tensorflow/core/framework/op_kernel.h"
4448
#include "tensorflow/core/framework/register_types.h"
@@ -50,6 +54,12 @@ limitations under the License.
5054
#include "tensorflow/core/util/gpu_kernel_helper.h"
5155
#include "tensorflow/core/util/transform_output_iterator.h"
5256

57+
#if GOOGLE_CUDA
58+
namespace gpuprim = ::cub;
59+
#elif TENSORFLOW_USE_ROCM
60+
namespace gpuprim = ::hipcub;
61+
#endif
62+
5363
namespace tensorflow {
5464

5565
typedef Eigen::GpuDevice GPUDevice;
@@ -59,14 +69,14 @@ namespace {
5969
template <typename T>
6070
__global__ void RangeInitKernel(const T start, const T delta, const int32 size,
6171
T* out) {
62-
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
72+
GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
6373
}
6474

6575
__global__ void MoveValuesKernel(const int32* keys, const int32* values,
6676
const int32* size, int32 out_size,
6777
int32* out) {
6878
int32 N = min(ldg(size), out_size);
69-
CUDA_1D_KERNEL_LOOP(i, N) {
79+
GPU_1D_KERNEL_LOOP(i, N) {
7080
int32 key = ldg(keys + i);
7181
int32 value = ldg(values + i);
7282
if (FastBoundsCheck(key, out_size)) out[key] = value;
@@ -78,9 +88,9 @@ __global__ void MoveValuesKernel(const int32* keys, const int32* values,
7888
template <typename T>
7989
void RangeInit(const GPUDevice& d, const T start, const T delta,
8090
const int32 size, typename TTypes<T>::Flat out) {
81-
CudaLaunchConfig config = GetCudaLaunchConfig(size, d);
82-
RangeInitKernel<T>
83-
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
91+
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
92+
GPU_LAUNCH_KERNEL(RangeInitKernel<T>, dim3(config.block_count),
93+
dim3(config.thread_per_block), 0, d.stream(),
8494
start, delta, size, out.data());
8595
}
8696

@@ -93,18 +103,19 @@ void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
93103
// This is valid for correct inputs, because then out_size >= *num_runs.
94104
// For wrong inputs, we may have out_size < *num_runs. In this case we will
95105
// only handle the first out_size values.
96-
CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
97-
MoveValuesKernel<<<config.block_count, config.thread_per_block, 0,
98-
d.stream()>>>(keys, values, num_runs, out_size, out);
106+
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
107+
GPU_LAUNCH_KERNEL(MoveValuesKernel, dim3(config.block_count),
108+
dim3(config.thread_per_block), 0, d.stream(), keys, values,
109+
num_runs, out_size, out);
99110
}
100111

101112
template <typename T>
102113
void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices,
103114
T* out, int64 gather_dim_size, int64 indices_size,
104115
int64 slice_size, int64 out_size) {
105-
CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
106-
GatherOpKernel<T, int32, true>
107-
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
116+
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
117+
GPU_LAUNCH_KERNEL(GatherOpKernel<T, int32, true>,
118+
dim3(config.block_count), dim3(config.thread_per_block), 0, d.stream(),
108119
params, indices, out, gather_dim_size, indices_size, slice_size,
109120
out_size);
110121
}
@@ -180,7 +191,7 @@ class BoundedOutputIterator
180191
// I + P + max(3N + R + P, O + N), where:
181192
// I - the size of the input
182193
// N - the size of the partitions tensor
183-
// R - the temporary storage used by cub::RadixSort, about 2N
194+
// R - the temporary storage used by gpuprim::RadixSort, about 2N
184195
// P - the number of partitions
185196
// O - the size of the output
186197
// So roughly the cost is I + P + max(5N, O + N).
@@ -325,7 +336,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
325336
Tensor* indices_out, DoneCallback done) {
326337
int32 N = partitions->NumElements();
327338
const GPUDevice& device = c->eigen_device<GPUDevice>();
328-
const cudaStream_t& cu_stream = GetCudaStream(c);
339+
const gpuStream_t& cu_stream = GetGPUStream(c);
329340

330341
// Initialize the indices_in tensor using the Range GPU kernel.
331342
RangeInit(device, 0, 1, N, indices_in->flat<int32>());
@@ -337,7 +348,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
337348
// Determine temporary device storage requirements.
338349
Tensor cub_temp_storage;
339350
size_t temp_storage_bytes = 0;
340-
cub::DeviceRadixSort::SortPairs(
351+
gpuprim::DeviceRadixSort::SortPairs(
341352
NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr,
342353
indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream);
343354
// Allocate temporary storage.
@@ -348,7 +359,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
348359
&cub_temp_storage),
349360
done);
350361
// Radix-sort the partition information.
351-
cub::DeviceRadixSort::SortPairs(
362+
gpuprim::DeviceRadixSort::SortPairs(
352363
cub_temp_storage.flat<int8>().data(), temp_storage_bytes,
353364
partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
354365
0, sizeof(int32) * 8, cu_stream);
@@ -358,7 +369,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
358369
Tensor* partition_count, Tensor* indices_out,
359370
DoneCallback done) {
360371
const GPUDevice& device = c->eigen_device<GPUDevice>();
361-
const cudaStream_t& cu_stream = GetCudaStream(c);
372+
const gpuStream_t& cu_stream = GetGPUStream(c);
362373
int32 N = partitions->NumElements();
363374
Tensor indices_in;
364375
Tensor partitions_out;
@@ -395,8 +406,13 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
395406
BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op,
396407
num_partitions_);
397408

409+
#if GOOGLE_CUDA
398410
cub::ConstantInputIterator<int32> values_in(1);
399-
cub::Sum reduction_op;
411+
#elif TENSORFLOW_USE_ROCM
412+
using ConstantInputIterator = ::rocprim::constant_iterator<int32, ptrdiff_t>;
413+
ConstantInputIterator values_in(1);
414+
#endif
415+
gpuprim::Sum reduction_op;
400416

401417
// Allocate space on GPU for the number of runs. This is required by CUB.
402418
Tensor num_runs;
@@ -407,7 +423,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
407423
// Determine temporary device storage requirements
408424
Tensor cub_temp_storage;
409425
size_t temp_storage_bytes = 0;
410-
cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr,
426+
gpuprim::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr,
411427
unique_out_it, values_in, aggregates_out_it,
412428
num_runs_ptr, reduction_op, N, cu_stream);
413429
// Allocate temporary storage.
@@ -421,7 +437,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
421437
// each index appears in partitions. The distinct indices are stored
422438
// in unique_out, while the count is stored in aggregates_out.
423439
// The total number of distinct indices is stored in num_runs.
424-
cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(),
440+
gpuprim::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(),
425441
temp_storage_bytes, keys_in_ptr,
426442
unique_out_it, values_in, aggregates_out_it,
427443
num_runs_ptr, reduction_op, N, cu_stream);
@@ -467,4 +483,4 @@ TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);
467483

468484
} // namespace tensorflow
469485

470-
#endif // GOOGLE_CUDA
486+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

0 commit comments

Comments
 (0)