Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix potentially unsafe custom allreduce synchronization #8558

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda_runtime.h>

#include <iostream>
#include <array>
#include <limits>
#include <map>
#include <unordered_map>
Expand All @@ -23,17 +24,23 @@

namespace vllm {

constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
constexpr int kMaxBlocks = 36;
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};

struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };

struct __align__(16) RankSignals { volatile Signal* signals[8]; };
struct __align__(16) RankSignals { Signal* signals[8]; };

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -123,47 +130,60 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}

// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]);
}
__syncthreads();
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
return flag;
}

// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if constexpr (!final_sync) __threadfence_system();
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]);
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr =
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val);
}
}
if constexpr (!final_sync) __syncthreads();
if constexpr (is_start || need_fence) __syncthreads();
}

template <typename P, int ngpus, typename A>
Expand All @@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) {
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
start_sync<ngpus>(sg, self_sg, rank);
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}

template <typename P>
DINLINE P* get_tmp_buf(volatile Signal* sg) {
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) {
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
Expand All @@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1)
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
end_sync<ngpus>(sg, self_sg, rank);
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);

// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
Expand Down Expand Up @@ -437,6 +455,8 @@ class CustomAllreduce {
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
Expand Down
14 changes: 9 additions & 5 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/**
* This is a standalone test for custom allreduce.
* To compile, make sure you have MPI and NCCL installed in your system.
* export MPI_HOME=XXX
* export MPI_HOME=xxx
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
*
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
*
* To run:
* mpirun -np 8 ./custom_all_reduce_test
* mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test
*/
#include <cuda.h>
#include <curand_kernel.h>
Expand Down Expand Up @@ -302,15 +302,19 @@ int main(int argc, char** argv) {

bool performance_test = true;
cudaProfilerStart();
// for (int threads : {256, 512}) {
// Uncomment to scan through different block size configs.
// for (int threads : {256, 512, 1024}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
// performance_test);
// }
// }
// Scan through different sizes to test performance.
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}

cudaProfilerStop();
MPICHECK(MPI_Finalize());
return EXIT_SUCCESS;
}
Loading