Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocprim integration #7

Merged
merged 14 commits into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,8 @@ tf_kernel_library(
deps = if_cuda([
":cuda_solvers",
"@cub_archive//:cub",
]) + if_rocm([
"@rocprim_archive//:rocprim",
]) + ARRAY_DEPS,
)

Expand Down Expand Up @@ -1779,7 +1781,8 @@ tf_kernel_library(
deps = DYNAMIC_DEPS + [
":fill_functor",
":gather_functor",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand Down Expand Up @@ -2911,7 +2914,7 @@ tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
prefix = "reduction_ops",
deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]),
deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand Down Expand Up @@ -3416,7 +3419,9 @@ tf_kernel_library(
deps = NN_DEPS + if_cuda([
":reduction_ops",
"@cub_archive//:cub",
]),
]) + if_rocm([
"@rocprim_archive//:rocprim",
]),
)

tf_kernel_library(
Expand All @@ -3434,7 +3439,7 @@ tf_kernel_library(
tf_kernel_library(
name = "topk_op",
prefix = "topk_op",
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]),
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]) + if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3457,7 +3462,8 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3468,7 +3474,8 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_kernel_library(
Expand All @@ -3482,7 +3489,8 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_grad",
"//tensorflow/core:nn_ops_op_lib",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda(["@cub_archive//:cub"])
+ if_rocm(["@rocprim_archive//:rocprim"]),
)

tf_gpu_cc_test(
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/bincount_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BincountOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Bincount") \
Expand All @@ -133,6 +133,6 @@ TF_CALL_int32(REGISTER_KERNELS);
TF_CALL_float(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow
6 changes: 5 additions & 1 deletion tensorflow/core/kernels/bincount_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#if GOOGLE_CUDA
#include "external/cub_archive/cub/device/device_histogram.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/determinant_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {
namespace functor {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/transform_output_iterator.h"

namespace tensorflow {
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/histogram_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class HistogramFixedWidthOp : public OpKernel {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("HistogramFixedWidth") \
.Device(DEVICE_GPU) \
Expand All @@ -142,6 +142,6 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/histogram_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "external/cub_archive/cub/device/device_histogram.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/histogram_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/reduction_gpu_kernels.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/permutation_input_iterator.h"
#include "tensorflow/core/util/transform_output_iterator.h"

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/softmax_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"

#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/svd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/topk_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

// Required for sorting Eigen::half
namespace cub {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/where_op_gpu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/where_op.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

namespace tensorflow {

Expand Down
17 changes: 17 additions & 0 deletions tensorflow/core/util/gpu_kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ limitations under the License.
#define GPU_AXIS_KERNEL_LOOP(i, n, axis) \
for (int i : ::tensorflow::GpuGridRange##axis<int>(n))

#if TENSORFLOW_USE_ROCM

#define cub hipcub
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll need to sit behind a computer to check this header again. Most of cuda to hipnchanges don’t sound correct to me

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please elaborate. I am happy to add more advanced macros, however, these simple define changes do compile correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the goal is to try restrict exposure to cuda for less potential legal issue. I'd feel more comfortable putting #if/#elif blocks in the operators.

#define cudaSuccess hipSuccess
#define cudaGetErrorString hipGetErrorString
#define cudaStream_t hipStream_t
#define cudaGetLastError hipGetLastError
#define cudaError hipError

#define CUDA_1D_KERNEL_LOOP(i, n) \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please help check if this macro can be abolished and use GPU_1D_KERNEL_LOOP also defined in this header file.

for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

#define GetCudaStream(context) context->eigen_gpu_device().stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetGpuStream

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you also remove the GetCudaStream macro when you removed cuda_kernel_helper.h ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. You might want to double check though.


#endif

namespace tensorflow {
__host__ __device__ inline tensorflow::bfloat16 GpuLdg(
const tensorflow::bfloat16* address) {
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
build_file = clean_dep("//third_party:cub.BUILD"),
)

tf_http_archive(
name = "rocprim_archive",
urls = [
"https://mirror.bazel.build/github.com/ROCmSoftwarePlatform/rocPRIM/archive/563461f.zip",
"https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/563461f.zip",
],
sha256 = "64d340057649d7643cb6ae158b168dd6d48d2b1dd77297fbdfcb6249528e0707",
strip_prefix = "rocPRIM-563461f3def38e30bcc53d9bf37b2e12f494ab99",
build_file = clean_dep("//third_party:rocprim.BUILD"),
)

tf_http_archive(
name = "cython",
sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5",
Expand Down
53 changes: 53 additions & 0 deletions third_party/rocprim.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Description: rocPRIM library which is a set of primitives for GPU programming on AMD ROCm stack.

licenses(["notice"]) # BSD

exports_files(["LICENSE.TXT"])

load("@local_config_rocm//rocm:build_defs.bzl", "rocm_default_copts", "if_rocm")

filegroup(
name = "rocprim_headers",
srcs = glob([
"hipcub/include/**",
"rocprim/include/**",
]),
)

cc_library(
name = "rocprim",
hdrs = if_rocm([":rocprim_headers"]),
srcs= ["rocprim_version.hpp", "hipcub_version.hpp"],
deps = [
"@local_config_rocm//rocm:rocm_headers",
],
includes = ["hipcub/include",
"rocprim/include",
"rocprim/include/rocprim",
".",],
visibility = ["//visibility:public"],
)

genrule(
name = "rocprim_version_hpp",
message = "Creating rocPRIM version header...",
srcs = ["rocprim/include/rocprim/rocprim_version.hpp.in"],
outs = ["rocprim_version.hpp"],
cmd = ("sed " +
"-e 's/@rocprim_VERSION_MAJOR@/0/g' " +
"-e 's/@rocprim_VERSION_MINOR@/3/g' " +
"-e 's/@rocprim_VERSION_PATCH@/0/g' " +
"$< >$@"),
)

genrule(
name = "hipcub_version_hpp",
message = "Creating hipcub version header...",
srcs = ["hipcub/include/hipcub/hipcub_version.hpp.in"],
outs = ["hipcub_version.hpp"],
cmd = ("sed " +
"-e 's/@rocprim_VERSION_MAJOR@/0/g' " +
"-e 's/@rocprim_VERSION_MINOR@/3/g' " +
"-e 's/@rocprim_VERSION_PATCH@/0/g' " +
"$< >$@"),
)