Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocprim integration #7

Merged
merged 14 commits into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rocm_docs/core_kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ Kernels under tensorflow/core/kernels
| O | | bias_op.h
| O | | bias_op_gpu.cu.cc
| O | | bias_op_gpu.h
| X |cub | bincount_op.cc
| X |cub | bincount_op.h
| X |cub | bincount_op_gpu.cu.cc
| X |cub | bincount_op_test.cc
| O |rocPRIM | bincount_op.cc
| O |rocPRIM | bincount_op.h
| O |rocPRIM | bincount_op_gpu.cu.cc
| O |rocPRIM | bincount_op_test.cc
| O | | bitcast_op.cc
| O | | bitcast_op.h
| O | | boosted_trees
Expand Down Expand Up @@ -557,9 +557,9 @@ Kernels under tensorflow/core/kernels
| O | | hexagon/soc_interface.cc
| O | | hexagon/soc_interface.h
| O | | hinge-loss.h
| X |cub | histogram_op.cc
| X |cub | histogram_op.h
| X |cub | histogram_op_gpu.cu.cc
| O |rocPRIM | histogram_op.cc
| O |rocPRIM | histogram_op.h
| O |rocPRIM | histogram_op_gpu.cu.cc
| O | | i_remote_fused_graph_executor.h
| O | | i_remote_fused_graph_ops_definitions.cc
| O | | i_remote_fused_graph_ops_definitions.h
Expand Down
22 changes: 15 additions & 7 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,8 @@ tf_kernel_library(
deps = if_cuda([
":cuda_solvers",
"@cub_archive//:cub",
]) + if_rocm([
"@rocprim_archive//:rocprim",
]) + ARRAY_DEPS,
)

Expand Down Expand Up @@ -1779,7 +1781,8 @@ tf_kernel_library(
deps = DYNAMIC_DEPS + [
":fill_functor",
":gather_functor",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand Down Expand Up @@ -2911,7 +2914,7 @@ tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
prefix = "reduction_ops",
deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]),
deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand Down Expand Up @@ -3416,7 +3419,9 @@ tf_kernel_library(
deps = NN_DEPS + if_cuda([
":reduction_ops",
"@cub_archive//:cub",
]),
]) + if_rocm([
"@rocprim_archive//:rocprim",
]),
)

tf_kernel_library(
Expand All @@ -3434,7 +3439,7 @@ tf_kernel_library(
tf_kernel_library(
name = "topk_op",
prefix = "topk_op",
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]),
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3457,7 +3462,8 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3468,7 +3474,8 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3482,7 +3489,8 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_grad",
"//tensorflow/core:nn_ops_op_lib",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_gpu_cc_test(
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BincountOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Bincount") \
Expand All @@ -133,6 +133,6 @@ TF_CALL_int32(REGISTER_KERNELS);
TF_CALL_float(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow
26 changes: 18 additions & 8 deletions tensorflow/core/kernels/bincount_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#if GOOGLE_CUDA
#include "external/cub_archive/cub/device/device_histogram.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -27,6 +31,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;
Expand Down Expand Up @@ -55,11 +65,11 @@ struct BincountFunctor<GPUDevice, T> {
int32 lower_level = 0;
int32 upper_level = output.size();
int num_samples = arr.size();
const cudaStream_t& stream = GetCudaStream(context);
const gpuStream_t& stream = GetGPUStream(context);

// The first HistogramEven is to obtain the temp storage size required
// with d_temp_storage = NULL passed to the call.
auto err = cub::DeviceHistogram::HistogramEven(
auto err = gpuprim::DeviceHistogram::HistogramEven(
/* d_temp_storage */ NULL,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
Expand All @@ -69,10 +79,10 @@ struct BincountFunctor<GPUDevice, T> {
/* upper_level */ upper_level,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramEven to get temp storage: ",
cudaGetErrorString(err), ".");
GPUGETERRORSTRING(err), ".");
}
Tensor temp_storage;
TF_RETURN_IF_ERROR(context->allocate_temp(
Expand All @@ -82,7 +92,7 @@ struct BincountFunctor<GPUDevice, T> {
void* d_temp_storage = temp_storage.flat<int8>().data();
// The second HistogramEven is to actual run with d_temp_storage
// allocated with temp_storage_bytes.
err = cub::DeviceHistogram::HistogramEven(
err = gpuprim::DeviceHistogram::HistogramEven(
/* d_temp_storage */ d_temp_storage,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
Expand All @@ -92,9 +102,9 @@ struct BincountFunctor<GPUDevice, T> {
/* upper_level */ upper_level,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramEven: ", cudaGetErrorString(err), ".");
"Could not launch HistogramEven: ", GPUGETERRORSTRING(err), ".");
}
return Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/determinant_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {
namespace functor {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/transform_output_iterator.h"

namespace tensorflow {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/histogram_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
.Device(DEVICE_GPU) \
Expand All @@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow
28 changes: 19 additions & 9 deletions tensorflow/core/kernels/histogram_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,30 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "external/cub_archive/cub/device/device_histogram.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/histogram_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif

namespace tensorflow {

Expand Down Expand Up @@ -66,11 +76,11 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
int num_levels = levels.size();
T* d_levels = levels.data();
int num_samples = values.size();
const cudaStream_t& stream = GetCudaStream(context);
const gpuStream_t& stream = GetGPUStream(context);

// The first HistogramRange is to obtain the temp storage size required
// with d_temp_storage = NULL passed to the call.
auto err = cub::DeviceHistogram::HistogramRange(
auto err = gpuprim::DeviceHistogram::HistogramRange(
/* d_temp_storage */ NULL,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
Expand All @@ -79,10 +89,10 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
/* d_levels */ d_levels,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramRange to get temp storage: ",
cudaGetErrorString(err), ".");
GPUGETERRORSTRING(err), ".");
}

Tensor temp_storage;
Expand All @@ -94,7 +104,7 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {

// The second HistogramRange is to actual run with d_temp_storage
// allocated with temp_storage_bytes.
err = cub::DeviceHistogram::HistogramRange(
err = gpuprim::DeviceHistogram::HistogramRange(
/* d_temp_storage */ d_temp_storage,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_samples */ d_samples,
Expand All @@ -103,9 +113,9 @@ struct HistogramFixedWidthFunctor<GPUDevice, T, Tout> {
/* d_levels */ d_levels,
/* num_samples */ num_samples,
/* stream */ stream);
if (err != cudaSuccess) {
if (err != gpuSuccess) {
return errors::Internal(
"Could not launch HistogramRange: ", cudaGetErrorString(err), ".");
"Could not launch HistogramRange: ", GPUGETERRORSTRING(err), ".");
}

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/reduction_gpu_kernels.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/permutation_input_iterator.h"
#include "tensorflow/core/util/transform_output_iterator.h"

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/softmax_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"

#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/svd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/topk_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

// Required for sorting Eigen::half
namespace cub {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/where_op_gpu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/where_op.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
20 changes: 20 additions & 0 deletions tensorflow/core/util/gpu_kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ limitations under the License.
#define GPU_AXIS_KERNEL_LOOP(i, n, axis) \
for (int i : ::tensorflow::GpuGridRange##axis<int>(n))

#if GOOGLE_CUDA
#define gpuSuccess cudaSuccess
#define GPUGETERRORSTRING(error) cudaGetErrorString(error)
using gpuStream_t = cudaStream_t;
#define GPUGETLASTERROR() cudaGetLastError()
using gpuError_t = cudaError_t;
#elif TENSORFLOW_USE_ROCM
#define gpuSuccess hipSuccess
#define GPUGETERRORSTRING(error) hipGetErrorString(error)
using gpuStream_t = hipStream_t;
#define GPUGETLASTERROR() hipGetLastError()
using gpuError_t = hipError_t;
#endif

#if GOOGLE_CUDA
#define GetGPUStream(context) GetCudaStream(context)
#elif TENSORFLOW_USE_ROCM
#define GetGPUStream(context) context->eigen_gpu_device().stream()
#endif

namespace tensorflow {
__host__ __device__ inline tensorflow::bfloat16 GpuLdg(
const tensorflow::bfloat16* address) {
Expand Down
Loading