-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add malloc async functionality to ROCm #3
base: main
Are you sure you want to change the base?
Conversation
c0c72e9
to
bf154a5
Compare
bf154a5
to
d7df941
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also rename GpuCudaMallocAsyncAllocator to GpuMallocAsyncAllocator ? As it now for both platforms. Other then that, it looks good to me: indeed it might be better to place CUDA/HIP functions into corresponding cuda_driver.cc/rocm_driver.cc. But this is a quite tediuos work, and also most of these functions / enum types are used only in this single place. We can add a comment about it in an upstream PR, to ask what opinion XLA reviewers have about it ?
#elif TENSORFLOW_USE_ROCM | ||
error = GpuGetErrorString(result); | ||
name = GpuGetErrorName(result); | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can just use cuGetErrorString/cuGetErrorName and hipGetErrorString/hipGetErrorName there ? Because the function prototypes differ anyway, hence there is no much sense in defining generic versions GpuGetErrorString/GpuGetErrorName above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is your PR based on upstream:main? If so please ignore my comments, otherwise I think you might include many unnessary changes in it. (We don't use this rocm/xla:main and it's outdated)
third_party/llvm/workspace.bzl
Outdated
@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") | |||
|
|||
def repo(name): | |||
"""Imports LLVM.""" | |||
LLVM_COMMIT = "8d6784db04ee5d925a2d036a68f00a7c124c6cf9" | |||
LLVM_SHA256 = "05f4eddd26f28400d13235417ee9de822e255bdc0ec0b30826bb03156ea6fdc5" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this should be not included?
xla/BUILD
Outdated
@@ -1239,7 +1239,7 @@ filegroup( | |||
# name = "xla_py_pb2", | |||
# testonly = 0, | |||
# api_version = 2, | |||
# compatible_with = ["//buildenv/target:gce"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not be included as well?
xla/python/tpu_driver/BUILD
Outdated
@@ -121,6 +121,6 @@ cc_library( | |||
go_grpc_library( | |||
name = "tpu_service_go_grpc", | |||
srcs = [":tpu_service_proto"], | |||
compatible_with = ["//buildenv/target:gce"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one shouldn't be included?
xla/service/gpu/BUILD
Outdated
@@ -872,7 +872,7 @@ tsl_gpu_library( | |||
name = "_nccl_utils", | |||
srcs = if_gpu_is_configured(["nccl_utils.cc"]), | |||
hdrs = if_gpu_is_configured(["nccl_utils.h"]), | |||
# Override tsl_gpu_library()'s internal default value of ["//buildenv/target:gce"]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change this file as well?
…art #3 PiperOrigin-RevId: 599039077
Currently we look for ptxas and nvlink in a few different places on the host machine, then we choose the first found binary without taking its version into account. If the chosen binary doesn't fulfill our version requirements we will later fail even if there was a suitable ptxas or nvlink in the search path in the first place. This change makes it take the version of each binary into account when going through the search path. Unsuitable binaries will be discarded right away and the search continues until we are out of locations to check. This should help with host environments that have multiple CUDA toolkits installed and should make ptxas and nvlink selection more robust. The concreate changes: 1. `FindCudaExecutable` now also takes a minimum version and a list of forbidden (think buggy) versions that are supposed to be skipped. 2. `WarnIfBadPtxAsVersion` has been removed. It was checking for ptxas < 11.1 which is way older than our minimum supported version of 11.8 and was not doing anything given the check described in #3. 3. There was another version check for `ptxas` in `NVPTXCompiler::ChooseLinkingMethod` which was checking for `version(ptxas)` < 11.8. This has also been removed/replace by the version check described in #4. 4. Version checking for `ptxas` and `nvlink` has been consolidated into 2 methods `FindPtxAsExectuable` and `FindNvLinkExecutable`. These methods hard code the current minimum version (and the list of excluded versions) of each tool in one place. It's still not great but at least less spaghetti-like. PiperOrigin-RevId: 618797392
…d phase to Initialize() Imported from GitHub PR openxla#12228 The first time that a NormThunk is executed, it will build a cudnn execution plan. This build step can hang if a NCCL collective is running at the same time. To fix this, I've moved the build step to take place during thunk initialization. We only observe this hang when using cudnn 9. Here's a backtrace from the hang that will be fixed: ``` Thread 585 (Thread 0x7fb9391ff640 (LWP 41364) "main.py"): #0 0x00007fd3d17cffd9 in ?? () from /lib/x86_64-linux-gnu/libc.so.6 #1 0x00007fd3d17da24f in pthread_rwlock_wrlock () from /lib/x86_64-linux-gnu/libc.so.6 #2 0x00007fd070967dfe in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1 #3 0x00007fd0709c928a in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1 #4 0x00007f1970d76102 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 #5 0x00007f1970f2c999 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 #6 0x00007f1970a7d4ab in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 #7 0x00007f1970d0a9cb in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 #8 0x00007fce60b2a98c in cudnn::backend::ExecutionPlan::finalize_internal() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 #9 0x00007fce60aefbb1 in cudnn::backend::Descriptor::finalize() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 #10 0x00007fce60b15bec in cudnnBackendFinalize () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 #11 0x00007fd2521b8f39 in cudnn_frontend::ExecutionPlanBuilder_v8::build() () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #12 0x00007fd2521734ba in stream_executor::gpu::(anonymous namespace)::GetExecPlanFromHeuristics(cudnn_frontend::OperationGraph_v8&&, stream_executor::gpu::(anonymous namespace)::CudnnHandle const&, bool) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #13 0x00007fd25216ff9b in stream_executor::gpu::CudnnSupport::NormRunnerFromDesc(stream_executor::Stream*, stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormKind, double, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #14 0x00007fd24e36b88b in stream_executor::dnn::NormOp::RunnerFromAlgorithmDesc(stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #15 0x00007fd24e36ae37 in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}::operator()() const () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #16 0x00007fd24e36adbc in void absl::lts_20230802::base_internal::CallOnceImpl<stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}>(std::atomic<unsigned int>*, absl::lts_20230802::base_internal::SchedulingMode, stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}&&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #17 0x00007fd24e36a9bd in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #18 0x00007fd24e369d29 in xla::gpu::RunGpuNorm(xla::gpu::GpuNormConfig const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, stream_executor::DeviceMemoryBase const&, stream_executor::Stream*, xla::gpu::RunNormOptions) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #19 0x00007fd24e368be6 in xla::gpu::NormThunk::ExecuteOnStream(xla::gpu::Thunk::ExecuteParams const&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so ``` Copybara import of the project: -- f535330 by Trevor Morris <tmorris@nvidia.com>: Fix hang with cudnn layer norm by moving cudnn init to Initialize() Merging this change closes openxla#12228 COPYBARA_INTEGRATE_REVIEW=openxla#12228 from trevor-m:tmorris-norm-init f535330 PiperOrigin-RevId: 633220207
The xla upstream put the source files implementing platform dependent (CUDA) malloc async in the directory
xla/stream_executor/gpu/.
This directory is not supposed to have platform dependent implementation. Not sure why the upstream is doing this. I think that the canonical way is adding APIs ingpu_driver.h
and putting the platform dependent implementation in the corresponding vendor directories, i.e.cuda
orrocm
. The macros in current implementation will be largely reduced if the canonical implementation is adopted. The existing APIs in the driver can be reused.The current implementation does not change macros defined in the original implementation. Those macros contain the acronym CUDA. The macros will not affect other parts of the code base if handled properly.