From e3328feea438de1de4a879f18e88557f56f2f173 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Fri, 6 Dec 2024 14:16:30 +0000 Subject: [PATCH] Manually migrate bd changes to 2.18 branch --- tensorflow/core/framework/device_base.h | 2 + tensorflow/core/nccl/nccl_manager.cc | 824 +++++++++++++++++------- tensorflow/core/nccl/nccl_manager.h | 86 ++- 3 files changed, 678 insertions(+), 234 deletions(-) diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 065707fde4b8c2..301d4c639cd7bb 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -165,6 +165,8 @@ class DeviceBase { int gpu_id = -1; }; + using GpuDeviceInfo = AcceleratorDeviceInfo; + // Does not take ownership. void set_tensorflow_accelerator_device_info( AcceleratorDeviceInfo* device_info) { diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index cf3ceb670ba717..7b1472ee2f5a42 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -18,7 +18,14 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "absl/base/call_once.h" +#include +#include +#include +#include +#include +#include +#include + #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -36,6 +43,30 @@ limitations under the License. namespace tensorflow { +#if defined(USE_ROCM) +#elif TENSORFLOW_USE_ROCM +using se::rocm::ScopedActivateExecutorContext; +using stream_executor::gpu::ScopedActivateContext; +#define cudaError_t hipError_t +// Local hipify of cuda symbols +#define cudaStream_t hipStream_t +#define cudaError_t hipError_t +#define cudaGetErrorString hipGetErrorString +#define cudaStream_t hipStream_t +#define cudaGetDevice hipGetDevice +#define cudaGetErrorString hipGetErrorString +#define cudaSetDevice hipSetDevice +#define cudaGetDevice hipGetDevice +#define cudaSuccess hipSuccess +#define cudaSetDevice hipSetDevice +int NcclManager::instance_count = 0; +#define cudaSuccess hipSuccess +#else +int NcclManager::instance_count = 0; +using tensorflow::se::cuda::ScopedActivateExecutorContext; +#endif +#endif + using stream_executor::gpu::ScopedActivateContext; #if TENSORFLOW_USE_ROCM // Local hipify of cuda symbols @@ -89,9 +120,14 @@ struct NcclManager::NcclStream : public core::RefCounted { // signals `cv` to unblock the thread waiting on more collectives. mutex mu; condition_variable cv; - // Has (collective, participant_idx) pairs. +// Has (collective, participant_idx) pairs. +#ifdef USE_TF215 std::deque> pending_launches_ TF_GUARDED_BY(mu); bool shutdown_requested TF_GUARDED_BY(mu) = false; +#else + std::deque> pending_launches_ GUARDED_BY(mu); + bool shutdown_requested GUARDED_BY(mu) = false; +#endif }; struct NcclManager::CommunicatorMember { @@ -140,12 +176,256 @@ ncclDataType_t ToNcclType(DataType t) { } } +std::size_t NcclTypeSize(ncclDataType_t t) { + switch (t) { + case ncclHalf: + return 2; +#ifndef DISABLE_BFLOAT16 + case ncclBfloat16: + return 2; +#endif + case ncclFloat: + return sizeof(float); + case ncclDouble: + return sizeof(double); + case ncclInt: + return sizeof(int32_t); + case ncclInt64: + return sizeof(int64_t); + + default: + return sizeof(float); + } +} + void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) { if (str_id.size() == NCCL_UNIQUE_ID_BYTES) { memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES); } } +struct netIf { + char prefix[64]; + int port; +}; + +union socketAddress { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +struct ncclBootstrapHandle { + uint64_t magic; + union socketAddress addr; +}; + +int parseStringList(const char* string, struct netIf* ifList, int maxList) { + if (!string) return 0; + + // Ignore 'NCCL_COMM_ID=' + const char* ptr = string + NCCL_COMM_ID_LEN; + + int ifNum = 0; + int ifC = 0; + char c; + do { + c = *ptr; + if (c == ':') { + if (ifC > 0) { + ifList[ifNum].prefix[ifC] = '\0'; + ifList[ifNum].port = atoi(ptr + 1); + ifNum++; + ifC = 0; + } + while (c != ',' && c != '\0') c = *(++ptr); + } else if (c == ',' || c == '\0') { + if (ifC > 0) { + ifList[ifNum].prefix[ifC] = '\0'; + ifList[ifNum].port = -1; + ifNum++; + ifC = 0; + } + } else { + ifList[ifNum].prefix[ifC] = c; + ifC++; + } + ptr++; + } while (ifNum < maxList && c); + return ifNum; +} + +ncclResult_t GetNcclUniqueIdFromString(ncclUniqueId* id, const char* comm_id) { + memset(id, 0, sizeof(ncclUniqueId)); + auto handler = (struct ncclBootstrapHandle*)(id); + std::string str(comm_id); + std::cout << "comm_id :" << comm_id << std::endl; + size_t end = str.find('+'); + handler->magic = std::stoull(str.substr(end + 1, str.size())); + auto str_ip_port = str.substr(0, end); + const char* ip_port_pair = str_ip_port.c_str(); + union socketAddress* ua = &(handler->addr); + + if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { + std::cout << "Net : string is null" << std::endl; + return ncclInvalidArgument; + } + + std::cout << "Net address :" << ip_port_pair << std::endl; + bool ipv6 = ip_port_pair[NCCL_COMM_ID_LEN] == '['; + /* Construct the sockaddress structure */ + if (!ipv6) { + struct netIf ni; + // parse : string, expect one pair + if (parseStringList(ip_port_pair, &ni, 1) != 1) { + std::cout << "Net : No valid : pair found" + << std::endl; + return ncclInvalidArgument; + } + struct addrinfo hints, *p; + int rv; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) { + std::cout << "Net : error encountered when getting address info : " + << gai_strerror(rv) << std::endl; + return ncclInvalidArgument; + } + + // use the first + if (p->ai_family == AF_INET) { + struct sockaddr_in& sin = ua->sin; + memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in)); + sin.sin_family = AF_INET; // IPv4 + // inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address + sin.sin_port = htons(ni.port); // port + } else if (p->ai_family == AF_INET6) { + struct sockaddr_in6& sin6 = ua->sin6; + memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6)); + sin6.sin6_family = AF_INET6; // IPv6 + sin6.sin6_port = htons(ni.port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = 0; // should be global scope, set to 0 + } else { + VLOG(0) << "Net : unsupported IP family"; + return ncclInvalidArgument; + } + + freeaddrinfo(p); // all done with this structure + + } else { + // Ignore 'NCCL_COMM_ID=' + const char* ptr = ip_port_pair + NCCL_COMM_ID_LEN; + int i, j = -1, len = strlen(ptr); + for (i = 1; i < len; i++) { + if (ptr[i] == '%') j = i; + if (ptr[i] == ']') break; + } + if (i == len) { + std::cout << "Net : No valid [IPv6]:port pair found" << std::endl; + return ncclInvalidArgument; + } + bool global_scope = + (j == -1 + ? true + : false); // If no % found, global scope; otherwise, link scope + + char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ]; + memset(ip_str, '\0', sizeof(ip_str)); + memset(port_str, '\0', sizeof(port_str)); + memset(if_name, '\0', sizeof(if_name)); + strncpy(ip_str, ptr + 1, global_scope ? i - 1 : j - 1); + strncpy(port_str, ptr + i + 2, len - i - 1); + int port = atoi(port_str); + if (!global_scope) + strncpy(if_name, ptr + j + 1, + i - j - 1); // If not global scope, we need the intf name + + struct sockaddr_in6& sin6 = ua->sin6; + sin6.sin6_family = AF_INET6; // IPv6 + inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address + sin6.sin6_port = htons(port); // port + sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete + sin6.sin6_scope_id = + global_scope + ? 0 + : if_nametoindex( + if_name); // 0 if global scope; intf index if link scope + } + return ncclSuccess; +} + +const char* socketToString(struct sockaddr* saddr, char* buf) { + if (buf == NULL || saddr == NULL) return NULL; + if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { + buf[0] = '\0'; + return buf; + } + char host[NI_MAXHOST], service[NI_MAXSERV]; + (void)getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, + service, NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV); + sprintf(buf, "%s<%s>", host, service); + return buf; +} + +// nccl op sequential +std::atomic_int NCCL_KERNEL_LAUNCH_SEQ(0); +std::array NCCL_STREAM; +class IncrSeqHook { + public: + IncrSeqHook(tensorflow::se::Stream* s, int32_t seq_launch_len, + int32_t seq_launch_idx) + : seq_launch_idx(seq_launch_idx) { + stream = s; + seq_launch_next = + (seq_launch_idx == seq_launch_len - 1 ? 0 : seq_launch_idx + 1); + } + ~IncrSeqHook() { + NCCL_STREAM[seq_launch_idx] = stream; + int32_t expected = seq_launch_idx; + if (!NCCL_KERNEL_LAUNCH_SEQ.compare_exchange_strong(expected, + seq_launch_next)) { + LOG(FATAL) << "invalid op_seq " << NCCL_KERNEL_LAUNCH_SEQ + << " vs expected " << seq_launch_idx; + } + // LOG(INFO) << "reset seq to " << seq_launch_next; + } + + private: + tensorflow::se::Stream* stream; + int32_t seq_launch_idx; + int32_t seq_launch_next; +}; + +#define WAIT_FOR_KERNEL_LAUNCH_SEQ(stream, seq_launch_len, seq_launch_idx) \ + std::unique_ptr __TMP_INCR_SEQ_HOOK__ = nullptr; \ + do { \ + if (seq_launch_len) { \ + while (seq_launch_idx != NCCL_KERNEL_LAUNCH_SEQ) { \ + } \ + if (seq_launch_idx) { \ + stream->WaitFor(NCCL_STREAM[seq_launch_idx - 1]).IgnoreError(); \ + } \ + __TMP_INCR_SEQ_HOOK__ = std::make_unique( \ + stream, seq_launch_len, seq_launch_idx); \ + } \ + } while (0) + +void ThreadSetNameOnce(const std::string& name) { + thread_local bool name_set = false; + if (name_set) { + return; + } + auto ret = pthread_setname_np(pthread_self(), name.c_str()); + if (ret != 0) { + LOG(WARNING) << "ThreadSetName failed for " << name; + return; + } + name_set = true; +} + } // namespace // A `Collective` encapsulates state for a collective instance at one node. @@ -161,7 +441,8 @@ struct NcclManager::Collective : public core::RefCounted { Collective(const string& collective_key_in, DataType data_type_in, CollectiveType type_in, ncclRedOp_t reduction_op_in, int num_local_devices_in, int num_global_devices_in, - const string& communicator_key_in) + const string& communicator_key_in, int seq_launch_len = 0, + int seq_launch_idx = 0) : collective_key(collective_key_in), data_type(data_type_in), type(type_in), @@ -169,7 +450,9 @@ struct NcclManager::Collective : public core::RefCounted { num_local_devices(num_local_devices_in), num_global_devices(num_global_devices_in), single_node(num_local_devices_in == num_global_devices_in), - communicator_key(communicator_key_in) { + communicator_key(communicator_key_in), + seq_launch_len(seq_launch_len), + seq_launch_idx(seq_launch_idx) { participants.reserve(num_local_devices_in); #if TENSORFLOW_USE_ROCM // On ROCm platform, this allows caller to either use the singleton instance @@ -221,19 +504,37 @@ struct NcclManager::Collective : public core::RefCounted { uint64 trace_context = 0; Status status; + + // if set to greater than zero, collective from all streams will be launched + // one by one according to launch idx + int32_t seq_launch_len = 0; + int32_t seq_launch_idx = 0; }; NcclManager::NcclManager() { VLOG(2) << "New NcclManager " << this; -#if TENSORFLOW_USE_ROCM +#if USE_ROCM ++instance_count; #endif } NcclManager::~NcclManager() { VLOG(2) << "~NcclManager " << this; -#if TENSORFLOW_USE_ROCM +#if USE_ROCM --instance_count; #endif + char* env = getenv("JAGUAR_LOCAL_RANKS"); + if (env) { + local_ranks_ = atoi(env); + } + env = getenv("JAGUAR_WORKER_INDEX"); + if (env) { + worker_index_ = atoi(env); + } + env = getenv("JAGUAR_WORKER_COUNT"); + if (env) { + worker_count_ = atoi(env); + } + global_ranks_ = local_ranks_ * worker_count_; for (auto& it : device_to_comm_streams_) { for (NcclStream* nccl_stream : it.second) { { @@ -247,7 +548,7 @@ NcclManager::~NcclManager() { } NcclManager* NcclManager::instance() { static NcclManager* instance = new NcclManager(); -#if TENSORFLOW_USE_ROCM +#if USE_ROCM // singleton does not count against total instances // see comment above in Collective constructor concerning ROCm platform static absl::once_flag once; @@ -256,12 +557,108 @@ NcclManager* NcclManager::instance() { return instance; } -string NcclManager::GenerateCommunicatorKey() { +string NcclManager::GenerateCommunicatorKey(char* nccl_comm_id, int rank, + bool init_step) { ncclUniqueId nccl_id; ncclGetUniqueId(&nccl_id); return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES); } +int NcclManager::LocalRanks() { return local_ranks_; } + +int NcclManager::WorkerIndex() { return worker_index_; } + +int NcclManager::WorkerCount() { return worker_count_; } + +int NcclManager::GlobalRanks() { return global_ranks_; } + +tensorflow::Status NcclManager::CreateCommunicator( + tensorflow::se::StreamExecutor* executor, +#ifdef USE_TF215 + const tensorflow::DeviceBase::AcceleratorDeviceInfo* info, +#else + const tensorflow::DeviceBase::GpuDeviceInfo* info, +#endif + const tensorflow::string& communicator_key, int global_rank) { + if (LocalRanks() != 1) { + return tensorflow::errors::Internal( + "NcclManager::CreateCommunicator only support LocalRanks=1"); + } + tensorflow::mutex_lock l(mu_); + if (communicator_key.size() != NCCL_UNIQUE_ID_BYTES) { + return tensorflow::errors::Internal( + "Expected communicator_key of size ", NCCL_UNIQUE_ID_BYTES, + " but found size ", communicator_key.size()); + } + + // This is an instance of multi-node collective. We have previously + // created a NCCL unique id and shared with all workers. Now we find the + // `Communicator` corresponding to this id. + for (auto& comm : communicators_) { + if (comm->key == communicator_key) { + return STATUS_OK; + } + } + + auto* env = tensorflow::Env::Default(); + // Create and initialize a new communicator. + // Note that this is done under the lock; performance is not expected to + // matter as this happens a very small number of times. + std::vector members(LocalRanks()); + int device_id = info->gpu_id; + + // Find a communication stream to use for the device. + auto& streams = device_to_comm_streams_[executor]; + NcclStream* nccl_stream; + nccl_stream = new NcclStream(); + nccl_stream->executor = executor; + VLOG(2) << "Create new stream"; +#if USE_ROCM + auto stream_or_status = executor->CreateStream(); + nccl_stream->stream = stream_or_status->get(); +#else + nccl_stream->stream.reset(new tensorflow::se::Stream(executor)); + nccl_stream->stream->Init(); +#endif + + streams.emplace_back(nccl_stream); + // used_streams.insert(nccl_stream); + + nccl_stream->Ref(); + env->SchedClosure([this, nccl_stream]() { + LoopKernelLaunches(nccl_stream); + nccl_stream->Unref(); + }); + + members[0].nccl_stream = nccl_stream; + + ncclComm_t nccl_comm; + VLOG(2) << "Create new communicator"; + // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL + // group primitives. + ncclUniqueId nccl_id; + StringToNcclUniqueId(communicator_key, &nccl_id); + char line_a[SOCKET_NAME_MAXLEN + 1]; + union socketAddress& addr = ((struct ncclBootstrapHandle*)(&nccl_id))->addr; + VLOG(2) << "Try to init rank " << global_rank << " " + << socketToString(&(addr.sa), line_a); + VLOG(2) << "NCCL_COMM_ID " << getenv("NCCL_COMM_ID"); + int saved_device = 0; + CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device)); + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + NCCL_RETURN_IF_ERROR( + ncclCommInitRank(&nccl_comm, GlobalRanks(), nccl_id, global_rank)); + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device)); + + members[0].nccl_comm = nccl_comm; + + communicators_.emplace_back( + new Communicator(std::move(members), communicator_key)); + return STATUS_OK; +} + Status NcclManager::GetCommunicator(NcclManager::Collective* collective, NcclManager::Communicator** communicator) { // Sort by device ID, executor, and global rank to make ordering of @@ -279,9 +676,9 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, }); mutex_lock l(mu_); - if (!status_.ok()) { - return status_; - } + // if (!status_.ok()) { // Removed? + // return status_; + //} if (collective->communicator_key.empty()) { // For single-node collectives, when the caller does not specify a @@ -316,7 +713,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, } if (i == collective->num_local_devices) { *communicator = comm.get(); - return OkStatus(); + return STATUS_OK; } } } @@ -341,6 +738,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, } } + VLOG(0) << "ERROR: NcclManager lazy ncclInitRank should never got called"; auto* env = Env::Default(); std::set used_streams; @@ -355,16 +753,16 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, // Find a communication stream to use for the device. auto& streams = device_to_comm_streams_[executor]; NcclStream* nccl_stream = nullptr; - for (const auto& s : streams) { - if (used_streams.insert(s).second) { - nccl_stream = s; - break; - } - } + // for (const auto& s : streams) { + // if (used_streams.insert(s).second) { + // nccl_stream = s; + // break; + //} + //} if (nccl_stream == nullptr) { nccl_stream = new NcclStream(); nccl_stream->executor = executor; -#if TENSORFLOW_USE_ROCM +#if USE_ROCM nccl_stream->stream = collective->participants[i]->context->nccl_stream(); #else TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); @@ -372,7 +770,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, #endif streams.emplace_back(nccl_stream); - used_streams.insert(nccl_stream); + // used_streams.insert(nccl_stream); nccl_stream->Ref(); env->SchedClosure([this, nccl_stream]() { @@ -386,23 +784,17 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective, } std::vector nccl_comms(collective->num_local_devices); - VLOG(2) << "Created nccl Communicator with " - << "num_global_devices = " << collective->num_global_devices - << " num_local_devices = " << collective->num_local_devices - << " communicator_key =" - << absl::StrJoin( - std::vector{collective->communicator_key.begin(), - collective->communicator_key.end()}, - " "); + VLOG(2) << "Create new communicator"; + #if NCCL_MAJOR >= 2 // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL // group primitives. ncclUniqueId nccl_id; - if (collective->single_node) { - NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); - } else { - StringToNcclUniqueId(collective->communicator_key, &nccl_id); - } + StringToNcclUniqueId(collective->communicator_key, &nccl_id); + char line_a[SOCKET_NAME_MAXLEN + 1]; + union socketAddress& addr = ((struct ncclBootstrapHandle*)(&nccl_id))->addr; + VLOG(2) << "Try to init rank " << socketToString(&(addr.sa), line_a); + int saved_device = 0; CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device)); NCCL_RETURN_IF_ERROR(ncclGroupStart()); @@ -459,159 +851,103 @@ void NcclManager::AddToAllToAll(std::unique_ptr participant, ncclSum /* unused */); } -void NcclManager::AddBroadcastSend(std::unique_ptr participant, - const Context& context) { - participant->root = true; - AddParticipant(std::move(participant), context, kBroadcast, - ncclSum /* unused */); -} - -void NcclManager::AddBroadcastRecv(std::unique_ptr participant, - const Context& context) { +void NcclManager::AddBroadcast(std::unique_ptr participant, + const Context& context) { AddParticipant(std::move(participant), context, kBroadcast, ncclSum /* unused */); } -void NcclManager::AddReduceSend(std::unique_ptr participant, - const Context& context, - ncclRedOp_t reduction_op) { - AddParticipant(std::move(participant), context, kReduce, reduction_op); -} - -void NcclManager::AddReduceRecv(std::unique_ptr participant, - const Context& context, - ncclRedOp_t reduction_op) { - participant->root = true; - AddParticipant(std::move(participant), context, kReduce, reduction_op); -} - -void NcclManager::SignalMultiNodeReady(const string& collective_key) { - Collective* to_run = nullptr; - { - mutex_lock l(mu_); - auto collective_it = collectives_.find(collective_key); - if (collective_it != collectives_.end()) { - Collective* collective = collective_it->second; - collective->multi_node_ready = true; - if (CheckReady(collective_key, collective)) { - to_run = collective; - } - VLOG(2) << "SignalMultiNodeReady collective " << collective_key - << " to_run " << to_run; - } - } - - if (to_run != nullptr) RunCollective(to_run); -} - void NcclManager::AddParticipant(std::unique_ptr participant, const Context& context, CollectiveType collective_type, ncclRedOp_t reduction_op) { Collective* to_run = nullptr; - DataType data_type; - Status nccl_manager_status; - if (participant->input != nullptr) { - data_type = participant->input->dtype(); - } else { - data_type = participant->output->dtype(); + tensorflow::DataType data_type; + if (participant->inputs.size() > 0) { + data_type = participant->inputs[0]->dtype(); + } else if (participant->outputs.size() > 0) { + data_type = participant->outputs[0]->dtype(); } { - mutex_lock l(mu_); - nccl_manager_status = status_; - if (nccl_manager_status.ok()) { - auto collective_it = collectives_.find(context.collective_key); - Collective* collective = nullptr; - if (collective_it == collectives_.end()) { - collective = new Collective( - context.collective_key, data_type, collective_type, reduction_op, - context.num_local_devices, context.num_global_devices, - context.communicator_key); - collectives_.emplace(context.collective_key, collective); - } else { - collective = collective_it->second; - } - - // Check `collective` is correct and consistent. - if (collective->status.ok() && !collective->single_node && - collective->communicator_key.empty()) { - collective->status = errors::Internal( - "Collective ", reduction_op, - " is multi node with num_local_devices=", - collective->num_local_devices, - " and num_global_devices=", collective->num_global_devices, - " but has an empty communicator_key"); - } - if (collective->status.ok() && collective->communicator_key.size() != - context.communicator_key.size()) { - collective->status = - errors::Internal("Collective ", reduction_op, - " mismatch in member communicator_key with size ", - collective->communicator_key.size(), - " and arg communicator_key with size ", - context.communicator_key.size()); - } - if (collective->status.ok() && collective->type != collective_type) { - collective->status = errors::Internal( - "Collective ", reduction_op, " previously initialized with type ", - collective->type, " but now got type ", collective_type); - } - if (collective->status.ok() && - collective->num_global_devices != context.num_global_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - " previously initialized with num_global_devices ", - collective->num_global_devices, " but now got ", - context.num_global_devices); - } - if (collective->status.ok() && - collective->num_local_devices != context.num_local_devices) { - collective->status = - errors::Internal("Collective ", reduction_op, - "previously initialized with num_local_devices ", - collective->num_local_devices, " but now got ", - context.num_local_devices); - } - if (collective->status.ok() && - collective->participants.size() >= collective->num_local_devices) { - collective->status = errors::Internal( - "Collective ", reduction_op, " expected ", - collective->num_local_devices, " participants but now has ", - collective->participants.size(), - " with one more participant being added"); - } - if (collective->status.ok() && collective->root_rank >= 0 && - context.source_rank >= 0 && - collective->root_rank != context.source_rank) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, - " already has root_rank ", collective->root_rank, - " but new participant has root_rank ", context.source_rank); - } - if (collective->status.ok() && - !kValidDataTypes.Contains(collective->data_type)) { - collective->status = errors::Internal( - "Collective ", collective->collective_key, - " expected data types compatible with NCCL but instead got ", - DataTypeString(collective->data_type)); - } - - if (context.source_rank >= 0) { - collective->root_rank = context.source_rank; - } + tensorflow::mutex_lock l(mu_); + auto collective_it = collectives_.find(context.collective_key); + Collective* collective = nullptr; + if (collective_it == collectives_.end()) { + collective = + new Collective(context.collective_key, data_type, collective_type, + reduction_op, context.num_local_devices, + context.num_global_devices, context.communicator_key, + context.seq_launch_len, context.seq_launch_idx); + collectives_.emplace(context.collective_key, collective); + } else { + collective = collective_it->second; + } + // Check `collective` is correct and consistent. + if (collective->status.ok() && !collective->single_node && + collective->communicator_key.empty()) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, " is multi node with num_local_devices=", + collective->num_local_devices, + " and num_global_devices=", collective->num_global_devices, + " but has an empty communicator_key"); + } + if (collective->status.ok() && collective->communicator_key.size() != + context.communicator_key.size()) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, + " mismatch in member communicator_key with size ", + collective->communicator_key.size(), + " and arg communicator_key with size ", + context.communicator_key.size()); + } + if (collective->status.ok() && collective->type != collective_type) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, " previously initialized with type ", + collective->type, " but now got type ", collective_type); + } + if (collective->status.ok() && + collective->num_global_devices != context.num_global_devices) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, + " previously initialized with num_global_devices ", + collective->num_global_devices, " but now got ", + context.num_global_devices); + } + if (collective->status.ok() && + collective->num_local_devices != context.num_local_devices) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, + "previously initialized with num_local_devices ", + collective->num_local_devices, " but now got ", + context.num_local_devices); + } + if (collective->status.ok() && + collective->participants.size() >= collective->num_local_devices) { + collective->status = tensorflow::errors::Internal( + "Collective ", reduction_op, " expected ", + collective->num_local_devices, " participants but now has ", + collective->participants.size(), + " with one more participant being added"); + } + if (collective->status.ok() && collective->root_rank >= 0 && + context.source_rank >= 0 && + collective->root_rank != context.source_rank) { + collective->status = tensorflow::errors::Internal( + "Collective ", collective->collective_key, " already has root_rank ", + collective->root_rank, " but new participant has root_rank ", + context.source_rank); + } - collective->participants.emplace_back(std::move(participant)); - ++collective->available_participants; + if (context.source_rank >= 0) { + collective->root_rank = context.source_rank; + } + collective->participants.emplace_back(std::move(participant)); + ++collective->available_participants; - if (CheckReady(context.collective_key, collective)) { - to_run = collective; - } + if (CheckReady(context.collective_key, collective)) { + to_run = collective; } } - if (!nccl_manager_status.ok()) { - participant->done_callback(nccl_manager_status); - return; - } if (to_run != nullptr) RunCollective(to_run); } @@ -646,11 +982,11 @@ void NcclManager::RunCollective(Collective* collective) { CHECK(nccl_stream != nullptr); const int rank = p->global_rank >= 0 ? p->global_rank : i; - if (p->input != nullptr) { + if (p->inputs.size() > 0 && p->inputs[0] != nullptr) { // Wait to ensure that the kernel that produces the data in the input // tensor has finished running before the nccl kernel runs on the // communication stream. - status = nccl_stream->stream->WaitFor(p->tensor_stream); + nccl_stream->stream->WaitFor(p->tensor_stream).IgnoreError(); } if (p->root) { if (collective->root_rank == -1) { @@ -714,13 +1050,13 @@ size_t ComputeBufferSize(const NcclManager::Participant* p, } // namespace void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { -#if TENSORFLOW_USE_ROCM +#if USE_ROCM se::Stream* comm_stream = nccl_stream->stream; #else se::Stream* comm_stream = nccl_stream->stream.get(); #endif ScopedActivateContext scoped_context(nccl_stream->executor); - cudaStream_t cu_stream = reinterpret_cast( + cudaStream_t* cu_stream = reinterpret_cast( comm_stream->platform_specific_handle().stream); while (true) { @@ -750,10 +1086,29 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { Participant* p = collective->participants[p_idx].get(); auto nccl_comm = collective->communicator->members[p_idx].nccl_comm; ncclResult_t nccl_result = ncclSuccess; + + // sequential launch + WAIT_FOR_KERNEL_LAUNCH_SEQ(comm_stream, collective->seq_launch_len, + collective->seq_launch_idx); + std::shared_ptr nccl_time_cost(new cpputil::TimeCost); + p->event_mgr->ThenExecute(comm_stream, + [nccl_time_cost] { nccl_time_cost->reset(); }); + std::string group_key = collective->collective_key; + std::string metric_key; + for (int i = 0; i < group_key.size(); i++) { + if (group_key[i] == ';') { + metric_key = group_key.substr(0, i); + break; + } + } +#ifdef COMPILING_JAGUAR + GlobalCudaTimerManager()->StartCudaTimer(metric_key, *cu_stream, +#endif + switch (collective->type) { case kAllReduce: { - const void* sendbuff = p->input->tensor_data().data(); - void* recvbuff = const_cast(p->output->tensor_data().data()); + const void* sendbuff = p->inputs[0]->tensor_data().data(); + void* recvbuff = const_cast(p->outputs[0]->tensor_data().data()); VLOG(2) << "call NcclAllReduce collective_key " << collective->collective_key << " participant " << p_idx @@ -776,9 +1131,9 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { const void* sendbuff = nullptr; void* recvbuff = nullptr; int num_elements = -1; - if (p->input) { - sendbuff = p->input->tensor_data().data(); - num_elements = p->input->NumElements(); + if (p->inputs.size() > 0) { + sendbuff = p->input[0]->tensor_data().data(); + num_elements = p->input[0]->NumElements(); } if (p->output) { recvbuff = const_cast(p->output->tensor_data().data()); @@ -793,11 +1148,13 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { collective->Unref(); continue; } - VLOG(2) << "call NcclBroadcast collective_key " - << collective->collective_key << " participant " << p_idx - << " sendbuff " << sendbuff << " recvbuff " << recvbuff - << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream - << " cuda_stream " << cu_stream; + LOG(INFO) << "call NcclBroadcast collective_key " + << collective->collective_key << " participant " << p_idx + << " sendbuff " << sendbuff << " recvbuff " << recvbuff + << " num_elements " << num_elements + << " collective root_rank " << collective->root_rank + << " nccl_comm " << nccl_comm << " comm_stream " + << comm_stream << " cuda_stream " << cu_stream; profiler::AnnotatedTraceMe traceme([&] { return profiler::TraceMeEncode( "ncclBroadcast", @@ -810,7 +1167,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { break; } case kReduce: { - const void* sendbuff = p->input->tensor_data().data(); + const void* sendbuff = p->inputs[0]->tensor_data().data(); void* recvbuff = p->output ? const_cast(p->output->tensor_data().data()) : nullptr; @@ -826,14 +1183,24 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { break; } case kAllGather: { - const void* sendbuff = p->input->tensor_data().data(); - void* recvbuff = const_cast(p->output->tensor_data().data()); + const void* sendbuff = nullptr; + void* recvbuff = nullptr; + int send_num_elements = -1; + int recv_num_elements = -1; + if (p->inputs.size() > 0) { + sendbuff = p->inputs[0]->tensor_data().data(); + send_num_elements = p->inputs[0]->NumElements(); + } + if (p->outputs.size() > 0) { + recvbuff = const_cast(p->outputs[0]->tensor_data().data()); + recv_num_elements = p->outputs[0]->NumElements(); + } VLOG(2) << "call NcclAllGather collective_key " << collective->collective_key << " participant " << p_idx << " sendbuff " << sendbuff << " sendcount " - << p->input->NumElements() << " recvbuff " << recvbuff - << " recvcount " << p->output->NumElements() << " nccl_comm " + << send_num_elements << " recvbuff " << recvbuff + << " recvcount " << recv_num_elements << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream << " cuda_stream " << cu_stream; profiler::AnnotatedTraceMe traceme([&] { @@ -868,51 +1235,72 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { break; } case kAllToAll: { - const char* sendbuff = p->input->tensor_data().data(); - char* recvbuff = const_cast(p->output->tensor_data().data()); - size_t count = - p->input->NumElements() / collective->participants.size(); - size_t rank_offset = count * DataTypeSize(collective->data_type); - - VLOG(2) << "call Nccl All to All collective_key " + VLOG(2) << "call NcclAlltoAll collective_key " << collective->collective_key << " participant " << p_idx - << " num_participants " << collective->participants.size() - << " sendbuff " << static_cast(sendbuff) - << " recvbuff " << static_cast(recvbuff) << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream << " cuda_stream " << cu_stream; - profiler::AnnotatedTraceMe traceme([&] { - return profiler::TraceMeEncode( - "ncclAllToAll", - {{"buffer_size", ComputeBufferSize(p, collective->data_type)}, - {"collective_type", "all_to_all"}}); - }); + int32_t total_sendcount = 0; + int32_t total_recvcount = 0; + ncclResult_t tmp_nccl_result = ncclSuccess; + void* sendptr = const_cast(p->inputs[0]->tensor_data().data()); + void* recvptr = const_cast(p->outputs[0]->tensor_data().data()); ncclGroupStart(); - for (int i = 0; i < collective->participants.size(); ++i) { - ncclSend(sendbuff + i * rank_offset, count, data_type, - collective->participants[i]->global_rank, nccl_comm, - cu_stream); - ncclRecv(recvbuff + i * rank_offset, count, data_type, - collective->participants[i]->global_rank, nccl_comm, - cu_stream); + for (int r = 0; r < collective->num_global_devices; ++r) { + const void* sendbuff = + sendptr + p->send_offsets[r] * NcclTypeSize(data_type); + int32_t sendcount = p->send_counts[r]; + total_sendcount += sendcount; + tmp_nccl_result = ncclSend(sendbuff, sendcount, data_type, r, + nccl_comm, *cu_stream); + if (tmp_nccl_result != ncclSuccess) nccl_result = tmp_nccl_result; + void* recvbuff = + recvptr + p->recv_offsets[r] * NcclTypeSize(data_type); + int32_t recvcount = p->recv_counts[r]; + total_recvcount += recvcount; + tmp_nccl_result = ncclRecv(recvbuff, recvcount, data_type, r, + nccl_comm, *cu_stream); + if (tmp_nccl_result != ncclSuccess) nccl_result = tmp_nccl_result; + } + nccl_result = tmp_nccl_result; + ncclGroupEnd(); + if (collective->collective_key.substr( + 0, ALLTOALL_BACKWARD_PREFIX_LEN) == "Grad") { + MetricAdapter::EmitTimer("nccl.bw_all2all.sendcount", + total_sendcount); + MetricAdapter::EmitTimer("nccl.bw_all2all.recvcount", + total_recvcount); + } else { + MetricAdapter::EmitTimer("nccl.fw_all2all.sendcount", + total_sendcount); + MetricAdapter::EmitTimer("nccl.fw_all2all.recvcount", + total_recvcount); } - nccl_result = ncclGroupEnd(); break; } } +#ifdef COMPILING_JAGUAR + GlobalCudaTimerManager()->StopCudaTimer(metric_key, *cu_stream); + GlobalSessionStatus()->worker_status()->nccl_status.push(metric_key); +#endif + // Run the done_callback when the nccl kernel finishes running. - auto done_callback = [collective, p_idx, nccl_result]() { + auto done_callback = [collective, p_idx, nccl_result, nccl_time_cost, + metric_key]() { VLOG(2) << "done Nccl kernel collective_key " << collective->collective_key << " participant " << p_idx << " ncclResult " << nccl_result; if (nccl_result == ncclSuccess) { - collective->participants[p_idx]->done_callback(OkStatus()); + collective->participants[p_idx]->done_callback(STATUS_OK); + MetricAdapter::TagkvList tag = {{"key", metric_key}}; + MetricAdapter::EmitTimer("nccl.runtime", nccl_time_cost->get_elapsed(), + tag); } else { // Propagate the error, but note that if other members of the collective // did launch their kernels, then they are hanging. - collective->participants[p_idx]->done_callback(errors::Unknown( - "Error invoking NCCL: ", ncclGetErrorString(nccl_result))); + collective->participants[p_idx]->done_callback( + tensorflow::errors::Unknown("Error invoking NCCL: ", + ncclGetErrorString(nccl_result))); } collective->Unref(); }; diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index 0e62013949bc1a..2312f3d20d71ec 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -43,6 +43,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "xla/pjrt/event_pool.h" + +#define NCCL_COMM_ID_LEN 13 +#define ALLTOALL_BACKWARD_PREFIX_LEN 4 +#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV) +#define NCCL_COMM_ID_MAXLEN (NCCL_COMM_ID_LEN + SOCKET_NAME_MAXLEN) namespace tensorflow { @@ -59,36 +65,64 @@ class NcclManager { static NcclManager* instance(); -#if TENSORFLOW_USE_ROCM +#if USE_ROCM static int instance_count; #endif // Calls `ncclGetUniqueId` and returns the id as a string. The returned value // may be shared with other participants on different nodes and passed in to // multi-node collective invocations. - string GenerateCommunicatorKey(); + string GenerateCommunicatorKey(char* nccl_comm_id, int rank, bool init_step); + + int LocalRanks(); + int WorkerIndex(); + int WorkerCount(); + int GlobalRanks(); // A participant in a Collective. struct Participant { - Participant(se::StreamExecutor* executor, se::Stream* tensor_stream, - const DeviceBase::AcceleratorDeviceInfo* info, - const Tensor* input, Tensor* output, int global_rank, + Participant(tensorflow::se::StreamExecutor* executor, + tensorflow::se::Stream* tensor_stream, +#ifdef USE_TF215 + const tensorflow::DeviceBase::AcceleratorDeviceInfo* info, +#else + const tensorflow::DeviceBase::GpuDeviceInfo* info, +#endif + std::vector input_tensors, + std::unique_ptr input_evt, + std::vector outputs, + std::vector send_offsets, + std::vector recv_offsets, + std::vector send_counts, + std::vector recv_counts, int global_rank, DoneCallback done_callback) : executor(executor), tensor_stream(tensor_stream), event_mgr(info->event_mgr), gpu_device_id(info->gpu_id), -#if TENSORFLOW_USE_ROCM +#if USE_ROCM context(static_cast(info->default_context)), #endif - input(input), - output(output), + inputs(std::move(input_tensors)), + input_event(std::move(input_evt)), + outputs(std::move(outputs)), + send_offsets(std::move(send_offsets)), + recv_offsets(std::move(recv_offsets)), + send_counts(std::move(send_counts)), + recv_counts(std::move(recv_counts)), global_rank(global_rank), done_callback(std::move(done_callback)), + event_pool(false), root(false) { DCHECK(executor != nullptr); DCHECK(event_mgr != nullptr); DCHECK(tensor_stream != nullptr); + if (inputs.size() > 0 && inputs[0] != nullptr && !input_event) { + absl::StatusOr handel_event = + event_pool.ThenAllocateAndRecordEvent(tensor_stream); + input_event = std::unique_ptr( + handel_event.value().event()); + } } // StreamExecutor for the device. Expected to be live for process lifetime. @@ -109,17 +143,27 @@ class NcclManager { const int gpu_device_id; -#if TENSORFLOW_USE_ROCM +#if USE_ROCM GPUDeviceContext* const context; #endif // Owned by the caller, who must keep it live until `done_callback` is // called. Is NULL for participants that only receive data. - const Tensor* input; + std::vector inputs; + + // Wait on this event rather than synchronizing on the entire stream. + // This allows greater concurrency between compute and nccl streams. + std::unique_ptr input_event; // Owned by the caller, who must keep it live until `done_callback` is // called. Is NULL for participants that only send data. - Tensor* output; + std::vector outputs; + + // Split vector for alltoall, only for all2all + std::vector send_offsets; + std::vector recv_offsets; + std::vector send_counts; + std::vector recv_counts; // Rank across all devices and all nodes. // `global_rank` is not required for single-node collectives. @@ -132,6 +176,8 @@ class NcclManager { // True if this is the root of the collective, e.g. source of broadcast. bool root; + + xla::EventPool event_pool; }; // Data that provides context for the collective operation, including the @@ -181,10 +227,8 @@ class NcclManager { // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender // to all receivers. - void AddBroadcastSend(std::unique_ptr participant, - const Context& context); - void AddBroadcastRecv(std::unique_ptr participant, - const Context& context); + void AddBroadcast(std::unique_ptr participant, + const Context& context); // AddReduceSend and AddReduceRecv combine to send data from all senders // to one receiver. @@ -226,6 +270,7 @@ class NcclManager { struct Communicator; struct CommunicatorMember; struct NcclStream; + struct NcclProfiler; // Gets the `Communicator` object that will be used to enqueue NCCL kernels // for `collective`, and returns it via `communicator`. @@ -268,10 +313,19 @@ class NcclManager { absl::flat_hash_map> device_to_comm_streams_ TF_GUARDED_BY(mu_); +#ifdef USE_TF215 std::vector> communicators_ TF_GUARDED_BY(mu_); - +#else + std::vector> communicators_; +#endif + std::unique_ptr mgr_comm_; Status status_ TF_GUARDED_BY(mu_); + int local_ranks_ = -1; + int global_ranks_ = -1; + int worker_index_ = -1; + int worker_count_ = -1; + NcclManager(const NcclManager&) = delete; void operator=(const NcclManager&) = delete; };