Skip to content

Commit

Permalink
Merge pull request #2 from NVIDIA/complex_grad_dist
Browse files Browse the repository at this point in the history
Enable support for complex gradient reduction in distributed cases.
  • Loading branch information
azrael417 authored Sep 5, 2023
2 parents e06613d + ff7cd7f commit fb8c793
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions src/csrc/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,14 @@ void Comm::allreduce(torch::Tensor& tensor, bool average) const {
CHECK_CUDA(cudaStreamWaitEvent(stream, event));

auto count = torch::numel(tensor);
CHECK_NCCL(ncclAllReduce(tensor.data_ptr(), tensor.data_ptr(), count, get_nccl_dtype(tensor),
ncclDataType_t nccl_dtype;
if (torch::is_complex(tensor)) {
nccl_dtype = get_nccl_dtype(torch::view_as_real(tensor));
count *= 2;
} else {
nccl_dtype = get_nccl_dtype(tensor);
}
CHECK_NCCL(ncclAllReduce(tensor.data_ptr(), tensor.data_ptr(), count, nccl_dtype,
(average) ? ncclAvg : ncclSum, nccl_comm, stream));

CHECK_CUDA(cudaEventRecord(event, stream));
Expand All @@ -114,7 +121,14 @@ void Comm::allreduce(const std::vector<torch::Tensor>& tensors, bool average) co
}

auto count = torch::numel(t);
CHECK_NCCL(ncclAllReduce(t.data_ptr(), t.data_ptr(), count, get_nccl_dtype(t), (average) ? ncclAvg : ncclSum,
ncclDataType_t nccl_dtype;
if (torch::is_complex(t)) {
nccl_dtype = get_nccl_dtype(torch::view_as_real(t));
count *= 2;
} else {
nccl_dtype = get_nccl_dtype(t);
}
CHECK_NCCL(ncclAllReduce(t.data_ptr(), t.data_ptr(), count, nccl_dtype, (average) ? ncclAvg : ncclSum,
nccl_comm, stream));
}
CHECK_NCCL(ncclGroupEnd());
Expand All @@ -140,13 +154,20 @@ void Comm::broadcast(torch::Tensor& tensor, int root) const {
THROW_INVALID_USAGE("broadcast only supports GPU tensors for now.");
}
auto count = torch::numel(tensor);
ncclDataType_t nccl_dtype;
if (torch::is_complex(tensor)) {
nccl_dtype = get_nccl_dtype(torch::view_as_real(tensor));
count *= 2;
} else {
nccl_dtype = get_nccl_dtype(tensor);
}

auto torch_stream = c10::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
CHECK_CUDA(cudaEventRecord(event, torch_stream));
CHECK_CUDA(cudaStreamWaitEvent(stream, event));

CHECK_NCCL(
ncclBroadcast(tensor.data_ptr(), tensor.data_ptr(), count, get_nccl_dtype(tensor), root, nccl_comm, stream));
ncclBroadcast(tensor.data_ptr(), tensor.data_ptr(), count, nccl_dtype, root, nccl_comm, stream));

CHECK_CUDA(cudaEventRecord(event, stream));
CHECK_CUDA(cudaStreamWaitEvent(torch_stream, event));
Expand Down

0 comments on commit fb8c793

Please sign in to comment.