Skip to content

Commit

Permalink
Simplify/improve barrier in AllReduce6
Browse files Browse the repository at this point in the history
  • Loading branch information
roshandathathri committed Jun 20, 2024
1 parent 34f4d9d commit 3362e25
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,26 @@ extern "C" __global__ void __launch_bounds__(1024, 1)

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900

// Barrier among all devices followed by a memory fence
// Should be called by all threads on all devices
// Assumes \p num_threads_per_block >= \p num_ranks
__device__ void barrier(
mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
int thread_id, int block_id, int num_threads_per_block, int num_blocks,
int num_ranks) {
// wait for every device
if (block_id == 0) {
// 1 less than the num_ranks because there is no semaphore for self
if (thread_id < num_ranks - 1) {
semaphores[thread_id].signal();
semaphores[thread_id].wait();
}
}

// wait for every thread in every block on this device
deviceSyncer.sync(num_blocks);
}

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
Expand All @@ -796,17 +816,13 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_threads_per_block = blockDim.x;
int num_blocks = gridDim.x;

if (tid == 0 && bid == 0) {
__threadfence_system();
}
if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
// before reading them in this kernel
barrier(semaphores, tid, bid, num_threads_per_block, num_blocks, nranks);

int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;
Expand All @@ -815,22 +831,14 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
int my_step = blockDim.x * gridDim.x * 4;

for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val;
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
}

deviceSyncer.sync(gridDim.x);
if (tid == 0 && bid == 0) {
__threadfence_system();
}

if (bid == 0) {
if (tid < nranks - 1) {
semaphores[tid].signal();
semaphores[tid].wait();
}
}
deviceSyncer.sync(gridDim.x);
// end with a barrier to ensure all devices can now read their values
// from their own memory (that is part of the multicast memory)
// after writing them in this kernel
barrier(semaphores, tid, bid, num_threads_per_block, num_blocks, nranks);
}
#endif

0 comments on commit 3362e25

Please sign in to comment.