-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rocprim integration #7
Changes from 10 commits
31492a7
ed242a2
fe620c1
51a06d5
e76a089
3272056
3eb21d3
d60b574
a2b2eea
ccdc519
992b6a3
6da34fe
8812e69
ff03e95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,23 @@ limitations under the License. | |
#define GPU_AXIS_KERNEL_LOOP(i, n, axis) \ | ||
for (int i : ::tensorflow::GpuGridRange##axis<int>(n)) | ||
|
||
#if TENSORFLOW_USE_ROCM | ||
|
||
#define cub hipcub | ||
#define cudaSuccess hipSuccess | ||
#define cudaGetErrorString hipGetErrorString | ||
#define cudaStream_t hipStream_t | ||
#define cudaGetLastError hipGetLastError | ||
#define cudaError hipError | ||
|
||
#define CUDA_1D_KERNEL_LOOP(i, n) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please help check if this macro can be abolished and use |
||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ | ||
i += blockDim.x * gridDim.x) | ||
|
||
#define GetCudaStream(context) context->eigen_gpu_device().stream() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you also remove the GetCudaStream macro when you removed cuda_kernel_helper.h ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe so. You might want to double check though. |
||
|
||
#endif | ||
|
||
namespace tensorflow { | ||
__host__ __device__ inline tensorflow::bfloat16 GpuLdg( | ||
const tensorflow::bfloat16* address) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Description: rocPRIM library which is a set of primitives for GPU programming on AMD ROCm stack. | ||
|
||
licenses(["notice"]) # BSD | ||
|
||
exports_files(["LICENSE.TXT"]) | ||
|
||
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_default_copts", "if_rocm") | ||
|
||
filegroup( | ||
name = "rocprim_headers", | ||
srcs = glob([ | ||
"hipcub/include/**", | ||
"rocprim/include/**", | ||
]), | ||
) | ||
|
||
cc_library( | ||
name = "rocprim", | ||
hdrs = if_rocm([":rocprim_headers"]), | ||
srcs= ["rocprim_version.hpp", "hipcub_version.hpp"], | ||
deps = [ | ||
"@local_config_rocm//rocm:rocm_headers", | ||
], | ||
includes = ["hipcub/include", | ||
"rocprim/include", | ||
"rocprim/include/rocprim", | ||
".",], | ||
visibility = ["//visibility:public"], | ||
) | ||
|
||
genrule( | ||
name = "rocprim_version_hpp", | ||
message = "Creating rocPRIM version header...", | ||
srcs = ["rocprim/include/rocprim/rocprim_version.hpp.in"], | ||
outs = ["rocprim_version.hpp"], | ||
cmd = ("sed " + | ||
"-e 's/@rocprim_VERSION_MAJOR@/0/g' " + | ||
"-e 's/@rocprim_VERSION_MINOR@/3/g' " + | ||
"-e 's/@rocprim_VERSION_PATCH@/0/g' " + | ||
"$< >$@"), | ||
) | ||
|
||
genrule( | ||
name = "hipcub_version_hpp", | ||
message = "Creating hipcub version header...", | ||
srcs = ["hipcub/include/hipcub/hipcub_version.hpp.in"], | ||
outs = ["hipcub_version.hpp"], | ||
cmd = ("sed " + | ||
"-e 's/@rocprim_VERSION_MAJOR@/0/g' " + | ||
"-e 's/@rocprim_VERSION_MINOR@/3/g' " + | ||
"-e 's/@rocprim_VERSION_PATCH@/0/g' " + | ||
"$< >$@"), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ll need to sit behind a computer to check this header again. Most of cuda to hipnchanges don’t sound correct to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please elaborate. I am happy to add more advanced macros, however, these simple define changes do compile correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the goal is to try restrict exposure to
cuda
for less potential legal issue. I'd feel more comfortable putting#if/#elif
blocks in the operators.