diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index dd91b916bafe31..ed8f4f73adc267 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -2,7 +2,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm", "if_gpu_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("@tsl//tsl:tsl.bzl", "if_nccl") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") @@ -44,6 +43,7 @@ cc_library( ":gpu_helpers", ":gpu_metrics", ":gpu_topology", + ":nccl_id_store", "//xla:literal", "//xla:shape_util", "//xla:status", @@ -116,11 +116,9 @@ cc_library( ] + if_cuda_or_rocm([ "//xla/service/gpu:gpu_compiler", ]) + if_cuda([ - ":nccl_id_store_cuda", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", ]) + if_rocm([ - ":nccl_id_store_rocm", "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -161,15 +159,10 @@ xla_cc_test( ], ) -# We actually wish we could write if_cuda(if_nccl(...)) in :gpu_device, -# but Bazel does not allow nested selects. We can work around the problem using -# an intermediate library that has the conditional NCCL pieces that is only -# itself included as a dependency if CUDA is enabled. cc_library( - name = "nccl_id_store_cuda", + name = "nccl_id_store", srcs = ["nccl_id_store.cc"], hdrs = ["nccl_id_store.h"], - defines = if_nccl(["NCCL_ENABLED=1"]), deps = [ "//xla:status_macros", "//xla:statusor", @@ -178,6 +171,7 @@ cc_library( "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:global_device_id", + "//xla/service/gpu:nccl_api", "//xla/service/gpu:nccl_clique_key", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -188,33 +182,7 @@ cc_library( "@com_google_absl//absl/time", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:statusor", - ] + if_nccl(["@local_config_nccl//:nccl"]), -) - -cc_library( - name = "nccl_id_store_rocm", - srcs = ["nccl_id_store.cc"], - hdrs = ["nccl_id_store.h"], - defines = if_nccl(["NCCL_ENABLED=1"]), - deps = [ - "//xla:status_macros", - "//xla:statusor", - "//xla:util", - "//xla/pjrt:pjrt_client", - "//xla/pjrt/distributed:client", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:global_device_id", - "//xla/service/gpu:nccl_clique_key", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", - ] + if_nccl(["@local_config_nccl//:nccl"]), + ], ) xla_cc_test( @@ -255,6 +223,7 @@ cc_library( hdrs = ["se_gpu_pjrt_compiler.h"], defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]), deps = [ + ":nccl_id_store", ":se_gpu_pjrt_client", "//xla:status_macros", "//xla/client:local_client", @@ -281,13 +250,11 @@ cc_library( ] + if_cuda_or_rocm([ "//xla/service/gpu:gpu_compiler", ]) + if_cuda([ - ":nccl_id_store_cuda", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor/cuda:cuda_activation_header", "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", "//xla/service/gpu:nvptx_compiler_impl", ]) + if_rocm([ - ":nccl_id_store_rocm", "@local_config_rocm//rocm:rocm_headers", "//xla/service/gpu:amdgpu_compiler_impl", ]), diff --git a/xla/pjrt/gpu/nccl_id_store.cc b/xla/pjrt/gpu/nccl_id_store.cc index 48cb3b4fb35ec2..c61529f1a2b356 100644 --- a/xla/pjrt/gpu/nccl_id_store.cc +++ b/xla/pjrt/gpu/nccl_id_store.cc @@ -20,25 +20,13 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/status_macros.h" #include "xla/statusor.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" -#ifdef NCCL_ENABLED -#if TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#if (TF_ROCM_VERSION >= 50200) -#include "rocm/include/rccl/rccl.h" -#else -#include "rocm/include/rccl.h" -#endif -#else -#include "third_party/nccl/nccl.h" -#endif -#endif // NCCL_ENABLED - namespace xla { StatusOr NcclIdStore::GetNcclUniqueId( @@ -55,16 +43,8 @@ StatusOr NcclIdStore::GetNcclUniqueId( gpu::NcclCliqueId clique_id; int primary_node_id = device_to_node_.at(key.devices()[0]); if (node_id_ == primary_node_id) { -#ifdef NCCL_ENABLED - ncclUniqueId id; - ncclResult_t r = ncclGetUniqueId(&id); - TF_RET_CHECK(r == ncclSuccess); - clique_id = gpu::NcclCliqueId(id.internal); + TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclApi::GetUniqueId()); TF_RETURN_IF_ERROR(kv_store_->Set(key.ToString(), clique_id.ToString())); -#else - return absl::FailedPreconditionError( - "NCCL support was not built into XLA binary."); -#endif } else { TF_ASSIGN_OR_RETURN(std::string id_str, kv_store_->Get(key.ToString(), absl::Minutes(10))); diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index f366e019476807..3116a3a5e38eca 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -845,6 +845,7 @@ tsl_gpu_library( ":backend_configs_cc", ":buffer_allocations", ":ir_emission_utils", + ":nccl_api", ":nccl_clique", ":nccl_clique_key", ":nccl_errors", @@ -852,6 +853,7 @@ tsl_gpu_library( ":thunk", "//xla:shape_util", "//xla:status", + "//xla:status_macros", "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", @@ -892,6 +894,15 @@ tsl_gpu_library( ]), ) +#===-------------------------------------------------------------------------------------------===// +# NCCL integration +#===-------------------------------------------------------------------------------------------===// + +# A lot of build complexity below is because NCCL dependency might not always be available and we +# have `if_nccl` and `if_gpu_configured` that do not compose. NCCL header included directly in a few +# targets (:nccl_types and :nccl_api) and all other targets should use these headers to launch +# collective operations. This allows to minimize the spreading of #ifdef all over the XLA code base. + # Empty library to implement nested dependency conditions. cc_library( name = "empty", @@ -944,6 +955,55 @@ tsl_gpu_library( ]), ) +alias( + name = "nccl_api", + actual = if_nccl(":_nccl_api_impl", ":_nccl_api_stub"), +) + +cc_library( + name = "_nccl_api_impl", + srcs = if_cuda_is_configured( + ["nccl_api.cc"], + ["nccl_api_stub.cc"], + ), + hdrs = ["nccl_api.h"], + compatible_with = get_compatible_with_portable(), + defines = if_cuda_is_configured(["XLA_ENABLE_XCCL"]), + deps = [ + ":nccl_clique_key", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + "@local_config_nccl//:nccl", + ]), +) + +cc_library( + name = "_nccl_api_stub", + srcs = ["nccl_api_stub.cc"], + hdrs = ["nccl_api.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":nccl_clique_key", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/stream_executor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + tsl_gpu_library( name = "_nccl_errors", srcs = ["nccl_errors.cc"], @@ -965,8 +1025,7 @@ tsl_gpu_library( srcs = ["nccl_clique.cc"], hdrs = ["nccl_clique.h"], deps = [ - ":_nccl_errors", - ":_nccl_types", + ":nccl_api", ":nccl_clique_key", "//xla:debug_options_flags", "//xla:executable_run_options", @@ -978,7 +1037,6 @@ tsl_gpu_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1112,6 +1170,7 @@ cc_library( ":mock_nccl_sleep_kernel_cuda", ":mock_nccl_xml_google", ":nccl_collective_thunks", + ":nccl_api", ":nccl_errors", ":nccl_utils", ":nccl_clique_key", diff --git a/xla/service/gpu/mock_nccl_utils.cc b/xla/service/gpu/mock_nccl_utils.cc index 4b2f9bef19e477..03c300a0e2a1a7 100644 --- a/xla/service/gpu/mock_nccl_utils.cc +++ b/xla/service/gpu/mock_nccl_utils.cc @@ -61,6 +61,7 @@ limitations under the License. #include "xla/service/gpu/mock_nccl_sleep_kernel.h" #include "xla/service/gpu/mock_nccl_topo_config.h" #include "xla/service/gpu/mock_nccl_xml.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/nccl_collective_thunk.h" @@ -497,26 +498,15 @@ absl::Status RunMockCollectivePermute( namespace { void CheckNcclAsyncError(NcclComm& lockable_comm) { - ncclComm_t comm = *lockable_comm.Acquire(); + NcclCommHandle comm = *lockable_comm.Acquire(); if (comm == nullptr) return; - absl::Status status = [comm] { - ncclResult_t async_err; - XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err)); - if (async_err != ncclSuccess) { - LOG(ERROR) << "Aborting communicator: " << comm - << " due to async NCCL error: " - << ncclGetErrorString(async_err); - XLA_NCCL_RETURN_IF_ERROR(ncclCommAbort(comm)); - } - return XLA_NCCL_STATUS(async_err); - }(); - + absl::Status status = NcclApi::CommGetAsyncError(comm); if (!status.ok()) LOG(ERROR) << status; } struct NcclCliqueState { - ncclUniqueId unique_id; + NcclCliqueId clique_id; int64_t run_id = -1; // `mu` guards `communicators` and `status` during initialization. @@ -726,7 +716,8 @@ absl::StatusOr AcquireMockNcclComm( size_t num_initialized = [&] { absl::MutexLock lock(&state.mu); state.status.Update(status); - state.communicators[rank] = std::make_unique(comm); + state.communicators[rank] = + std::make_unique(reinterpret_cast(comm)); return state.communicators.size(); }(); diff --git a/xla/service/gpu/mock_nccl_utils.h b/xla/service/gpu/mock_nccl_utils.h index 9d3052694638bf..6c280de3bc3a18 100644 --- a/xla/service/gpu/mock_nccl_utils.h +++ b/xla/service/gpu/mock_nccl_utils.h @@ -26,6 +26,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/gpu_executable_run_options.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/nccl_p2p_thunk_common.h" @@ -59,11 +60,26 @@ absl::Status RunMockNcclCollectives(std::vector& buffers, se::Stream& stream, ncclComm_t comm, Thunk::Kind reduce_op); +inline absl::Status RunMockNcclCollectives( + std::vector& buffers, se::Stream& stream, + NcclCommHandle comm, Thunk::Kind reduce_op) { + return RunMockNcclCollectives(buffers, stream, + reinterpret_cast(comm), reduce_op); +} + // Mock a NCCL-based All-To-All op. absl::Status RunMockNcclAllToAll(bool has_split_dimension, std::vector& buffers, se::Stream& stream, ncclComm_t comm); +inline absl::Status RunMockNcclAllToAll(bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, + NcclCommHandle comm) { + return RunMockNcclAllToAll(has_split_dimension, buffers, stream, + reinterpret_cast(comm)); +} + // Mock a collective permute op. absl::Status RunMockCollectivePermute( NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer, diff --git a/xla/service/gpu/nccl_all_gather_thunk.cc b/xla/service/gpu/nccl_all_gather_thunk.cc index 1072ded003f3d9..aeabefa7b27847 100644 --- a/xla/service/gpu/nccl_all_gather_thunk.cc +++ b/xla/service/gpu/nccl_all_gather_thunk.cc @@ -16,27 +16,25 @@ limitations under the License. #include "xla/service/gpu/nccl_all_gather_thunk.h" #include -#include #include #include +#include "absl/status/status.h" #include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -144,49 +142,30 @@ absl::Status NcclAllGatherStartThunk::RunNcclCollective( std::vector device_buffers, ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); - return xla::gpu::RunAllGather(device_buffers, stream, comm); + return xla::gpu::RunAllGather( + device_buffers, stream, reinterpret_cast(comm)); } absl::Status RunAllGather(std::vector& buffers, - se::Stream& stream, ncclComm_t comm) { -#if XLA_ENABLE_XCCL + se::Stream& stream, NcclApi::NcclCommHandle comm) { int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal; - TF_RETURN_IF_ERROR(MaybeRegisterBuffers(device_ordinal, buffers, comm)); - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); - - PrimitiveType element_type = buffer.element_type; - TF_ASSIGN_OR_RETURN( - auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier(element_type, Thunk::kNcclAllGather)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - VLOG(3) << absl::StreamFormat( - "Calling ncclAllGather(send_buffer=%p, recv_buffer=%p, sendcount=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, element_count, static_cast(comm), - gpu_stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclAllGather( - send_buffer, recv_buffer, element_count, dtype, comm, gpu_stream)); + TF_RETURN_IF_ERROR(MaybeRegisterBuffers(device_ordinal, buffers, + reinterpret_cast(comm))); + + TF_RETURN_IF_ERROR(NcclApi::GroupStart()); + + for (DeviceBufferPair& buffer : buffers) { + TF_RETURN_IF_ERROR(NcclApi::AllGather( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count, reinterpret_cast(comm), + &stream)); } - XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + + TF_RETURN_IF_ERROR(NcclApi::GroupEnd()); VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal; return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL } } // namespace gpu diff --git a/xla/service/gpu/nccl_all_gather_thunk.h b/xla/service/gpu/nccl_all_gather_thunk.h index 1c6696586386ce..87dd51363ea7ac 100644 --- a/xla/service/gpu/nccl_all_gather_thunk.h +++ b/xla/service/gpu/nccl_all_gather_thunk.h @@ -19,11 +19,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" -#include "xla/status.h" +#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { @@ -72,7 +74,7 @@ class NcclAllGatherStartThunk : public NcclCollectiveThunk { }; absl::Status RunAllGather(std::vector& buffers, - se::Stream& stream, ncclComm_t comm); + se::Stream& stream, NcclApi::NcclCommHandle comm); } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/nccl_all_reduce_thunk.cc b/xla/service/gpu/nccl_all_reduce_thunk.cc index b5ec060c53a67b..1d845bce1b5d85 100644 --- a/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -17,14 +17,12 @@ limitations under the License. #include #include -#include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" @@ -32,19 +30,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/thunk.h" +#include "xla/status_macros.h" #include "xla/stream_executor/stream.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - namespace xla { namespace gpu { @@ -54,43 +51,19 @@ using mlir::lmhlo_gpu::ReduceScatterStartOp; absl::Status RunAllReduce(ReductionKind reduction_kind, std::vector& buffers, se::Stream& stream, ncclComm_t comm) { -#if XLA_ENABLE_XCCL int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR(MaybeRegisterBuffers(device_ordinal, buffers, comm)); - TF_ASSIGN_OR_RETURN(ncclRedOp_t reduce_op, ToNcclReduction(reduction_kind)); - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllReduce)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - VLOG(3) << absl::StreamFormat( - "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, element_count, static_cast(comm), - gpu_stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer, - element_count, dtype, reduce_op, - comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::GroupStart()); + for (DeviceBufferPair& buffer : buffers) { + TF_RETURN_IF_ERROR(NcclApi::AllReduce( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count, reduction_kind, + reinterpret_cast(comm), &stream)); } - return XLA_NCCL_STATUS(ncclGroupEnd()); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL + + return NcclApi::GroupEnd(); } namespace { @@ -375,57 +348,31 @@ absl::Status NcclReduceScatterStartThunk::RunNcclCollective( absl::Status RunReduceScatter(ReductionKind reduction_kind, std::vector& buffers, se::Stream& stream, ncclComm_t comm) { -#if XLA_ENABLE_XCCL int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing reduce-scatter from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR(MaybeRegisterBuffers(device_ordinal, buffers, comm)); - TF_ASSIGN_OR_RETURN(ncclRedOp_t reduce_op, ToNcclReduction(reduction_kind)); - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - - int num_participants = 0; - XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants)); - - XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const void* send_buffer = buffer.source_buffer.opaque(); - void* recv_buffer = buffer.destination_buffer.opaque(); + TF_ASSIGN_OR_RETURN( + int32_t num_participants, + NcclApi::CommCount(reinterpret_cast(comm))); - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclReduceScatter)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; + TF_RETURN_IF_ERROR(NcclApi::GroupStart()); + for (DeviceBufferPair& buffer : buffers) { // buffer.element_count is the source buffers element count. For // ncclReduceScatter, we need the destination buffers element count. - TF_RET_CHECK(element_count % num_participants == 0) + TF_RET_CHECK(buffer.element_count % num_participants == 0) << "Source buffer was not an exact multiple of the number of " "participants."; - int64_t recv_count = element_count / num_participants; - VLOG(3) << absl::StreamFormat( - "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, " - "recvcount=%d, " - "comm=%p, stream=%p)", - send_buffer, recv_buffer, recv_count, static_cast(comm), - gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer, - recv_count, dtype, reduce_op, - comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::ReduceScatter( + buffer.source_buffer, buffer.destination_buffer, buffer.element_type, + buffer.element_count / num_participants, reduction_kind, + reinterpret_cast(comm), &stream)); } - XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); - VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal; - return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL + return NcclApi::GroupEnd(); } } // namespace gpu diff --git a/xla/service/gpu/nccl_all_reduce_thunk.h b/xla/service/gpu/nccl_all_reduce_thunk.h index ae0a703ecbb75c..0f3398dbbfd59d 100644 --- a/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/xla/service/gpu/nccl_all_reduce_thunk.h @@ -20,12 +20,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/status.h" +#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { @@ -132,10 +135,24 @@ absl::Status RunAllReduce(ReductionKind reduction_kind, std::vector& buffers, se::Stream& stream, ncclComm_t comm); +inline absl::Status RunAllReduce(ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, NcclCommHandle comm) { + return RunAllReduce(reduction_kind, buffers, stream, + reinterpret_cast(comm)); +} + absl::Status RunReduceScatter(ReductionKind reduction_kind, std::vector& buffers, se::Stream& stream, ncclComm_t comm); +inline absl::Status RunReduceScatter(ReductionKind reduction_kind, + std::vector& buffers, + se::Stream& stream, NcclCommHandle comm) { + return RunReduceScatter(reduction_kind, buffers, stream, + reinterpret_cast(comm)); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/nccl_all_to_all_thunk.cc b/xla/service/gpu/nccl_all_to_all_thunk.cc index dd5f6757783edc..95b28e1bb59f44 100644 --- a/xla/service/gpu/nccl_all_to_all_thunk.cc +++ b/xla/service/gpu/nccl_all_to_all_thunk.cc @@ -22,21 +22,21 @@ limitations under the License. #include #include "absl/status/status.h" -#include "absl/strings/str_format.h" #include "absl/strings/substitute.h" #include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" #include "tsl/platform/errors.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -150,59 +150,44 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective( absl::Status RunAllToAll(bool has_split_dimension, std::vector& buffers, se::Stream& stream, ncclComm_t comm) { -#if XLA_ENABLE_XCCL int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal; - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); + TF_ASSIGN_OR_RETURN( + int32_t num_participants, + NcclApi::CommCount(reinterpret_cast(comm))); - int num_participants; - XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants)); + TF_RETURN_IF_ERROR(NcclApi::GroupStart()); - XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); // AllToAll can operate in two modes. Either it specifies a split dimension, // in which case inputs are split and outputs concatenated in that dimension // (here, we only support dimension 0), or it takes a list of inputs // and produces a tuple of outputs. if (has_split_dimension) { - for (size_t i = 0; i < buffers.size(); ++i) { - DeviceBufferPair& buffer = buffers[i]; - const uint8_t* send_buffer = - static_cast(buffer.source_buffer.opaque()); - uint8_t* recv_buffer = - static_cast(buffer.destination_buffer.opaque()); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllToAll)); - auto [dtype, multiplier] = dtype_and_multiplier; - int64_t element_count = buffer.element_count; - - TF_RET_CHECK(element_count % num_participants == 0) + for (DeviceBufferPair& buffer : buffers) { + TF_RET_CHECK(buffer.element_count % num_participants == 0) << "Buffer was not an exact multiple of the number of participants."; - size_t chunk_elements = element_count / num_participants; - size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType( - buffer.element_type); - - for (int rank = 0; rank < num_participants; ++rank) { - VLOG(3) << absl::StreamFormat( - "Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - send_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank, - static_cast(comm), gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes, - chunk_elements * multiplier, dtype, - rank, comm, gpu_stream)); - - VLOG(3) << absl::StreamFormat( - "Calling ncclRecv(recvbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - recv_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank, - static_cast(comm), gpu_stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes, - chunk_elements * multiplier, dtype, - rank, comm, gpu_stream)); + + size_t chunk_elements = buffer.element_count / num_participants; + + for (int peer = 0; peer < num_participants; ++peer) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase send_slice, + NcclApi::Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements)); + + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase recv_slice, + NcclApi::Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements)); + + TF_RETURN_IF_ERROR(NcclApi::Send( + send_slice, buffer.element_type, chunk_elements, peer, + reinterpret_cast(comm), &stream)); + + TF_RETURN_IF_ERROR(NcclApi::Recv( + recv_slice, buffer.element_type, chunk_elements, peer, + reinterpret_cast(comm), &stream)); } } } else { @@ -211,45 +196,18 @@ absl::Status RunAllToAll(bool has_split_dimension, for (size_t i = 0; i < buffers.size(); ++i) { DeviceBufferPair& buffer = buffers[i]; - const uint8_t* send_buffer = - static_cast(buffer.source_buffer.opaque()); - uint8_t* recv_buffer = - static_cast(buffer.destination_buffer.opaque()); - - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclAllToAll)); - auto [dtype, multiplier] = dtype_and_multiplier; - int64_t element_count = buffer.element_count * multiplier; - - VLOG(3) << absl::StreamFormat( - "Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - send_buffer, element_count, i, static_cast(comm), - gpu_stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype, - /*rank=*/i, comm, gpu_stream)); - - VLOG(3) << absl::StreamFormat( - "Calling ncclRecv(recvbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - recv_buffer, element_count, i, static_cast(comm), - gpu_stream); - - XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype, - /*rank=*/i, comm, gpu_stream)); + + TF_RETURN_IF_ERROR(NcclApi::Send( + buffer.source_buffer, buffer.element_type, buffer.element_count, i, + reinterpret_cast(comm), &stream)); + + TF_RETURN_IF_ERROR(NcclApi::Recv( + buffer.destination_buffer, buffer.element_type, buffer.element_count, + i, reinterpret_cast(comm), &stream)); } } - XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); - - VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal; - return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL + + return NcclApi::GroupEnd(); } } // namespace gpu diff --git a/xla/service/gpu/nccl_all_to_all_thunk.h b/xla/service/gpu/nccl_all_to_all_thunk.h index 5bd44b9aee5fd6..352e76c0327798 100644 --- a/xla/service/gpu/nccl_all_to_all_thunk.h +++ b/xla/service/gpu/nccl_all_to_all_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" namespace xla { @@ -69,6 +70,13 @@ absl::Status RunAllToAll(bool has_split_dimension, std::vector& buffers, se::Stream& stream, ncclComm_t comm); +inline absl::Status RunAllToAll(bool has_split_dimension, + std::vector& buffers, + se::Stream& stream, NcclCommHandle comm) { + return RunAllToAll(has_split_dimension, buffers, stream, + reinterpret_cast(comm)); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/nccl_api.cc b/xla/service/gpu/nccl_api.cc new file mode 100644 index 00000000000000..d47466fb5f52d2 --- /dev/null +++ b/xla/service/gpu/nccl_api.cc @@ -0,0 +1,335 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/nccl_api.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "third_party/nccl/nccl.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +//==-----------------------------------------------------------------------===// +// Macros to return or warn on NCCL errors. +//==-----------------------------------------------------------------------===// + +static absl::Status ToStatus(ncclResult_t s, const char* file, int64_t line, + const char* expr) { + if (s == ncclSuccess) return absl::OkStatus(); + + return absl::InternalError(absl::StrFormat( + "%s:%d: NCCL operation %s failed: %s." + " Last NCCL warning(error) log entry (may be unrelated) '%s'.", + file, line, expr, ncclGetErrorString(s), ncclGetLastError(nullptr))); +} + +#define XLA_NCCL_STATUS(expr) \ + xla::gpu::ToStatus(expr, __FILE__, __LINE__, #expr) + +#define XLA_NCCL_RETURN_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + return s; \ + } \ + } while (0) + +#define XLA_NCCL_LOG_IF_ERROR(expr) \ + do { \ + absl::Status s = XLA_NCCL_STATUS(expr); \ + if (!s.ok()) { \ + LOG(ERROR) << s.ToString(); \ + } \ + } while (0) + +//==-----------------------------------------------------------------------===// +// Conversions between XLA and NCCL data types +//==-----------------------------------------------------------------------===// + +static size_t ToNcclCount(PrimitiveType dtype, size_t count) { + return primitive_util::IsComplexType(dtype) ? count * 2 : count; +} + +static absl::StatusOr ToNcclDataType(PrimitiveType dtype, + bool is_reduction_op) { + switch (dtype) { + case S8: + case F8E5M2: + case F8E4M3FN: + return ncclInt8; + case PRED: + case U8: + return ncclUint8; + case S32: + return ncclInt32; + case U32: + return ncclUint32; + case S64: + return ncclInt64; + case U64: + return ncclUint64; + case F16: + return ncclFloat16; + case F32: + case C64: + return ncclFloat32; + case F64: + case C128: + return ncclFloat64; + case S16: + case U16: + // For reductions we expect 16 bit integer types to be promoted to 32-bit. + if (is_reduction_op) { + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported data type for reduction operation: %s", + primitive_util::LowercasePrimitiveTypeName(dtype))); + } + // For collectives that just move data around, we can use ncclFloat16 for + // 16-bit integer data types. + return ncclFloat16; + case BF16: + return ncclBfloat16; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Unsupported data type: %s", + primitive_util::LowercasePrimitiveTypeName(dtype))); + } +} + +static ncclRedOp_t ToNcclReduction(ReductionKind kind) { + switch (kind) { + case ReductionKind::SUM: + return ncclSum; + case ReductionKind::PRODUCT: + return ncclProd; + case ReductionKind::MIN: + return ncclMin; + case ReductionKind::MAX: + return ncclMax; + } +} + +static std::string_view ToString(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return "sum"; + case ReductionKind::PRODUCT: + return "prod"; + case ReductionKind::MIN: + return "min"; + case ReductionKind::MAX: + return "max"; + } +} + +//==-----------------------------------------------------------------------===// +// NcclApi +//==-----------------------------------------------------------------------===// + +static_assert(NCCL_UNIQUE_ID_BYTES == NcclCliqueId::kSize, + "size of nccl unique id must match the clique id size"); + +static NcclCommHandle Cast(ncclComm_t comm) { + return reinterpret_cast(comm); +} + +static ncclComm_t Cast(NcclCommHandle comm) { + return reinterpret_cast(comm); +} + +static ncclUniqueId AsNcclUniqueId(const NcclCliqueId& clique_id) { + ncclUniqueId id; + absl::c_copy(clique_id.data(), id.internal); + return id; +} + +absl::StatusOr NcclApi::Slice(se::DeviceMemoryBase buff, + PrimitiveType dtype, + size_t offset, + size_t count) { + size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype); + return buff.GetByteSlice(offset * multiplier, count * multiplier); +} + +absl::StatusOr NcclApi::GetUniqueId() { + VLOG(3) << "Get NCCL unique id"; + ncclUniqueId id; + XLA_NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&id)); + return NcclCliqueId(id.internal); +} + +absl::StatusOr NcclApi::CommInitRank( + int32_t nranks, const NcclCliqueId& clique_id, int32_t rank) { + VLOG(1) << "Initialize NCCL communicator for rank #" << rank << " of " + << nranks << "; hash(id)=" << absl::HashOf(clique_id.data()); + + if (rank < 0 || rank >= nranks) + return absl::InvalidArgumentError(absl::StrFormat( + "Invalid rank %d, it must be in [0, %d) range", rank, nranks)); + + ncclComm_t comm = nullptr; + absl::Status status = XLA_NCCL_STATUS( + ncclCommInitRank(&comm, nranks, AsNcclUniqueId(clique_id), rank)); + + return Cast(comm); +} + +absl::Status NcclApi::CommAbort(NcclCommHandle comm) { + VLOG(1) << "Abort NCCL communicator: " << comm; + return XLA_NCCL_STATUS(ncclCommAbort(Cast(comm))); +} + +absl::StatusOr NcclApi::CommCount(NcclCommHandle comm) { + VLOG(5) << "Get the number of ranks in NCCL communicator: " << comm; + int32_t count; + XLA_NCCL_RETURN_IF_ERROR(ncclCommCount(Cast(comm), &count)); + return count; +} + +absl::Status NcclApi::CommGetAsyncError(NcclCommHandle comm) { + VLOG(5) << "Get last async error for NCCL communicator: " << comm; + + ncclResult_t async_err; + XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(Cast(comm), &async_err)); + if (async_err == ncclSuccess) return absl::OkStatus(); + + return absl::InternalError(absl::StrCat( + ncclGetErrorString(async_err), + ". Last NCCL error (maybe unrelated): ", ncclGetLastError(Cast(comm)))); +} + +absl::Status NcclApi::GroupStart() { + VLOG(5) << "Start NCCL group"; + return XLA_NCCL_STATUS(ncclGroupStart()); +} + +absl::Status NcclApi::GroupEnd() { + VLOG(5) << "End NCCL group"; + return XLA_NCCL_STATUS(ncclGroupEnd()); +} + +absl::Status NcclApi::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL AllReduce operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; reduction_kind=%s; comm=%p; " + "stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, ToString(reduction_kind), comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclAllReduce( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, ToNcclReduction(reduction_kind), Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status NcclApi::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL ReduceScatter operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; reduction_kind=%s; comm=%p; " + "stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, ToString(reduction_kind), comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclReduceScatter( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, ToNcclReduction(reduction_kind), Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status NcclApi::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + NcclCommHandle comm, se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL AllGather operation on device #%d; send_buffer=%p; " + "recv_buffer=%p; dtype=%s; count=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + recv_buffer.opaque(), primitive_util::LowercasePrimitiveTypeName(dtype), + count, comm, stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS(ncclAllGather( + send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), + nccl_dtype, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status NcclApi::Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; " + "count=%d; peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), send_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm, + stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS( + ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +absl::Status NcclApi::Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; " + "count=%d; peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), recv_buffer.opaque(), + primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm, + stream); + + TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false)); + + return XLA_NCCL_STATUS( + ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, + peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/nccl_api.h b/xla/service/gpu/nccl_api.h new file mode 100644 index 00000000000000..775840c19f9ec2 --- /dev/null +++ b/xla/service/gpu/nccl_api.h @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_NCCL_API_H_ +#define XLA_SERVICE_GPU_NCCL_API_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "xla/xla_data.pb.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// NcclApi +//===----------------------------------------------------------------------===// + +// NcclApi hides implementation detail of collective operations built on top of +// NCCL library so that no other parts of XLA should include nccl.h header +// directly (or indirectly). + +struct NcclApi { + // Forward declarations of opaque structs corresponding to underlying platform + // types (also defined as opaque structs). + struct NcclComm; + + // Convenience handles for defining API functions. + using NcclCommHandle = NcclComm*; + + // Returns a slice of device memory `buff` containing `count` values of data + // type `dtype` starting from `offset`. + static absl::StatusOr Slice(se::DeviceMemoryBase buff, + PrimitiveType dtype, + size_t offset, + size_t count); + + // Creates a new unique clique id. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid + static absl::StatusOr GetUniqueId(); + + // Creates a new communicator. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrank + static absl::StatusOr CommInitRank( + int32_t nranks, const NcclCliqueId& clique_id, int32_t rank); + + // Frees resources that are allocated to a communicator object comm. Will + // abort any uncompleted operations before destroying the communicator. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommabort + static absl::Status CommAbort(NcclCommHandle comm); + + // Returns the number of ranks in the NCCL communicator comm. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount + static absl::StatusOr CommCount(NcclCommHandle comm); + + // Queries the progress and potential errors of asynchronous operations + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommgetasyncerror + static absl::Status CommGetAsyncError(NcclCommHandle comm); + + // Starts a group call. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupstart + static absl::Status GroupStart(); + + // Ends a group call. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupend + static absl::Status GroupEnd(); + + // Reduce buffers of length `count` in `send_buff` using `reduction_kind` + // reduction and leaves identical copies of the result on each `recv_buff`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallreduce + static absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream); + + // Reduce data in `send_buff` from all GPUs using the `reduction_kind` + // operation and leave the reduced result scattered over the devices so that + // the `recv_buff` on rank `i` will contain the i-th block of the result. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclreducescatter + static absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + NcclCommHandle comm, se::Stream* stream); + + // Gather `count` values from all GPUs into recv_buffer, receiving data from + // rank `i` at offset `i * sendcount`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallgather + static absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + NcclCommHandle comm, se::Stream* stream); + + // Send data from `send_buff` to rank `peer`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend + static absl::Status Send(se::DeviceMemoryBase send_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream); + + // Receive data from rank `peer` into `recv_buff`. + // + // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv + static absl::Status Recv(se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, int32_t peer, + NcclCommHandle comm, se::Stream* stream); +}; + +//===----------------------------------------------------------------------===// +// NcclApi Handles +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): Remove these aliases once all users migrated to new API. +using NcclCommHandle = NcclApi::NcclCommHandle; // NOLINT + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_NCCL_API_H_ diff --git a/xla/service/gpu/nccl_api_stub.cc b/xla/service/gpu/nccl_api_stub.cc new file mode 100644 index 00000000000000..2c53dafc807ec4 --- /dev/null +++ b/xla/service/gpu/nccl_api_stub.cc @@ -0,0 +1,87 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" +#include "xla/service/gpu/nccl_clique_key.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" + +namespace xla::gpu { + +absl::StatusOr NcclApi::GetUniqueId() { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::StatusOr NcclApi::CommInitRank(int32_t, + const NcclCliqueId&, + int32_t) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::CommAbort(NcclCommHandle) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::StatusOr NcclApi::CommCount(NcclCommHandle) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::CommGetAsyncError(NcclCommHandle) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::GroupStart() { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::GroupEnd() { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::AllReduce(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, ReductionKind, + NcclCommHandle, se::Stream*) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::ReduceScatter(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, ReductionKind, + NcclCommHandle, se::Stream*) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::AllGather(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, NcclCommHandle, + se::Stream*) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::Send(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, + NcclCommHandle, se::Stream*) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +absl::Status NcclApi::Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, + NcclCommHandle, se::Stream*) { + return absl::UnimplementedError("XLA compiled without NCCL support"); +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/nccl_clique.cc b/xla/service/gpu/nccl_clique.cc index 875e00f7a9d031..f028dc39225ce6 100644 --- a/xla/service/gpu/nccl_clique.cc +++ b/xla/service/gpu/nccl_clique.cc @@ -18,16 +18,13 @@ limitations under the License. #include #include #include -#include #include #include #include -#include "absl/algorithm/container.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" @@ -38,9 +35,8 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/service/global_device_id.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" -#include "xla/service/gpu/nccl_errors.h" -#include "xla/service/gpu/nccl_types.h" #include "xla/service/lockable.h" #include "xla/service/rendezvous.h" #include "xla/status_macros.h" @@ -62,14 +58,7 @@ bool IsGlobalNcclConfig() { // Creates a new NCCL unique id for local communication. static absl::StatusOr LocalNcclUniqueId(const NcclCliqueKey&) { -#ifdef XLA_ENABLE_XCCL - NcclUniqueId id; - XLA_NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&id)); - static_assert(sizeof(NcclUniqueId) == sizeof(NcclCliqueId), - "size of nccl unique id must match the clique id"); - return NcclCliqueId(id.internal); -#endif - return absl::InternalError("XLA compiled without NCCL support."); + return NcclApi::GetUniqueId(); } absl::StatusOr GetNcclCliqueIdCallback( @@ -91,7 +80,7 @@ absl::StatusOr GetNcclCliqueIdCallback( namespace { struct NcclCliqueState { - NcclUniqueId unique_id; + NcclCliqueId clique_id; int64_t run_id = -1; // `mu` guards `communicators` and `status` during initialization. @@ -115,17 +104,6 @@ struct NcclCliques { absl::node_hash_map cliques ABSL_GUARDED_BY(mu); }; -absl::StatusOr ToNcclUniqueId(const NcclCliqueId& id) { -#ifdef XLA_ENABLE_XCCL - static_assert(sizeof(NcclUniqueId) == sizeof(NcclCliqueId), - "size of nccl unique id must match the clique id"); - NcclUniqueId nccl_id; - absl::c_copy(id.data(), nccl_id.internal); - return nccl_id; -#endif - return absl::InternalError("XLA compiled without NCCL support."); -} - std::shared_ptr> AcquireNcclClique( RunId run_id, OpId op_id, NcclCliqueKey clique_key, const NcclCliqueIdCallback& clique_id_callback, @@ -161,8 +139,8 @@ std::shared_ptr> AcquireNcclClique( const NcclCliqueKey& clique_key = std::get<2>(rendezvous_key); NcclClique::Lock clique = cliques[clique_key].Acquire(); if (clique->run_id < 0) { - TF_ASSIGN_OR_RETURN(NcclCliqueId id, clique_id_callback(clique_key)); - TF_ASSIGN_OR_RETURN(clique->unique_id, ToNcclUniqueId(id)); + TF_ASSIGN_OR_RETURN(clique->clique_id, + clique_id_callback(clique_key)); } // If multiple executable are running simultaneously while using // multiple hosts, it is possible that different executables could @@ -181,7 +159,6 @@ std::shared_ptr> AcquireNcclClique( // Adds NCCL communicator to a global per-process state that tracks NCCL // communicators health. void TrackNcclCommunicatorHealth(NcclComm* comm) { -#ifdef XLA_ENABLE_XCCL struct AllCommunicators { absl::Mutex mu; std::vector communicators ABSL_GUARDED_BY(mu); @@ -199,19 +176,14 @@ void TrackNcclCommunicatorHealth(NcclComm* comm) { NcclCommHandle comm = *lockable_comm->Acquire(); if (comm == nullptr) return absl::OkStatus(); - NcclStatus async_err; - XLA_NCCL_RETURN_IF_ERROR(ncclCommGetAsyncError(comm, &async_err)); - - if (async_err != ncclSuccess) { + absl::Status async_err = NcclApi::CommGetAsyncError(comm); + if (!async_err.ok()) { LOG(ERROR) << "Aborting communicator: " << comm - << " due to async NCCL error: " - << ncclGetErrorString(async_err) - << ". Last NCCL warning(error) log entry (may be unrelated): " - << ncclGetLastError(nullptr); - XLA_NCCL_RETURN_IF_ERROR(ncclCommAbort(comm)); + << " due to async NCCL error: " << async_err; + TF_RETURN_IF_ERROR(NcclApi::CommAbort(comm)); } - return XLA_NCCL_STATUS(async_err); + return async_err; }; // Launch a thread that periodically checks all NCCL communicators for @@ -233,7 +205,6 @@ void TrackNcclCommunicatorHealth(NcclComm* comm) { } }); (void)check_async_error_thread; // Silence unused variable warning. -#endif } } // namespace @@ -243,7 +214,6 @@ absl::StatusOr AcquireNcclComm( size_t num_local_participants, const NcclCliqueIdCallback& clique_id_callback, int32_t rank, int64_t stream_id, bool enable_clique_optimization) { -#ifdef XLA_ENABLE_XCCL // Ensure that this group of threads have exclusive access to the clique to // prevent threads from different groups locking communicators in the clique. // The enable_clique_optimization value is only used for asynchronous @@ -264,19 +234,18 @@ absl::StatusOr AcquireNcclComm( if (!state.ready.HasBeenNotified()) { int nranks = clique_key.devices().size(); - const ncclUniqueId& id = state.unique_id; - VLOG(3) << "Initialize NCCL communicator for rank #" << rank << " of " - << nranks << "; id=" << absl::HashOf(absl::MakeSpan(id.internal)); - - ncclComm_t comm = nullptr; - absl::Status status = - XLA_NCCL_STATUS(ncclCommInitRank(&comm, nranks, id, rank)); + absl::StatusOr comm = + NcclApi::CommInitRank(nranks, state.clique_id, rank); size_t num_initialized = [&] { absl::MutexLock lock(&state.mu); - state.status.Update(status); - state.communicators[rank] = std::make_unique(comm); + if (comm.ok()) { + state.communicators[rank] = std::make_unique(*comm); + } else { + state.status.Update(comm.status()); + state.communicators[rank] = std::make_unique(nullptr); + } return state.communicators.size(); }(); @@ -288,7 +257,7 @@ absl::StatusOr AcquireNcclComm( if (num_initialized == num_local_participants) { state.ready.Notify(); } else { - TF_RETURN_IF_ERROR(status); + TF_RETURN_IF_ERROR(comm.status()); state.ready.WaitForNotification(); } @@ -298,9 +267,6 @@ absl::StatusOr AcquireNcclComm( TF_RETURN_IF_ERROR(state.status); return state.communicators[rank]->Acquire(); -#endif - - return absl::InternalError("XLA compiled without NCCL support."); } } // namespace xla::gpu diff --git a/xla/service/gpu/nccl_clique.h b/xla/service/gpu/nccl_clique.h index 6d9109dd339ca2..0e3c3c75b40f73 100644 --- a/xla/service/gpu/nccl_clique.h +++ b/xla/service/gpu/nccl_clique.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/executable_run_options.h" #include "xla/service/global_device_id.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_clique_key.h" -#include "xla/service/gpu/nccl_types.h" #include "xla/service/lockable.h" #include "tsl/lib/gtl/int_type.h" diff --git a/xla/service/gpu/nccl_collective_permute_thunk.cc b/xla/service/gpu/nccl_collective_permute_thunk.cc index 348f22da510454..1f470690590564 100644 --- a/xla/service/gpu/nccl_collective_permute_thunk.cc +++ b/xla/service/gpu/nccl_collective_permute_thunk.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/service/gpu/nccl_collective_thunk.h" #include "xla/service/gpu/nccl_p2p_thunk_common.h" #include "xla/service/gpu/thunk.h" @@ -274,11 +275,10 @@ absl::Status RunCollectivePermute( device_string, current_id, source_id.value_or(-1), target_id.value_or(-1)); - // ncclGroupStart/end API is needed only if we will issue both ncclSend and - // ncclRecv API calls. + // GroupStart/End API is needed only if we will issue both send & recv calls. const bool is_nccl_group_needed = (target_id && source_id); if (is_nccl_group_needed) { - XLA_NCCL_RETURN_IF_ERROR(ncclGroupStart()); + TF_RETURN_IF_ERROR(NcclApi::GroupStart()); } TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, @@ -311,7 +311,7 @@ absl::Status RunCollectivePermute( *source_id, comm, gpu_stream)); } if (is_nccl_group_needed) { - XLA_NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + TF_RETURN_IF_ERROR(NcclApi::GroupEnd()); } if (!source_id) { diff --git a/xla/service/gpu/nccl_collective_thunk.cc b/xla/service/gpu/nccl_collective_thunk.cc index 973e1f0815fba3..a05a2629c26088 100644 --- a/xla/service/gpu/nccl_collective_thunk.cc +++ b/xla/service/gpu/nccl_collective_thunk.cc @@ -413,14 +413,15 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { // Run the collective on main stream or using the async executor. absl::Status status = [&]() { if (!IsAsync()) { - return RunNcclCollective(params, *params.stream, *comm); + return RunNcclCollective(params, *params.stream, + reinterpret_cast(*comm)); } return async_->Execute( [this](const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) { return RunNcclCollective(params, stream, comm); }, - params, *comm, GetAsyncStreamKind()); + params, reinterpret_cast(*comm), GetAsyncStreamKind()); }(); TF_RETURN_IF_ERROR(status); diff --git a/xla/service/gpu/nccl_recv_thunk.cc b/xla/service/gpu/nccl_recv_thunk.cc index e7039534d06b6f..595196de4183a0 100644 --- a/xla/service/gpu/nccl_recv_thunk.cc +++ b/xla/service/gpu/nccl_recv_thunk.cc @@ -15,20 +15,18 @@ limitations under the License. #include "xla/service/gpu/nccl_recv_thunk.h" +#include #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/stream_executor/stream.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif +#include "tsl/platform/errors.h" namespace xla { namespace gpu { @@ -102,7 +100,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, absl::string_view device_string, int64_t current_id) { -#if XLA_ENABLE_XCCL // Determine the source IDs for this instance. The source ID is the ID for // the peer that will copy its data to this instance. If there is no source, // just memzero() the destination buffer. @@ -116,23 +113,12 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d", device_string, current_id, source_id.value_or(-1)); - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - // Receive data from the source peer to the destination buffer. if (source_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " - "stream=%p)", - device_string, dest_addr.opaque(), element_count, *source_id, - static_cast(comm), gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, - *source_id, comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::Recv( + dest_addr, buffer.element_type, buffer.element_count, *source_id, + reinterpret_cast(comm), &stream)); + } else { // If there is no source peer, i.e. no sender to this instance, zero out // the destination buffer. @@ -141,11 +127,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target, stream.ThenMemZero(&dest_addr, dest_addr.size()); } return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL } } // namespace gpu diff --git a/xla/service/gpu/nccl_send_thunk.cc b/xla/service/gpu/nccl_send_thunk.cc index 263bdc79d0aa1c..24ba2f68535410 100644 --- a/xla/service/gpu/nccl_send_thunk.cc +++ b/xla/service/gpu/nccl_send_thunk.cc @@ -15,20 +15,18 @@ limitations under the License. #include "xla/service/gpu/nccl_send_thunk.h" +#include #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/collective_ops_utils.h" +#include "xla/service/gpu/nccl_api.h" #include "xla/stream_executor/stream.h" - -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif +#include "tsl/platform/errors.h" namespace xla { namespace gpu { @@ -102,10 +100,8 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, absl::string_view device_string, int64_t current_id) { -#if XLA_ENABLE_XCCL // Determine the target IDs for this instance. The target ID is the ID // to which this instance will copy its data. - int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing collective permute from device ordinal: " << device_ordinal << "current_id " << current_id; @@ -116,30 +112,14 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target, VLOG(3) << absl::StreamFormat("%s : id = %d, target_id = %d", device_string, current_id, target_id.value_or(-1)); - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - // Send source buffer to target peer if needed. if (target_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - device_string, src_addr.opaque(), element_count, *target_id, - static_cast(comm), gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, - *target_id, comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::Send( + src_addr, buffer.element_type, buffer.element_type, *target_id, + reinterpret_cast(comm), &stream)); } + return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL } } // namespace gpu diff --git a/xla/service/gpu/nccl_types.h b/xla/service/gpu/nccl_types.h index 339cc8f878071a..c9ff598a033dc6 100644 --- a/xla/service/gpu/nccl_types.h +++ b/xla/service/gpu/nccl_types.h @@ -34,7 +34,6 @@ namespace xla::gpu { #if defined(XLA_ENABLE_XCCL) -using NcclCommHandle = ncclComm_t; using NcclDataType = ncclDataType_t; using NcclRedOp = ncclRedOp_t; using NcclStatus = ncclResult_t; @@ -46,7 +45,6 @@ using NcclUniqueId = ncclUniqueId; // pointers and always return errors in implementations. By doing this we can // keep all XLA headers compilable even if NCCL is not available and do not // spread ifdefs throughout the code base. -using NcclCommHandle = void*; using NcclDataType = void*; using NcclRedOp = void*; using NcclStatus = void*; diff --git a/xla/service/gpu/nccl_utils.cc b/xla/service/gpu/nccl_utils.cc index 69537a3a6ea057..2d40f683b2c53e 100644 --- a/xla/service/gpu/nccl_utils.cc +++ b/xla/service/gpu/nccl_utils.cc @@ -172,15 +172,18 @@ absl::Status NcclPersistentPlanAllocator::Deallocate( ScopedNcclPersistentPlanAllocator::ScopedNcclPersistentPlanAllocator( NcclComm::Lock* comm, ncclPersistentPlanAllocator* allocator) : comm_(comm) { - CHECK(ncclCommGetPersistentPlanAllocator(**comm_, &recover_) == ncclSuccess) + CHECK(ncclCommGetPersistentPlanAllocator( + reinterpret_cast(**comm_), &recover_) == ncclSuccess) << "Failed to get NCCL persistent plan allocator"; - CHECK(ncclCommSetPersistentPlanAllocator(**comm, allocator) == ncclSuccess) + CHECK(ncclCommSetPersistentPlanAllocator( + reinterpret_cast(**comm_), allocator) == ncclSuccess) << "Faield to set NCCL persistent plan allocator"; } ScopedNcclPersistentPlanAllocator::~ScopedNcclPersistentPlanAllocator() { - CHECK(ncclCommSetPersistentPlanAllocator(**comm_, recover_) == ncclSuccess) + CHECK(ncclCommSetPersistentPlanAllocator( + reinterpret_cast(**comm_), recover_) == ncclSuccess) << "Faield to set NCCL persistent plan allocator"; } #endif diff --git a/xla/service/gpu/runtime/collectives.cc b/xla/service/gpu/runtime/collectives.cc index 514a5335293018..a0e2282c1ffb0e 100644 --- a/xla/service/gpu/runtime/collectives.cc +++ b/xla/service/gpu/runtime/collectives.cc @@ -297,8 +297,9 @@ absl::Status MockNcclP2PImplCommon( const NcclP2PConfig::SourceTargetMapEntry source_target = NcclP2PConfig::GetSourceTarget(id_to_source_target, current_id); - return runner(source_target, (*device_buffers)[0], *stream, **comm, - device_string, current_id); + return runner(source_target, (*device_buffers)[0], *stream, + reinterpret_cast(**comm), device_string, + current_id); } absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, @@ -353,8 +354,9 @@ absl::Status P2PImplCommon(const ServiceExecutableRunOptions* run_options, return RunRepeated(debug_options->xla_gpu_collective_inflation_factor(), [&]() -> absl::Status { return runner(source_target, (*device_buffers)[0], - *stream, **comm, device_string, - current_id); + *stream, + reinterpret_cast(**comm), + device_string, current_id); }); } #endif // XLA_ENABLE_XCCL