Skip to content

Commit

Permalink
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi p…
Browse files Browse the repository at this point in the history
…art #2

PiperOrigin-RevId: 598915673
  • Loading branch information
ezhulenev authored and copybara-github committed Jan 16, 2024
1 parent 5834c40 commit f29ab1e
Show file tree
Hide file tree
Showing 23 changed files with 824 additions and 401 deletions.
43 changes: 5 additions & 38 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm", "if_gpu_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@tsl//tsl:tsl.bzl", "if_nccl")
load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
Expand Down Expand Up @@ -44,6 +43,7 @@ cc_library(
":gpu_helpers",
":gpu_metrics",
":gpu_topology",
":nccl_id_store",
"//xla:literal",
"//xla:shape_util",
"//xla:status",
Expand Down Expand Up @@ -116,11 +116,9 @@ cc_library(
] + if_cuda_or_rocm([
"//xla/service/gpu:gpu_compiler",
]) + if_cuda([
":nccl_id_store_cuda",
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
]) + if_rocm([
":nccl_id_store_rocm",
"@local_config_rocm//rocm:rocm_headers",
]),
)
Expand Down Expand Up @@ -161,15 +159,10 @@ xla_cc_test(
],
)

# We actually wish we could write if_cuda(if_nccl(...)) in :gpu_device,
# but Bazel does not allow nested selects. We can work around the problem using
# an intermediate library that has the conditional NCCL pieces that is only
# itself included as a dependency if CUDA is enabled.
cc_library(
name = "nccl_id_store_cuda",
name = "nccl_id_store",
srcs = ["nccl_id_store.cc"],
hdrs = ["nccl_id_store.h"],
defines = if_nccl(["NCCL_ENABLED=1"]),
deps = [
"//xla:status_macros",
"//xla:statusor",
Expand All @@ -178,6 +171,7 @@ cc_library(
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:global_device_id",
"//xla/service/gpu:nccl_api",
"//xla/service/gpu:nccl_clique_key",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
Expand All @@ -188,33 +182,7 @@ cc_library(
"@com_google_absl//absl/time",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
] + if_nccl(["@local_config_nccl//:nccl"]),
)

cc_library(
name = "nccl_id_store_rocm",
srcs = ["nccl_id_store.cc"],
hdrs = ["nccl_id_store.h"],
defines = if_nccl(["NCCL_ENABLED=1"]),
deps = [
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla/pjrt:pjrt_client",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:global_device_id",
"//xla/service/gpu:nccl_clique_key",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
] + if_nccl(["@local_config_nccl//:nccl"]),
],
)

xla_cc_test(
Expand Down Expand Up @@ -255,6 +223,7 @@ cc_library(
hdrs = ["se_gpu_pjrt_compiler.h"],
defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]),
deps = [
":nccl_id_store",
":se_gpu_pjrt_client",
"//xla:status_macros",
"//xla/client:local_client",
Expand All @@ -281,13 +250,11 @@ cc_library(
] + if_cuda_or_rocm([
"//xla/service/gpu:gpu_compiler",
]) + if_cuda([
":nccl_id_store_cuda",
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor/cuda:cuda_activation_header",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
"//xla/service/gpu:nvptx_compiler_impl",
]) + if_rocm([
":nccl_id_store_rocm",
"@local_config_rocm//rocm:rocm_headers",
"//xla/service/gpu:amdgpu_compiler_impl",
]),
Expand Down
24 changes: 2 additions & 22 deletions xla/pjrt/gpu/nccl_id_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,13 @@ limitations under the License.

#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "xla/service/gpu/nccl_api.h"
#include "xla/service/gpu/nccl_clique_key.h"
#include "xla/status_macros.h"
#include "xla/statusor.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

#ifdef NCCL_ENABLED
#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#if (TF_ROCM_VERSION >= 50200)
#include "rocm/include/rccl/rccl.h"
#else
#include "rocm/include/rccl.h"
#endif
#else
#include "third_party/nccl/nccl.h"
#endif
#endif // NCCL_ENABLED

namespace xla {

StatusOr<gpu::NcclCliqueId> NcclIdStore::GetNcclUniqueId(
Expand All @@ -55,16 +43,8 @@ StatusOr<gpu::NcclCliqueId> NcclIdStore::GetNcclUniqueId(
gpu::NcclCliqueId clique_id;
int primary_node_id = device_to_node_.at(key.devices()[0]);
if (node_id_ == primary_node_id) {
#ifdef NCCL_ENABLED
ncclUniqueId id;
ncclResult_t r = ncclGetUniqueId(&id);
TF_RET_CHECK(r == ncclSuccess);
clique_id = gpu::NcclCliqueId(id.internal);
TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclApi::GetUniqueId());
TF_RETURN_IF_ERROR(kv_store_->Set(key.ToString(), clique_id.ToString()));
#else
return absl::FailedPreconditionError(
"NCCL support was not built into XLA binary.");
#endif
} else {
TF_ASSIGN_OR_RETURN(std::string id_str,
kv_store_->Get(key.ToString(), absl::Minutes(10)));
Expand Down
65 changes: 62 additions & 3 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -845,13 +845,15 @@ tsl_gpu_library(
":backend_configs_cc",
":buffer_allocations",
":ir_emission_utils",
":nccl_api",
":nccl_clique",
":nccl_clique_key",
":nccl_errors",
":nccl_utils",
":thunk",
"//xla:shape_util",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
Expand Down Expand Up @@ -892,6 +894,15 @@ tsl_gpu_library(
]),
)

#===-------------------------------------------------------------------------------------------===//
# NCCL integration
#===-------------------------------------------------------------------------------------------===//

# A lot of build complexity below is because NCCL dependency might not always be available and we
# have `if_nccl` and `if_gpu_configured` that do not compose. NCCL header included directly in a few
# targets (:nccl_types and :nccl_api) and all other targets should use these headers to launch
# collective operations. This allows to minimize the spreading of #ifdef all over the XLA code base.

# Empty library to implement nested dependency conditions.
cc_library(
name = "empty",
Expand Down Expand Up @@ -944,6 +955,55 @@ tsl_gpu_library(
]),
)

alias(
name = "nccl_api",
actual = if_nccl(":_nccl_api_impl", ":_nccl_api_stub"),
)

cc_library(
name = "_nccl_api_impl",
srcs = if_cuda_is_configured(
["nccl_api.cc"],
["nccl_api_stub.cc"],
),
hdrs = ["nccl_api.h"],
compatible_with = get_compatible_with_portable(),
defines = if_cuda_is_configured(["XLA_ENABLE_XCCL"]),
deps = [
":nccl_clique_key",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/service:collective_ops_utils",
"//xla/stream_executor",
"//xla/stream_executor/gpu:gpu_stream",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
] + if_cuda_is_configured([
"@local_config_nccl//:nccl",
]),
)

cc_library(
name = "_nccl_api_stub",
srcs = ["nccl_api_stub.cc"],
hdrs = ["nccl_api.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":nccl_clique_key",
"//xla:xla_data_proto_cc",
"//xla/service:collective_ops_utils",
"//xla/stream_executor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

tsl_gpu_library(
name = "_nccl_errors",
srcs = ["nccl_errors.cc"],
Expand All @@ -965,8 +1025,7 @@ tsl_gpu_library(
srcs = ["nccl_clique.cc"],
hdrs = ["nccl_clique.h"],
deps = [
":_nccl_errors",
":_nccl_types",
":nccl_api",
":nccl_clique_key",
"//xla:debug_options_flags",
"//xla:executable_run_options",
Expand All @@ -978,7 +1037,6 @@ tsl_gpu_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1112,6 +1170,7 @@ cc_library(
":mock_nccl_sleep_kernel_cuda",
":mock_nccl_xml_google",
":nccl_collective_thunks",
":nccl_api",
":nccl_errors",
":nccl_utils",
":nccl_clique_key",
Expand Down
21 changes: 6 additions & 15 deletions xla/service/gpu/mock_nccl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ limitations under the License.
#include "xla/service/gpu/mock_nccl_sleep_kernel.h"
#include "xla/service/gpu/mock_nccl_topo_config.h"
#include "xla/service/gpu/mock_nccl_xml.h"
#include "xla/service/gpu/nccl_api.h"
#include "xla/service/gpu/nccl_clique.h"
#include "xla/service/gpu/nccl_clique_key.h"
#include "xla/service/gpu/nccl_collective_thunk.h"
Expand Down Expand Up @@ -497,26 +498,15 @@ absl::Status RunMockCollectivePermute(

namespace {
void CheckNcclAsyncError(NcclComm& lockable_comm) {
ncclComm_t comm = *lockable_comm.Acquire();
NcclCommHandle comm = *lockable_comm.Acquire();
if (comm == nullptr) return;

absl::Status status = [comm] {
ncclResult_t async_err;
XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err));
if (async_err != ncclSuccess) {
LOG(ERROR) << "Aborting communicator: " << comm
<< " due to async NCCL error: "
<< ncclGetErrorString(async_err);
XLA_NCCL_RETURN_IF_ERROR(ncclCommAbort(comm));
}
return XLA_NCCL_STATUS(async_err);
}();

absl::Status status = NcclApi::CommGetAsyncError(comm);
if (!status.ok()) LOG(ERROR) << status;
}

struct NcclCliqueState {
ncclUniqueId unique_id;
NcclCliqueId clique_id;
int64_t run_id = -1;

// `mu` guards `communicators` and `status` during initialization.
Expand Down Expand Up @@ -726,7 +716,8 @@ absl::StatusOr<NcclComm::Lock> AcquireMockNcclComm(
size_t num_initialized = [&] {
absl::MutexLock lock(&state.mu);
state.status.Update(status);
state.communicators[rank] = std::make_unique<NcclComm>(comm);
state.communicators[rank] =
std::make_unique<NcclComm>(reinterpret_cast<NcclCommHandle>(comm));
return state.communicators.size();
}();

Expand Down
16 changes: 16 additions & 0 deletions xla/service/gpu/mock_nccl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/service/collective_ops_utils.h"
#include "xla/service/global_device_id.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/gpu/nccl_api.h"
#include "xla/service/gpu/nccl_clique_key.h"
#include "xla/service/gpu/nccl_collective_thunk.h"
#include "xla/service/gpu/nccl_p2p_thunk_common.h"
Expand Down Expand Up @@ -59,11 +60,26 @@ absl::Status RunMockNcclCollectives(std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, ncclComm_t comm,
Thunk::Kind reduce_op);

inline absl::Status RunMockNcclCollectives(
std::vector<DeviceBufferPair>& buffers, se::Stream& stream,
NcclCommHandle comm, Thunk::Kind reduce_op) {
return RunMockNcclCollectives(buffers, stream,
reinterpret_cast<ncclComm_t>(comm), reduce_op);
}

// Mock a NCCL-based All-To-All op.
absl::Status RunMockNcclAllToAll(bool has_split_dimension,
std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, ncclComm_t comm);

inline absl::Status RunMockNcclAllToAll(bool has_split_dimension,
std::vector<DeviceBufferPair>& buffers,
se::Stream& stream,
NcclCommHandle comm) {
return RunMockNcclAllToAll(has_split_dimension, buffers, stream,
reinterpret_cast<ncclComm_t>(comm));
}

// Mock a collective permute op.
absl::Status RunMockCollectivePermute(
NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer,
Expand Down
Loading

0 comments on commit f29ab1e

Please sign in to comment.