Skip to content

Commit

Permalink
Merge pull request #11 from Tixxx/saemal/msallreducecudakernels
Browse files Browse the repository at this point in the history
Saemal/msallreducecudakernels
  • Loading branch information
Tixxx authored Sep 5, 2019
2 parents a3d5910 + 17e8d9c commit 44fd7f8
Show file tree
Hide file tree
Showing 7 changed files with 850 additions and 108 deletions.
7 changes: 4 additions & 3 deletions horovod/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#if HAVE_CUDA
#include "ops/msallreduce_cuda_operations.h"
#include "ops/msallreduce_cuda_ring_operations.h"
#include "ops/cuda_operations.h"
#include "ops/mpi_cuda_operations.h"
#endif
Expand Down Expand Up @@ -157,7 +158,7 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
#if HOROVOD_GPU_ALLREDUCE == 'M'
if (state.msallreduce_enabled == true){
LOG(INFO) << "msallGpureduce enabled.";
msallreduce_ops.push_back(std::shared_ptr<AllreduceOp>(new MsCudaAllreduceOp(&mpi_context, &cuda_context, &state)));
msallreduce_ops.push_back(std::shared_ptr<AllreduceOp>(new MsCudaRingAllreduceOp(&mpi_context, &cuda_context, &state)));
}
allreduce_ops.push_back(std::shared_ptr<AllreduceOp>(
new MPI_CUDAAllreduce(&mpi_context, &cuda_context, &state)));
Expand Down Expand Up @@ -997,7 +998,7 @@ void BackgroundThreadLoop(HorovodGlobalState& state, MPIContext& ctx) {
std::strtol(mpi_threads_disable, nullptr, 10) > 0) {
required = MPI_THREAD_SINGLE;
}
#if HAVE_MLSL
#if HAVE_MLSLf
// MLSL comes with Intel MPI
// and needs to initialize MPI with the proper configuration.
mlsl_context.Init();
Expand Down Expand Up @@ -1068,7 +1069,7 @@ if(state.msallreduce_enabled == true) {
}

delete[] node_rank;
}
}
// TODO parasail new algo end
}

Expand Down
10 changes: 10 additions & 0 deletions horovod/common/ops/cuda/msallreduce_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ void CudaScaleAddKernel(int count, T* a, const T* b, TACC a_coeff, TACC b_coeff)
}
}

template<typename T>
__global__
void ConvertToFloat(int count, T* a, float* b) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (count > index){
b[index] = (float) a[index];
}
}


void CudaDotProductImpl(int count, const double* device_a, const double* device_b,
double* device_normsq_a, double* device_normsq_b, double* device_dot, double& host_normsq_a, double& host_normsq_b, double& host_dot) {

Expand Down
31 changes: 30 additions & 1 deletion horovod/common/ops/msallreduce_cuda_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,33 @@ void CudaScaleAddImpl(int count, double* a_device, const double* b_device, doubl

void CudaScaleAddImpl(int count, float* a_device, const float* b_device, double host_a_coeff, double host_b_coeff);

void CudaScaleAddImpl(int count, uint16_t* a_device, const uint16_t* b_device, double host_a_coeff, double host_b_coeff);
void CudaScaleAddImpl(int count, uint16_t* a_device, const uint16_t* b_device, double host_a_coeff, double host_b_coeff);

template<typename T>
void MsCudaPairwiseReduce(int count, T* device_a, T* device_b){
double normsq_a = 0.f;
double normsq_b = 0.f;
double dot = 0.f;

double* device_normsq_a, * device_normsq_b, * device_dot;
cudaMalloc(&device_normsq_a, sizeof(double));
cudaMalloc(&device_normsq_b, sizeof(double));
cudaMalloc(&device_dot, sizeof(double));

CudaDotProductImpl(count, device_a, device_b, device_normsq_a, device_normsq_b, device_dot, normsq_a, normsq_b, dot);

cudaFree(device_normsq_a);
cudaFree(device_normsq_b);
cudaFree(device_dot);

double a_coeff = 1;
double b_coeff = 1;
if (normsq_a != 0)
a_coeff = 1.0 - dot / normsq_a * 0.5;
if (normsq_b != 0)
b_coeff = 1.0 - dot / normsq_b * 0.5;

CudaScaleAddImpl(count, device_a, device_b, a_coeff, b_coeff);
}


Loading

0 comments on commit 44fd7f8

Please sign in to comment.