diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 632b579c55afa..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -24,23 +23,17 @@ namespace vllm { -constexpr int kMaxBlocks = 36; -// Counter may overflow, but it's fine since unsigned int overflow is -// well-defined behavior. -using FlagType = uint32_t; +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links struct Signal { - alignas(128) FlagType self_counter[kMaxBlocks][8]; - // Two sets of peer counters are needed for two syncs. The reason is that - // it's possible for peer GPU block to arrive at the second sync point while - // the current GPU block haven't passed the first sync point. Thus, peer GPU - // may write counter+1 while current GPU is busy waiting for counter. We use - // alternating counter array to avoid this possibility. - alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { Signal* signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -130,60 +123,47 @@ DINLINE O downcast(array_t val) { } } -static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { - asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), - "l"(flag_addr)); -} - -static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { - FlagType flag; - asm volatile("ld.acquire.sys.global.u32 %0, [%1];" - : "=r"(flag) - : "l"(flag_addr)); - return flag; -} - -static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { - asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); -} - -static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { - FlagType flag; - asm volatile("ld.volatile.global.u32 %0, [%1];" - : "=r"(flag) - : "l"(flag_addr)); - return flag; +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]); + } + __syncthreads(); } -// is_start: whether this is the very first synchronization barrier. -// need_fence: whether a memory fence is needed. If true, a release-acquire -// semantic is used to enforce memory access order before and after this -// barrier. -template -DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, - int rank) { - if constexpr (!is_start) __syncthreads(); - static_assert( - !(is_start && need_fence)); // Start barrier shouldn't need fence. +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { - // Increment the counter. Technically we only need one counter, but we use - // multiple per block to eliminate the need to share the counter via smem. - auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; - // Write the expected counter value to peer and wait for correct value from - // peer. - auto peer_counter_ptr = - &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; - auto self_counter_ptr = - &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; - if constexpr (need_fence) { - st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val); - } else { - st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val); - } + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]); } - if constexpr (is_start || need_fence) __syncthreads(); + if constexpr (!final_sync) __syncthreads(); } template @@ -198,31 +178,33 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - multi_gpu_barrier(sg, self_sg, rank); + start_sync(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - multi_gpu_barrier(sg, self_sg, rank); + end_sync(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(Signal* sg) { +DINLINE P* get_tmp_buf(volatile Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -240,12 +222,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - multi_gpu_barrier(sg, self_sg, rank); + start_sync(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - multi_gpu_barrier(sg, self_sg, rank); + end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -455,8 +437,6 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); - // TODO(hanzhi713): Threshold is different for A100 and H100. - // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c8b5d0a013f63..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=xxx + * export MPI_HOME=XXX * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test + * mpirun -np 8 ./custom_all_reduce_test */ #include #include @@ -302,19 +302,15 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // Uncomment to scan through different block size configs. - // for (int threads : {256, 512, 1024}) { + // for (int threads : {256, 512}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, - // performance_test); + // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } - // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); - MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; }