diff --git a/horovod/common/operations.cc b/horovod/common/operations.cc index 227b31a065..f5af023c84 100644 --- a/horovod/common/operations.cc +++ b/horovod/common/operations.cc @@ -49,6 +49,7 @@ #if HAVE_CUDA #include "ops/msallreduce_cuda_operations.h" +#include "ops/msallreduce_cuda_ring_operations.h" #include "ops/cuda_operations.h" #include "ops/mpi_cuda_operations.h" #endif @@ -157,7 +158,7 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) { #if HOROVOD_GPU_ALLREDUCE == 'M' if (state.msallreduce_enabled == true){ LOG(INFO) << "msallGpureduce enabled."; - msallreduce_ops.push_back(std::shared_ptr(new MsCudaAllreduceOp(&mpi_context, &cuda_context, &state))); + msallreduce_ops.push_back(std::shared_ptr(new MsCudaRingAllreduceOp(&mpi_context, &cuda_context, &state))); } allreduce_ops.push_back(std::shared_ptr( new MPI_CUDAAllreduce(&mpi_context, &cuda_context, &state))); @@ -997,7 +998,7 @@ void BackgroundThreadLoop(HorovodGlobalState& state, MPIContext& ctx) { std::strtol(mpi_threads_disable, nullptr, 10) > 0) { required = MPI_THREAD_SINGLE; } -#if HAVE_MLSL +#if HAVE_MLSLf // MLSL comes with Intel MPI // and needs to initialize MPI with the proper configuration. mlsl_context.Init(); @@ -1068,7 +1069,7 @@ if(state.msallreduce_enabled == true) { } delete[] node_rank; - } + } // TODO parasail new algo end } diff --git a/horovod/common/ops/cuda/msallreduce_cuda_kernels.cu b/horovod/common/ops/cuda/msallreduce_cuda_kernels.cu index 2b62f3adde..0b01dc9645 100644 --- a/horovod/common/ops/cuda/msallreduce_cuda_kernels.cu +++ b/horovod/common/ops/cuda/msallreduce_cuda_kernels.cu @@ -45,6 +45,16 @@ void CudaScaleAddKernel(int count, T* a, const T* b, TACC a_coeff, TACC b_coeff) } } +template +__global__ +void ConvertToFloat(int count, T* a, float* b) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (count > index){ + b[index] = (float) a[index]; + } +} + + void CudaDotProductImpl(int count, const double* device_a, const double* device_b, double* device_normsq_a, double* device_normsq_b, double* device_dot, double& host_normsq_a, double& host_normsq_b, double& host_dot) { diff --git a/horovod/common/ops/msallreduce_cuda_kernels.h b/horovod/common/ops/msallreduce_cuda_kernels.h index 9f3f6a8490..a18d0d5a04 100644 --- a/horovod/common/ops/msallreduce_cuda_kernels.h +++ b/horovod/common/ops/msallreduce_cuda_kernels.h @@ -13,4 +13,33 @@ void CudaScaleAddImpl(int count, double* a_device, const double* b_device, doubl void CudaScaleAddImpl(int count, float* a_device, const float* b_device, double host_a_coeff, double host_b_coeff); -void CudaScaleAddImpl(int count, uint16_t* a_device, const uint16_t* b_device, double host_a_coeff, double host_b_coeff); \ No newline at end of file +void CudaScaleAddImpl(int count, uint16_t* a_device, const uint16_t* b_device, double host_a_coeff, double host_b_coeff); + +template +void MsCudaPairwiseReduce(int count, T* device_a, T* device_b){ + double normsq_a = 0.f; + double normsq_b = 0.f; + double dot = 0.f; + + double* device_normsq_a, * device_normsq_b, * device_dot; + cudaMalloc(&device_normsq_a, sizeof(double)); + cudaMalloc(&device_normsq_b, sizeof(double)); + cudaMalloc(&device_dot, sizeof(double)); + + CudaDotProductImpl(count, device_a, device_b, device_normsq_a, device_normsq_b, device_dot, normsq_a, normsq_b, dot); + + cudaFree(device_normsq_a); + cudaFree(device_normsq_b); + cudaFree(device_dot); + + double a_coeff = 1; + double b_coeff = 1; + if (normsq_a != 0) + a_coeff = 1.0 - dot / normsq_a * 0.5; + if (normsq_b != 0) + b_coeff = 1.0 - dot / normsq_b * 0.5; + + CudaScaleAddImpl(count, device_a, device_b, a_coeff, b_coeff); +} + + diff --git a/horovod/common/ops/msallreduce_cuda_ring_operations.cc b/horovod/common/ops/msallreduce_cuda_ring_operations.cc new file mode 100644 index 0000000000..c9221ae5fb --- /dev/null +++ b/horovod/common/ops/msallreduce_cuda_ring_operations.cc @@ -0,0 +1,568 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2019 Microsoft Corp. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "msallreduce_cuda_ring_operations.h" +#include "msallreduce_cuda_kernels.h" + +namespace horovod { +namespace common { + +using namespace msallreduce; + +#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +MsCudaRingAllreduceOp::MsCudaRingAllreduceOp(MPIContext* mpi_context, CUDAContext* cuda_context, HorovodGlobalState* global_state) + : MsAllreduceOp(mpi_context, global_state), mpi_context_(mpi_context), cuda_context_(cuda_context) { + } + +void MsCudaRingAllreduceOp::InitCUDA(const TensorTableEntry& entry, int layerid) { + cuda_context_->ErrorCheck("cudaSetDevice", cudaSetDevice(entry.device)); + + LOG(INFO, global_state_->rank)<<"Checking for existing stream for layer "<streams[global_state_->current_nccl_stream][layerid]; + if (stream == nullptr) { + + std::lock_guard guard(global_state_->mutex); + if (stream == nullptr) { + LOG(INFO, global_state_->rank)<<"Stream is null, creating new stream "<ErrorCheck("cudaDeviceGetStreamPriorityRange", + cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority)); + cuda_context_->ErrorCheck("cudaStreamCreateWithPriority", + cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, greatest_priority)); + } + } + cudaStream_t& device_stream = cuda_context_->streams[global_state_->current_nccl_stream][entry.device]; + if (device_stream == nullptr) { + std::lock_guard guard(global_state_->mutex); + if (device_stream == nullptr) { + LOG(INFO, global_state_->rank)<<"device Stream is null, creating new device stream "<ErrorCheck("cudaDeviceGetStreamPriorityRange", + cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority)); + cuda_context_->ErrorCheck("cudaStreamCreateWithPriority", + cudaStreamCreateWithPriority(&device_stream, cudaStreamNonBlocking, greatest_priority)); + + } + } +} + +Status MsCudaRingAllreduceOp::Execute(std::vector& entries, const Response& response) { + if(entries.size() < 1) { + return Status::OK(); + } + //TODO how do we report statuses? + std::map return_statuses; + int num_reductions = entries.size(); + AllRings all_rings(global_state_->local_rank, global_state_->local_size); + std::deque used_buffer_managers; + std::deque recv_buffers; + LOG(INFO, global_state_->rank)<<"Ready to process "<data(); + + buffer_len = entry.output->size(); + + if(entry.tensor->data() == entry.output->data()) { + LOG(INFO, global_state_->rank)<<"Output and input pointing to same data. Creating temp buffer "<current_nccl_stream, + []() {}, + []() {}, + [](int64_t& size, int64_t& threshold) {return size >= threshold;}); + + if (!status.ok()) { + throw std::logic_error("MsAllreduceOp::Execute_helper: Initialize buffer failed."); + } + auto& buffer = buffer_manager.GetBuffer(entry.device, entry.context->framework(), global_state_->current_nccl_stream); + recv_buffer = const_cast(buffer->AccessData(entry.context)); + } + else { + recv_buffer = (void*) entry.output->data(); + } + recv_buffers.push_back(recv_buffer); + LOG(INFO, global_state_->rank)<<"Begin to process gpu tensor with size "<size()<<" into output buffer with size "<size()<<" "<rank)<<"Begin processing gpu tensor in layer "<dtype(), + global_state_->local_comm, + layerid, + global_state_->local_rank); + } + all_rings.WaitAllMessages(); + // Return used buffer managers to the queue + buffer_managers_.insert(buffer_managers_.end(), used_buffer_managers.begin(), used_buffer_managers.end()); + + int local_rank = 0; + MPI_Comm_rank(global_state_->local_comm, &local_rank); + if (local_rank == 0 && global_state_->rank_log_size != 0) { + std::vector> allreduce_buffers; + + // start device to host copies + for (size_t layerid = 0; layerid < entries.size(); ++layerid) { + auto& entry = entries.at(layerid); + int buffer_len = entry.output->size(); + allreduce_buffers.emplace_back(new char[buffer_len]); + char* buffer_data = allreduce_buffers.at(layerid).get(); + + auto cuda_result = cudaMemcpyAsync( + buffer_data, (void*) entry.tensor->data(), + buffer_len, + cudaMemcpyDeviceToHost, + cuda_context_->streams[global_state_->current_nccl_stream][layerid]); + cuda_context_->ErrorCheck("cudaMemcpyAsync", cuda_result); + } + + for (size_t layerid = 0; layerid < entries.size(); ++layerid) { + auto& entry = entries.at(layerid); + int buffer_len = entry.output->size(); + char* buffer_data = allreduce_buffers.at(layerid).get(); + std::unique_ptr recv_buffer(new char[buffer_len]); + + // wait for this layer to finish copying to host + auto cuda_result = cudaStreamSynchronize(cuda_context_->streams[global_state_->current_nccl_stream][layerid]); + cuda_context_->ErrorCheck("cudaStreamSynchronize", cuda_result); + + MPI_Comm* node_comm = &global_state_->reduction_comms[global_state_->rank_log_size-1]; + switch (entry.output->dtype()) { + case HOROVOD_FLOAT16: + SyncAllreduce( + (uint16_t*) buffer_data, + (uint16_t*) recv_buffer.get(), + buffer_len / sizeof(uint16_t), + *node_comm, + global_state_->reduction_comms, + layerid, + entry, + ComputeDotAndNormSqrdsfp16, + ScaledAddfp16); + break; + case HOROVOD_FLOAT32: + SyncAllreduce( + (float*) buffer_data, + (float*) recv_buffer.get(), + buffer_len / sizeof(float), + *node_comm, + global_state_->reduction_comms, + layerid, + entry, + ComputeDotAndNormSqrds, + ScaledAdd); + break; + case HOROVOD_FLOAT64: + SyncAllreduce( + (double*) buffer_data, + (double*) recv_buffer.get(), + buffer_len / sizeof(double), + *node_comm, + global_state_->reduction_comms, + layerid, + entry, + ComputeDotAndNormSqrds, + ScaledAdd); + break; + default: + throw std::logic_error("MsAllreduceOp::Execute: Unsupported data type."); + } + + // start the copy back to device + cuda_result = cudaMemcpyAsync( + (void*) entry.tensor->data(), buffer_data, + buffer_len, + cudaMemcpyHostToDevice, + cuda_context_->streams[global_state_->current_nccl_stream][layerid]); + cuda_context_->ErrorCheck("cudaMemcpyAsync", cuda_result); + } + + // wait for all copies to device to finish + for (size_t layerid = 0; layerid < entries.size(); ++layerid) { + auto cuda_result = cudaStreamSynchronize(cuda_context_->streams[global_state_->current_nccl_stream][layerid]); + cuda_context_->ErrorCheck("cudaStreamSynchronize", cuda_result); + } + } + + for (size_t layerid = 0; layerid < entries.size(); ++layerid) { + auto& entry = entries.at(layerid); + void* buffer_data; + int buffer_len; + + buffer_data = (void*) entry.tensor->data(); + + buffer_len = entry.output->size(); + + LOG(INFO, global_state_->rank)<<"Begin to process gpu tensor with size "<size()<<" into output buffer with size "<size()<<" "<rank)<<"Begin processing gpu tensor in layer "<dtype(), + global_state_->local_comm, + layerid, + global_state_->local_rank); + } + all_rings.WaitAllMessages(); + for (size_t layerid = 0; layerid < entries.size(); ++layerid) { + auto& entry = entries.at(layerid); + if(entry.tensor->data() != entry.output->data()) { + memcpyUtil(entry, (void *) entry.output->data(), (void *) entry.tensor->data(), (size_t) entry.tensor->size(), layerid); + } + } + + return Status::OK(); +} + +void MsCudaRingAllreduceOp::memcpyUtil(TensorTableEntry entry, void* dest, void* src, size_t buffer_len, int layerid) { + assert(dest != nullptr); + assert(src != nullptr); + LOG(INFO, global_state_->rank)<<"memcpyUtil GPU. "<streams[global_state_->current_nccl_stream][entry.device]); + cuda_context_->ErrorCheck("cudaMemcpyAsync", cuda_result); + auto cuda_sync_result = cudaStreamSynchronize(cuda_context_->streams[global_state_->current_nccl_stream][entry.device]); + cuda_context_->ErrorCheck("cudaStreamSynchronize", cuda_sync_result); +} + +bool MsCudaRingAllreduceOp::Enabled(const ParameterManager& param_manager, + const std::vector& entries, + const Response& response) const { + return entries[0].device != CPU_DEVICE_ID; +} + +namespace msallreduce{ + +void Ring::InitRing(int tmp[], bool _isFat, int rank, int size) { + load = 0; + isFat = _isFat; + for (int i = 0; i < 8; i++) + loop[i] = tmp[i]; + + for (int j = 0; j < size; j++) { // go through allranks + if (rank == loop[j]) { + prevGPU = loop[(j-1+size) % size]; + nextGPU = loop[(j+1+size) % size]; + } + } +} + +int Ring::GetAfterLoad(int message_len) { + if (!isFat) + return 2*(load+message_len); + else + return (load+message_len); +} + +void Ring::AddLoad(int message_len) { + load += message_len; +} + +void Ring::ReduceLoad(int message_len) { + load -= message_len; +} + +Message::Message(MPIContext* mpi_context) + : mpi_context(mpi_context) { +} + +void Message::InitMessage(Ring* _ring, int _rank, int _ring_starter_rank, int _count, void* _grad_buf, void* _recv_buf, DataType _datatype, MPI_Comm _comm, int _tag) { + comm = _comm; + count = _count; + tag = _tag; + ring = _ring; + rank = _rank; + ring_starter_rank = _ring_starter_rank; + leg = 0; + grad_buf = _grad_buf; + recv_buf = _recv_buf; + datatype = _datatype; + Start(); +} + +AllreduceMessage::AllreduceMessage(MPIContext* mpi_context) + : Message(mpi_context) { +} + +void AllreduceMessage::Start() { + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + if (rank == ring_starter_rank) { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } else { + MPI_Irecv(recv_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } +} + +bool AllreduceMessage::Test() { + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + + int flag; + if (leg == 4) + return true; + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + if (flag == 1) { + leg++; + if (leg == 4) { + ring->ReduceLoad(count); + return true; + } + if (leg == 1) { + if (rank == ring_starter_rank) { + MPI_Irecv(grad_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } else { + // call the cuda kernel + switch(datatype) { + case HOROVOD_FLOAT16: + MsCudaPairwiseReduce(count, (uint16_t*)grad_buf, (uint16_t*)recv_buf); + break; + case HOROVOD_FLOAT32: + MsCudaPairwiseReduce(count, (float*)grad_buf, (float*)recv_buf); + break; + case HOROVOD_FLOAT64: + MsCudaPairwiseReduce(count, (double*)grad_buf, (double*)recv_buf); + break; + default: + throw std::logic_error("Message::Test: Unsupported data type."); + } + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } + } else if (leg == 2) { + if (rank == ring_starter_rank) { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } else { + MPI_Irecv(grad_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } + } else if (leg == 3) { + if (rank == ring_starter_rank) { + MPI_Irecv(grad_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } else { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } + } + } + + return false; +} + +ReduceMessage::ReduceMessage(MPIContext* mpi_context) + : Message(mpi_context) { +} + +void ReduceMessage::Start() { + + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + if (rank == ring_starter_rank) { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } else { + MPI_Irecv(recv_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } +} + +bool ReduceMessage::Test() { + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + + int flag; + if (leg == 2) + return true; + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + if (flag == 1) { + leg++; + if (leg == 2) { + ring->ReduceLoad(count); + return true; + } + if (leg == 1) { + if (rank == ring_starter_rank) { + MPI_Irecv(grad_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + } else { + // call the cuda kernel + switch(datatype) { + case HOROVOD_FLOAT16: + MsCudaPairwiseReduce(count, (uint16_t*)grad_buf, (uint16_t*)recv_buf); + break; + case HOROVOD_FLOAT32: + MsCudaPairwiseReduce(count, (float*)grad_buf, (float*)recv_buf); + break; + case HOROVOD_FLOAT64: + MsCudaPairwiseReduce(count, (double*)grad_buf, (double*)recv_buf); + break; + default: + throw std::logic_error("Message::Test: Unsupported data type."); + } + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } + } + } + return false; +} + +BroadcastMessage::BroadcastMessage(MPIContext* mpi_context) + : Message(mpi_context) { +} + +void BroadcastMessage::Start() { + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + if (rank == ring_starter_rank) { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + leg = 1; + } else { + MPI_Irecv(grad_buf, count, mpi_datatype, ring->prevGPU, tag, comm, &req); + if (ring->nextGPU == ring_starter_rank) + leg = 1; + } +} + +bool BroadcastMessage::Test() { + auto mpi_datatype = mpi_context->GetMPIDataType(datatype); + + int flag; + if (leg == 2) + return true; + MPI_Test(&req, &flag, MPI_STATUS_IGNORE); + if (flag == 1) { + leg++; + if (leg == 2) { + ring->ReduceLoad(count); + return true; + } + if (leg == 1) { + if (rank != ring_starter_rank) { + MPI_Isend(grad_buf, count, mpi_datatype, ring->nextGPU, tag, comm, &req); + } + } + } + return false; +} + +AllRings::~AllRings() { + for (int i = 0; i < messages.size(); i++) + delete messages[i]; + delete[] rings; +} + +AllRings::AllRings(int rank, int size) { + rings = new Ring[num_rings]; + { + // fat ring 1 + int tmp[8] = {0, 3, 2, 1, 5, 6, 7, 4}; + rings[0].InitRing(tmp, true, rank, size); + } + { + // fat ring 2 + int tmp[8] = {0, 4, 7, 6, 5, 1, 2, 3}; + rings[1].InitRing(tmp, true, rank, size); + } + { + // skinny ring 1 + int tmp[8] = {0, 2, 6, 4, 5, 7, 3, 1}; + rings[2].InitRing(tmp, false, rank, size); + } + { + // skinny ring 2 + int tmp[8] = {0, 1, 3, 7, 5, 4, 6, 2}; + rings[3].InitRing(tmp, false, rank, size); + } +}; + +Ring* AllRings::PickRing(int count) { + int min_load = (1<<30); // INF + Ring* ret_ring = NULL; + for (int i = 0; i < num_rings; i++) { + Ring* ring = &rings[i]; + int cur_ring_after_load = ring->GetAfterLoad(count); + if (cur_ring_after_load < min_load) { + ret_ring = ring; + min_load = cur_ring_after_load; + } + } + ret_ring->AddLoad(count); + assert(ret_ring != NULL); + return ret_ring; +} + +void AllRings::InitMessageInRing(Message* message, void* grad_buf, void* recv_buf, int size, DataType datatype, MPI_Comm comm, int grad_tag, int rank) { + int count = -1; + switch(datatype) { + case HOROVOD_FLOAT16: + count = size / sizeof(uint16_t); + break; + case HOROVOD_FLOAT32: + count = size / sizeof(float); + break; + case HOROVOD_FLOAT64: + count = size / sizeof(double); + break; + default: + throw std::logic_error("AllRings::InitMessageInRing: Unsupported data type."); + } + messages.push_back(message); + Ring* ring = PickRing(count); + message->InitMessage(ring, rank, grad_tag % 8, count, grad_buf, recv_buf, datatype, comm, grad_tag); +} + +void AllRings::WaitAllMessages() { + + bool all_done = false; + while (!all_done) { + all_done = true; + for (auto& message : messages) { + if (!message->Test()) + all_done = false; + } + } + for (int i = 0; i < messages.size(); i++) + delete messages[i]; + messages.clear(); +} + +} +} +} diff --git a/horovod/common/ops/msallreduce_cuda_ring_operations.h b/horovod/common/ops/msallreduce_cuda_ring_operations.h new file mode 100644 index 0000000000..341b87a320 --- /dev/null +++ b/horovod/common/ops/msallreduce_cuda_ring_operations.h @@ -0,0 +1,122 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// Modifications copyright (C) 2019 Microsoft Corp. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef HOROVOD_MSALLREDUCE_CUDA_RING_OPERATIONS_H +#define HOROVOD_MSALLREDUCE_CUDA_RING_OPERATIONS_H + +#include +#include + +#include "msallreduce_operations.h" +#include "cuda_operations.h" +#include "cuda_fp16.h" + +namespace horovod { +namespace common { + +class MsCudaRingAllreduceOp : public MsAllreduceOp { + public: + MsCudaRingAllreduceOp(MPIContext* mpi_context, CUDAContext* cuda_context, + HorovodGlobalState* global_state); + + bool Enabled(const ParameterManager& param_manager, + const std::vector& entries, + const Response& response) const override; + + Status Execute(std::vector& entries, + const Response& response) override; + + protected: + struct MPIContext* mpi_context_; + struct CUDAContext* cuda_context_; + std::deque buffer_managers_; + + void InitCUDA(const TensorTableEntry& entry, int layerid); + void memcpyUtil(TensorTableEntry entry, void* dest, void* src, size_t buffer_len, int layerid) override; +}; + +namespace msallreduce { + +struct Ring { + int loop[8]; + int nextGPU; + int prevGPU; + int load; + bool isFat; + + void InitRing(int tmp[], bool _isFat, int rank, int size); + int GetAfterLoad(int message_len); + void AddLoad(int message_len); + void ReduceLoad(int message_len); +}; + +struct Message { + MPIContext* mpi_context; + MPI_Comm comm; + MPI_Request req; + Ring* ring; + int rank; + int ring_starter_rank; + int leg; // number of legs in the ring has been done + void* grad_buf; + void* recv_buf; + DataType datatype; + int tag; + int count; + + Message(MPIContext* mpi_context); + void InitMessage(Ring* _ring, int rank, int _ring_starter_rank, int _count, void* _grad_buf, void* _recv_buf, DataType _datatype, MPI_Comm _comm, int _tag); + virtual bool Test() = 0; +protected: + virtual void Start() = 0; +}; + +struct AllreduceMessage : public Message { + AllreduceMessage(MPIContext* mpi_context); + virtual bool Test(); +protected: + virtual void Start(); +}; + +struct ReduceMessage : public Message { + ReduceMessage(MPIContext* mpi_context); + virtual bool Test(); +protected: + virtual void Start(); +}; + +struct BroadcastMessage : public Message { + BroadcastMessage(MPIContext* mpi_context); + virtual bool Test(); +protected: + virtual void Start(); +}; + +struct AllRings { + int num_rings = 4; + Ring* rings; + std::vector messages; + + ~AllRings(); + AllRings(int rank, int size); + Ring* PickRing(int count); + void InitMessageInRing(Message* message, void* grad_buf, void* recv_buf, int size, DataType datatype, MPI_Comm comm, int grad_tag, int rank); + void WaitAllMessages(); +}; + +} // namespace msallreduce +} // namespace common +} // namespace horovod +#endif // HOROVOD_MSALLREDUCE_CUDA_RING_OPERATIONS_H diff --git a/setup.py b/setup.py index ce2bde13bc..5f0d17f541 100644 --- a/setup.py +++ b/setup.py @@ -664,7 +664,8 @@ def get_common_options(build_ext): INCLUDES += ['horovod/common/ops/cuda'] SOURCES += ['horovod/common/ops/cuda_operations.cc', 'horovod/common/ops/mpi_cuda_operations.cc', - 'horovod/common/ops/msallreduce_cuda_operations.cc'] + 'horovod/common/ops/msallreduce_cuda_operations.cc', + 'horovod/common/ops/msallreduce_cuda_ring_operations.cc'] LIBRARY_DIRS += cuda_lib_dirs LIBRARIES += ['cudart', 'cublas'] @@ -806,7 +807,8 @@ def build_mx_extension(build_ext, options): options['INCLUDES'] += cuda_include_dirs options['SOURCES'] += ['horovod/common/ops/cuda_operations.cc', 'horovod/common/ops/mpi_cuda_operations.cc', - 'horovod/common/ops/msallreduce_cuda_operations.cc'] + 'horovod/common/ops/msallreduce_cuda_operations.cc', + 'horovod/common/ops/msallreduce_cuda_ring_operations.cc'] options['LIBRARY_DIRS'] += cuda_lib_dirs options['LIBRARIES'] += ['cudart', 'cublas'] diff --git a/test/test_msallreduce.py b/test/test_msallreduce.py index 8a1d4d9e0a..352118b843 100644 --- a/test/test_msallreduce.py +++ b/test/test_msallreduce.py @@ -48,50 +48,60 @@ def evaluate(self, tensors): return sess.run(tensors) - def test_horovod_multiple_allreduce_cpu(self): - """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" - hvd.init() - size = hvd.size() - - rank0_tensors = [np.asarray([[1.0, 2.0], [3.0, 4.0]]), np.asarray([[9.0, 10.0], [11.0, 12.0]])] - rank1_tensors = [np.asarray([[5.0, 6.0], [7.0, 8.0]]), np.asarray([[13.0, 14.0], [15.0, 16.0]])] + # def test_horovod_multiple_allreduce_cpu(self): + # """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + # hvd.init() + # size = hvd.size() + # rank0_tensors = [np.asarray([[1.0, 2.0], [3.0, 4.0]]), np.asarray([[9.0, 10.0], [11.0, 12.0]])] + # rank1_tensors = [np.asarray([[5.0, 6.0], [7.0, 8.0]]), np.asarray([[13.0, 14.0], [15.0, 16.0]])] - expected = [] - for a,b in zip(rank0_tensors, rank1_tensors): - answer = parasail_reference_operation(a, b) - expected.append(answer) + # expected = [] + # for a,b in zip(rank0_tensors, rank1_tensors): + # answer = parasail_reference_operation(a, b) + # expected.append(answer) - for dtype in [tf.float16, tf.float32, tf.float64]: - with tf.device("/cpu:0"): - tensors = map(tf.constant, rank0_tensors if hvd.rank() == 0 else rank1_tensors) - # cast to the corresponding dtype - tensors = map(lambda tensor: tf.cast(tensor, dtype), tensors) - # and away we go: do reduction - reduced_tensors = [ - self.evaluate(hvd.allreduce(tensor, average=False, allreduce_type=AllreduceType.MsAllreduce)) - for tensor in tensors - ] - # cast expected result to the type of the tensorflow values - np_type = dtype.as_numpy_dtype - tmp = [t.astype(np_type) for t in expected] - self.assertAllClose(tmp, reduced_tensors) + # for dtype in [tf.float16, tf.float32, tf.float64]: + # with tf.device("/cpu:0"): + # tensors = map(tf.constant, rank0_tensors if hvd.rank() == 0 else rank1_tensors) + # # cast to the corresponding dtype + # tensors = map(lambda tensor: tf.cast(tensor, dtype), tensors) + # # and away we go: do reduction + # reduced_tensors = [ + # self.evaluate(hvd.allreduce(tensor, average=False, allreduce_type=AllreduceType.MsAllreduce)) + # for tensor in tensors + # ] + # # cast expected result to the type of the tensorflow values + # np_type = dtype.as_numpy_dtype + # tmp = [t.astype(np_type) for t in expected] + # self.assertAllClose(tmp, reduced_tensors) def test_horovod_multiple_allreduce_gpu(self): """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" hvd.init() size = hvd.size() - rank0_tensors = [np.asarray([[1.0, 2.0], [3.0, 4.0]]), np.asarray([[9.0, 10.0], [11.0, 12.0]])] - rank1_tensors = [np.asarray([[5.0, 6.0], [7.0, 8.0]]), np.asarray([[13.0, 14.0], [15.0, 16.0]])] + all_tensors = [] + for i in range(8): + # all_tensors.append([np.asarray([[(1.0), (2.0)], [(3.0), (4.0)]]), np.asarray([[(5.0), (6.0)], [(7.0), (8.0)]])]) + # all_tensors.append([np.asarray([[(1.0+i), (2.0+i)], [(3.0+i), (4.0+i)]]), np.asarray([[(5.0+i), (6.0+i)], [(7.0+i), (8.0+i)]])]) + all_tensors.append([np.asarray([(1.0+i), (1.0+i)])]) + # all_tensors.append([np.asarray([[(1.0+i)*(i==0), (2.0+i)*(i==1)], [(3.0+i)*(i==2), (4.0+i)*(i==3)]]), np.asarray([[(5.0+i)*(i==0), (6.0+i)*(i==1)], [(7.0+i)*(i==2), (8.0+i)*(i==3)]])]) - expected = [] - for a,b in zip(rank0_tensors, rank1_tensors): - answer = parasail_reference_operation(a, b) - expected.append(answer) + + # rank0_tensors = [np.asarray([[1.0, 2.0], [3.0, 4.0]]), np.asarray([[9.0, 10.0], [11.0, 12.0]])] + # rank1_tensors = [np.asarray([[1.0, 2.0], [3.0, 4.0]]), np.asarray([[9.0, 10.0], [11.0, 12.0]])] + # rank0_tensors = [np.asarray([[9.0, 10.0], [11.0, 12.0]])] + # rank1_tensors = [np.asarray([[9.0, 10.0], [11.0, 12.0]])] + + expected = all_tensors[0] + for i in [3, 2, 1, 5, 6, 7, 4]: + answer0 = parasail_reference_operation(expected[0], all_tensors[i][0]) + expected = [answer0] rank_num = hvd.local_rank() - for dtype in [tf.float32, tf.float16]: + for dtype in [tf.float32]: with tf.device("/gpu:{}".format(rank_num)): - tensors = map(tf.constant, rank0_tensors if hvd.rank() == 0 else rank1_tensors) + tensors = map(tf.constant, all_tensors[hvd.rank()]) + # tensors = map(tf.constant, rank0_tensors if hvd.rank() == 0 else rank1_tensors) # cast to the corresponding dtype tensors = map(lambda tensor: tf.cast(tensor, dtype), tensors) # and away we go: do reduction @@ -104,77 +114,77 @@ def test_horovod_multiple_allreduce_gpu(self): tmp = [t.astype(np_type) for t in expected] self.assertAllClose(tmp, reduced_tensors) - def test_horovod_multiple_large_tensors_allreduce_cpu(self): - """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" - hvd.init() - size = hvd.size() - base_dim = [16,32,64] - dim_multipliers = [1, 4, 8, 16, 32, 64] - #for multiplier in dim_multipliers: - multiplier = dim_multipliers[5] - true_dim = base_dim.copy() - true_dim[2] = true_dim[2] * multiplier - start_time = datetime.utcnow() - rep = 1 - tensor_count = 100 - with tf.device("/cpu:0"): - tf.set_random_seed(1234) + # def test_horovod_multiple_large_tensors_allreduce_cpu(self): + # """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + # hvd.init() + # size = hvd.size() + # base_dim = [16,32,64] + # dim_multipliers = [1, 4, 8, 16, 32, 64] + # #for multiplier in dim_multipliers: + # multiplier = dim_multipliers[5] + # true_dim = base_dim.copy() + # true_dim[2] = true_dim[2] * multiplier + # start_time = datetime.utcnow() + # rep = 1 + # tensor_count = 100 + # with tf.device("/cpu:0"): + # tf.set_random_seed(1234) - for _ in range(rep): - summed = [] - for _ in range(tensor_count): - tensor = tf.random_uniform( - true_dim, -100, 100, dtype=tf.float32) - summed.append(hvd.allreduce(tensor, average=False)) - result_sum = self.evaluate(summed) - #print(result_sum) - end_time = datetime.utcnow() - time_delta = end_time - start_time - tensor_size = np.prod(true_dim) / 256 - print("{} {}K tensors {} Cycles took {}".format(tensor_count, tensor_size, rep, time_delta.total_seconds())) - - def test_horovod_single_large_tensor_allreduce_cpu(self): - """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" - hvd.init() - size = hvd.size() - base_dim = [16,32,64] - dim_multipliers = [1, 4, 8, 16, 32, 64] - #for multiplier in dim_multipliers: - multiplier = dim_multipliers[5] - true_dim = base_dim.copy() - true_dim[2] = true_dim[2] * multiplier - with tf.device("/cpu:0"): - tf.set_random_seed(1234) - tensor = tf.random_uniform( - true_dim, -100, 100, dtype=tf.float32) - start_time = datetime.utcnow() + # for _ in range(rep): + # summed = [] + # for _ in range(tensor_count): + # tensor = tf.random_uniform( + # true_dim, -100, 100, dtype=tf.float32) + # summed.append(hvd.allreduce(tensor, average=False)) + # result_sum = self.evaluate(summed) + # #print(result_sum) + # end_time = datetime.utcnow() + # time_delta = end_time - start_time + # tensor_size = np.prod(true_dim) / 256 + # print("{} {}K tensors {} Cycles took {}".format(tensor_count, tensor_size, rep, time_delta.total_seconds())) + + # def test_horovod_single_large_tensor_allreduce_cpu(self): + # """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + # hvd.init() + # size = hvd.size() + # base_dim = [16,32,64] + # dim_multipliers = [1, 4, 8, 16, 32, 64] + # #for multiplier in dim_multipliers: + # multiplier = dim_multipliers[5] + # true_dim = base_dim.copy() + # true_dim[2] = true_dim[2] * multiplier + # with tf.device("/cpu:0"): + # tf.set_random_seed(1234) + # tensor = tf.random_uniform( + # true_dim, -100, 100, dtype=tf.float32) + # start_time = datetime.utcnow() - for _ in range(100): - summed = 0 - summed = hvd.allreduce(tensor, average=False) - result_sum = self.evaluate(summed) - #print(result_sum) - end_time = datetime.utcnow() - time_delta = end_time - start_time - tensor_size = np.prod(true_dim) / 256 - print("{}K tensor Cycle took {}".format(tensor_size,time_delta.total_seconds())) - - def test_horovod_single_allreduce_cpu(self): - """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" - hvd.init() - size = hvd.size() - with tf.device("/cpu:0"): - if hvd.rank() == 0: - tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]]) - else: - tensor = tf.constant([[5.0, 6.0], [7.0, 8.0]]) - summed = hvd.allreduce(tensor, average=False) - diff = self.evaluate(summed) - print(diff) - - def test_horovod_multithread_init(self): - """Test thread pool init""" - hvd.init() + # for _ in range(100): + # summed = 0 + # summed = hvd.allreduce(tensor, average=False) + # result_sum = self.evaluate(summed) + # #print(result_sum) + # end_time = datetime.utcnow() + # time_delta = end_time - start_time + # tensor_size = np.prod(true_dim) / 256 + # print("{}K tensor Cycle took {}".format(tensor_size,time_delta.total_seconds())) + + # def test_horovod_single_allreduce_cpu(self): + # """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + # hvd.init() + # size = hvd.size() + # with tf.device("/cpu:0"): + # if hvd.rank() == 0: + # tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + # else: + # tensor = tf.constant([[5.0, 6.0], [7.0, 8.0]]) + # summed = hvd.allreduce(tensor, average=False) + # diff = self.evaluate(summed) + # print(diff) + + # def test_horovod_multithread_init(self): + # """Test thread pool init""" + # hvd.init() if __name__ == '__main__': tf.test.main()