diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index d35fc911c692..a5240aa2b2c5 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -101,8 +101,12 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { ncclUniqueId id; std::memcpy(id.internal, unique_id_bytes.data(), NCCL_UNIQUE_ID_BYTES); NCCL_CALL(ncclCommInitRank(&ctx->global_comm, worker->num_workers, id, worker->worker_id)); - NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, - worker->worker_id % group_size, &ctx->group_comm, NULL)); + if (worker->num_groups == 1) { + ctx->group_comm = ctx->global_comm; + } else { + NCCL_CALL(ncclCommSplit(ctx->global_comm, worker->worker_id / group_size, + worker->worker_id % group_size, &ctx->group_comm, NULL)); + } } void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) {