From 3cc29ca06405d1f954da4ee9a6b78fcdd1ef1241 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Mon, 18 Dec 2023 05:00:45 +0000 Subject: [PATCH 01/53] a working impl & squash commits --- csrc/fast_allreduce.cu | 107 ++++ csrc/fast_allreduce.cuh | 504 ++++++++++++++++++ csrc/fast_allreduce_test.cu | 241 +++++++++ csrc/ops.h | 15 + csrc/pybind.cpp | 11 + requirements.txt | 1 + setup.py | 1 + tests/kernels/test_fast_ar.py | 67 +++ .../parallel_utils/communication_op.py | 40 +- .../parallel_utils/fast_allreduce.py | 93 ++++ vllm/worker/model_runner.py | 8 +- 11 files changed, 1084 insertions(+), 4 deletions(-) create mode 100644 csrc/fast_allreduce.cu create mode 100644 csrc/fast_allreduce.cuh create mode 100644 csrc/fast_allreduce_test.cu create mode 100644 tests/kernels/test_fast_ar.py create mode 100644 vllm/model_executor/parallel_utils/fast_allreduce.py diff --git a/csrc/fast_allreduce.cu b/csrc/fast_allreduce.cu new file mode 100644 index 0000000000000..ab3423fd7e4ff --- /dev/null +++ b/csrc/fast_allreduce.cu @@ -0,0 +1,107 @@ +#include +#include +#include +#include + +#include "fast_allreduce.cuh" + +// fake pointer type +using fptr_t = uint64_t; +static_assert(sizeof(void *) == sizeof(fptr_t)); + +fptr_t prepare_buffer(fptr_t ptr, const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink) { + int world_size = offsets.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + cudaIpcMemHandle_t ipc_handles[8]; + for (int i = 0; i < world_size; i++) { + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + } + return (fptr_t) new vllm::FastAllreduce( + reinterpret_cast(ptr), ipc_handles, offsets, rank, + full_nvlink); +} + +void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { + auto fa = reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + switch (inp.scalar_type()) { + case at::ScalarType::Float: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + inp.numel()); + break; + } + case at::ScalarType::Half: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + inp.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), inp.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "Fast allreduce only supports float32, float16 and bfloat16"); + } +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +int meta_size() { return sizeof(vllm::Metadata); } + +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); +} + +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto sz = fa->graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::vector handles(handle_sz * sz, 0); + std::vector offsets(sz); + for (int i = 0; i < sz; i++) { + auto ptr = fa->graph_unreg_buffers_[i]; + void *base_ptr; + // note: must share the base address of each allocation, or we get wrong address + auto _err = cuPointerGetAttribute( + &base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr); + if (_err != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&handles[i * handle_sz], + base_ptr)); + offsets[i] = ((char *)ptr) - ((char *)base_ptr); + } + return std::make_pair(handles, offsets); +} + +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); +} diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh new file mode 100644 index 0000000000000..4bbfae8686c47 --- /dev/null +++ b/csrc/fast_allreduce.cuh @@ -0,0 +1,504 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +struct Signal { + alignas(64) union { + uint64_t flag; + unsigned char data[8]; + } start; + alignas(64) union { + uint64_t flag; + unsigned char data[8]; + } end; +}; + +struct Metadata { + alignas(128) Signal sg; + alignas(128) int counter; +}; +static_assert(offsetof(Metadata, counter) == 128); +static_assert(sizeof(Metadata) == 256); + +struct __align__(16) RankData { void *__restrict__ ptrs[8]; }; + +struct RankSignals { + volatile Signal *signals[8]; +}; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half &assign_add(half &a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float &assign_add(float &a, float b) { return a += b; } + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// compute flag at compile time +__host__ __device__ constexpr uint64_t compute_flag(int ngpus) { + auto m = std::numeric_limits::max(); + return m >> ((8 - ngpus) * 8); +} + +template +__device__ __forceinline__ void start_sync(const RankSignals &sg, + volatile Metadata *meta, int rank) { + constexpr auto FLAG = compute_flag(ngpus); + if (blockIdx.x == 0) { + if (threadIdx.x < ngpus) + // simultaneously write to the corresponding byte to all other ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start.data[rank] = 255; + else if (threadIdx.x == 32) + // reset + meta->sg.end.flag = 0; + } + if (threadIdx.x == 0) { + while (meta->sg.start.flag != FLAG) + ; + } + __syncthreads(); +} + +template +__device__ __forceinline__ void end_sync(const RankSignals &sg, + volatile Metadata *meta, int rank) { + constexpr auto FLAG = compute_flag(ngpus); + __syncthreads(); + __shared__ int num; + if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1); + __syncthreads(); + + // Only the last completing block can perform the end synchronization + // This can ensures when the final busy wait ends, all ranks must have + // finished reading each other's buffer, and the kernel can exit. + if (num == gridDim.x - 1) { + if (threadIdx.x == 32) { + // reset in a different warp + meta->counter = 0; + meta->sg.start.flag = 0; + } else if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding byte to all other ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end.data[rank] = 255; + } + if constexpr (final_sync) { + if (threadIdx.x == 0) { + while (meta->sg.end.flag != FLAG) + ; + } + } + } + if constexpr (!final_sync) { + if (threadIdx.x == 0) { + while (meta->sg.end.flag != FLAG) + ; + } + __syncthreads(); + } +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData *_dp, RankSignals sg, + volatile Metadata *meta, T *__restrict__ result, + int rank, int size) { + auto dp = *_dp; + start_sync(sg, meta, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + using P = typename packed_t::P; + using A = typename packed_t::A; + A tmp = upcast(((P *)dp.ptrs[0])[idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(((P *)dp.ptrs[i])[idx])); + } + ((P *)result)[idx] = downcast

(tmp); + } + end_sync(sg, meta, rank); +} + +template +DINLINE P *get_tmp_buf(volatile Signal *sg) { + return (P *)(((Metadata *)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData *_dp, RankSignals sg, + volatile Metadata *meta, T *__restrict__ result, + int rank, int size) { + auto dp = *_dp; + start_sync(sg, meta, rank); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + P *ptrs[ngpus]; + P *tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (P *)dp.ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + tmp_out[idx - start] = downcast

(tmp); + } + end_sync(sg, meta, rank); + + // stage 2: allgather + for (int idx = tid; idx < part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int dst_idx = i * part + idx; + ((P *)result)[dst_idx] = tmps[i][idx]; + } + } + // process the last larger partition + int remaining = size - part * ngpus; + if (tid < remaining) { + int dst_idx = tid + part * ngpus; + ((P *)result)[dst_idx] = get_tmp_buf

(sg.signals[ngpus - 1])[part + tid]; + } + + // faster than this + // for (int idx = tid; idx < size; idx += stride) { + // int target_rank = idx / part; + // if (target_rank == ngpus) target_rank -= 1; + // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part]; + // } +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg, + volatile Metadata *meta, + T *__restrict__ result, int rank, + int size) { + auto dp = *_dp; + start_sync(sg, meta, rank); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + // do the actual reduction + auto tmp_out = get_tmp_buf

(sg.signals[rank]); + constexpr int hg = ngpus / 2; + P *ptrs[hg]; + { + int start = rank - rank % hg; +#pragma unroll + for (int i = 0; i < hg; i++) { + ptrs[i] = (P *)dp.ptrs[i + start]; + } + } + for (int idx = tid; idx < size; idx += stride) { + A tmp = {0.0f, 0.0f}; +#pragma unroll + for (int i = 0; i < hg; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + tmp_out[idx] = downcast

(tmp); + } + + end_sync(sg, meta, rank); + + auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); + // do the actual reduction + for (int idx = tid; idx < size; idx += stride) { + auto tmp = tmp_out[idx]; + packed_assign_add(tmp, src[idx]); + ((P *)result)[idx] = tmp; + } +} +class FastAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + // below are device pointers + RankSignals sg_; + std::unordered_map buffers_; + Metadata *meta_; + + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + std::vector ipc_handles_; + + /** + * meta is a pointer to device metadata and acutual data. + * + * There's a total of sizeof(Metadata) of prefix before the actual data, + * so meta + 1 points to actual allreduce buffer. + */ + FastAllreduce(Metadata *meta, const cudaIpcMemHandle_t *handles, + const std::vector &offsets, int rank, + bool full_nvlink = true) + : rank_(rank), + world_size_(offsets.size()), + meta_(meta), + full_nvlink_(full_nvlink) { + for (int i = 0; i < world_size_; i++) { + Metadata *rank_meta; + if (i != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[i]; + rank_meta = (Metadata *)handle; + } else { + rank_meta = meta_; + } + sg_.signals[i] = &rank_meta->sg; + } + size_t rank_data_sz = 16 * 1024 * 1024; + CUDACHECK(cudaMalloc(&d_rank_data_base_, rank_data_sz)); + d_rank_data_end_ = d_rank_data_base_ + rank_data_sz / sizeof(RankData); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + void register_buffer(const std::vector &handles, + const std::vector &offsets, void *self) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + if (i != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle( + (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()), + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[self] = d_data; + } + + void register_graph_buffers( + const std::vector &handles, + const std::vector> &offsets) { + auto sz = graph_unreg_buffers_.size(); + check_rank_data_capacity(sz); + for (int i = 0; i < sz; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + RankData rd; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle( + (void **)&handle, + *((cudaIpcMemHandle_t *)&handles[j] + [i * sizeof(cudaIpcMemHandle_t)]), + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_++, &rd, sizeof(RankData), + cudaMemcpyHostToDevice)); + } + graph_unreg_buffers_.clear(); + } + + // note: 512, 36 is good for most cases + template + void allreduce(cudaStream_t stream, T *input, T *output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "fast allreduce currently requires input length to be multiple of " + + std::to_string(d)); + + RankData *ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name \ + <<>>(ptrs, sg_, meta_, output, rank_, size); +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } else { \ + KL(ngpus, cross_device_reduce_half_butterfly); \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "Fast allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~FastAllreduce() { + for (auto ptr : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; + +} // namespace vllm diff --git a/csrc/fast_allreduce_test.cu b/csrc/fast_allreduce_test.cu new file mode 100644 index 0000000000000..79d6f78ea99af --- /dev/null +++ b/csrc/fast_allreduce_test.cu @@ -0,0 +1,241 @@ +#include +#include +#include +#include + +#include +#include + +#include "cuda_profiler_api.h" +#include "fast_allreduce.cuh" +#include "mpi.h" +#include "nccl.h" + +#define MPICHECK(cmd) \ + do { \ + int e = cmd; \ + if (e != MPI_SUCCESS) { \ + printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ + ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +__global__ void dummy_kernel() { + for (int i = 0; i < 500; i++) __nanosleep(1000000); // 500ms +} + +template +__global__ void set_data(T *data, int size, int myRank) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + data[idx] = myRank * 0.11f; + } +} + +template +__global__ void convert_data(const T *data1, const T *data2, double *fdata1, + double *fdata2, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + fdata1[idx] = data1[idx]; + fdata2[idx] = data2[idx]; + } +} + +__global__ void init_rand(curandState_t *state, int size, int nRanks) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + for (int i = 0; i < nRanks; i++) { + curand_init(i + 1, idx, 0, &state[idx * nRanks + i]); + } + } +} + +template +__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, + int myRank, int nRanks, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + double sum = 0.0; + for (int i = 0; i < nRanks; i++) { + double val = curand_uniform_double(&state[idx * nRanks + i]) * 4; + T hval = val; // downcast first + sum += static_cast(hval); + if (i == myRank) data[idx] = hval; + } + ground_truth[idx] = sum; + } +} + +template +void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, + int data_size) { + T *result; + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); + CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); + + cudaIpcMemHandle_t self_data_handle; + cudaIpcMemHandle_t data_handles[8]; + vllm::Metadata *buffer; + T *self_data_copy; + CUDACHECK( + cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); + CUDACHECK(cudaMemset(buffer, 0, + 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); + CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); + CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); + + MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, MPI_COMM_WORLD)); + + std::vector offsets(nRanks, 0); + vllm::FastAllreduce fa(buffer, data_handles, offsets, myRank); + auto *self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Metadata) + data_size * sizeof(T)); + // hack buffer registration + { + std::vector handles; + handles.reserve(nRanks); + for (int i = 0; i < nRanks; i++) { + char *begin = (char *)&data_handles[i]; + char *end = (char *)&data_handles[i + 1]; + handles.emplace_back(begin, end); + } + std::vector offsets( + nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T)); + fa.register_buffer(handles, offsets, self_data); + } + + double *verification_buffer; + CUDACHECK(cudaMallocHost(&verification_buffer, data_size * sizeof(double))); + curandState_t *states; + CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); + init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); + gen_data<<<108, 1024, 0, stream>>>(states, self_data, verification_buffer, + myRank, nRanks, data_size); + CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + cudaEvent_t start, stop; + CUDACHECK(cudaEventCreate(&start)); + CUDACHECK(cudaEventCreate(&stop)); + + ncclDataType_t ncclDtype; + if (std::is_same::value) { + ncclDtype = ncclFloat16; + } else if (std::is_same::value) { + ncclDtype = ncclBfloat16; + } else { + ncclDtype = ncclFloat; + } + + dummy_kernel<<<1, 1, 0, stream>>>(); + constexpr int warmup_iters = 5; + constexpr int num_iters = 25; + // warmup + for (int i = 0; i < warmup_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, + stream)); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, + stream)); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + float allreduce_ms = 0; + cudaEventElapsedTime(&allreduce_ms, start, stop); + + // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>(); + // set_data<<<16, 1024, 0, stream>>>(self_data, data_size, myRank); + + dummy_kernel<<<1, 1, 0, stream>>>(); + // warm up + for (int i = 0; i < warmup_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, block_limit); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, block_limit); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + + float duration_ms = 0; + cudaEventElapsedTime(&duration_ms, start, stop); + if (myRank == 0) + printf( + "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " + "time:%.2fus\n", + myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, + duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); + + // And wait for all the queued up work to complete + CUDACHECK(cudaStreamSynchronize(stream)); + + NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, + ncclSum, comm, stream)); + + double *nccl_result, *my_result; + CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); + CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); + + convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, + my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + long double nccl_diffs = 0.0; + long double my_diffs = 0.0; + for (int j = 0; j < data_size; j++) { + nccl_diffs += abs(nccl_result[j] - verification_buffer[j]); + my_diffs += abs(my_result[j] - verification_buffer[j]); + } + if (myRank == 0) + std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + << " me: " << my_diffs / data_size << std::endl; + + // cudaFree(result); + // CUDACHECK(cudaStreamDestroy(stream)); + CUDACHECK(cudaFree(states)); +} + +int main(int argc, char **argv) { + int nRanks, myRank; + MPICHECK(MPI_Init(&argc, &argv)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); + CUDACHECK(cudaSetDevice(myRank)); + ncclUniqueId id; + ncclComm_t comm; + if (myRank == 0) ncclGetUniqueId(&id); + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPI_COMM_WORLD)); + NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + + cudaProfilerStart(); + // for (int threads : {256, 512}) { + // for (int block_limit = 16; block_limit < 112; block_limit += 4) { + // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // } + // } + for (int sz = 512; sz <= (4 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 36, sz); + } + + cudaProfilerStop(); + return EXIT_SUCCESS; +} diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da1417..c168c4a73b3c4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -89,3 +89,18 @@ torch::Tensor gptq_gemm( void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); + + +using fptr_t = uint64_t; +fptr_t prepare_buffer(fptr_t ptr, const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink); +void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void dispose(fptr_t _fa); +int meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f337..c46026e05ee54 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -81,4 +81,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "get_device_attribute", &get_device_attribute, "Gets the specified device attribute."); + + pybind11::module fast_ar = m.def_submodule("fast_ar", "fast allreduce"); + fast_ar.def("prepare_buffer", &prepare_buffer, "prepare_buffer"); + fast_ar.def("allreduce", &allreduce, "allreduce"); + fast_ar.def("dispose", &dispose, "dispose"); + fast_ar.def("meta_size", &meta_size, "meta_size"); + fast_ar.def("register_buffer", ®ister_buffer, "register_buffer"); + fast_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, + "get_graph_buffer_ipc_meta"); + fast_ar.def("register_graph_buffers", ®ister_graph_buffers, + "register_graph_buffers"); } diff --git a/requirements.txt b/requirements.txt index 92ba0a716c45c..aeeeaaf303d1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] +pynvml == 11.5.0 diff --git a/setup.py b/setup.py index 45a18776798fb..054622d22a988 100644 --- a/setup.py +++ b/setup.py @@ -220,6 +220,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/layernorm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/cuda_utils_kernels.cu", + "csrc/fast_allreduce.cu", "csrc/pybind.cpp", ] diff --git a/tests/kernels/test_fast_ar.py b/tests/kernels/test_fast_ar.py new file mode 100644 index 0000000000000..1d2edf36ab43e --- /dev/null +++ b/tests/kernels/test_fast_ar.py @@ -0,0 +1,67 @@ +""" +Run this test like this: +torchrun --standalone --nnodes=1 --nproc-per-node=4 tests/kernels/test_fast_ar.py +""" +import torch +import os +import random +import torch.distributed as dist +from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce +from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel + +os.environ["MAX_JOBS"] = "16" +rank = int(os.environ["RANK"]) +local_rank = int(os.environ["LOCAL_RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) +torch.cuda.set_device(local_rank) +initialize_model_parallel(world_size) + +test_count = 8 +if rank == 0: + test_sizes = [random.randint(1024, 2048*1024) for i in range(test_count)] + for i, v in enumerate(test_sizes): + test_sizes[i] -= test_sizes[i] % 8 +else: + test_sizes = [0] * test_count +dist.broadcast_object_list(test_sizes, src=0) + + +def test_fast_ar(sz: int, dtype): + fa = FastAllreduce(rank, world_size) + # use integers so result matches NCCL exactly + inp1 = torch.ones(sz, dtype=dtype, device=torch.cuda.current_device()) * random.randint(1, 32) + inp2 = torch.ones(sz, dtype=dtype, device=torch.cuda.current_device()) * random.randint(1, 32) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + out1 = fa.all_reduce(inp1) + out2 = fa.all_reduce(inp2) + # the input buffer is immediately modified to test synchronization + dist.all_reduce(inp1) + dist.all_reduce(inp2) + fa.register_graph_buffers() + graph.replay() + torch.cuda.synchronize() + + assert torch.allclose(out1, inp1) + assert torch.allclose(out2, inp2) + if rank == 0: + print("passed", sz, dtype) + + +def test_manual_registration(): + sz = 1024 + fa = FastAllreduce(rank, world_size) + inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) + fa.register_buffer(inp) + out = fa.all_reduce(inp) + assert torch.allclose(out, inp * world_size) + + +if __name__ == "__main__": + print(test_sizes) + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for sz in test_sizes: + test_fast_ar(sz, dtype) + test_manual_registration() diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index b1d5f5b9fb88e..01cf52251af43 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -3,10 +3,33 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, + get_tensor_model_parallel_rank ) +from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce -def tensor_model_parallel_all_reduce(input_): +fa_handle = None +is_capturing = False + + +def init_fast_ar(): + global fa_handle + world_size = get_tensor_model_parallel_world_size() + if world_size > 1: + fa_handle = FastAllreduce(get_tensor_model_parallel_rank(), world_size) + + +def begin_capture(): + global is_capturing + is_capturing = True + + +def end_capture(): + global is_capturing + is_capturing = False + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor): """All-reduce the input tensor across model parallel group. NOTE: This operation is applied in-place on the input tensor. @@ -14,9 +37,20 @@ def tensor_model_parallel_all_reduce(input_): # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - # All-reduce. + # fast allreduce only works with IPC pre-registered buffer. + # This is only handled when captured with cuda graph + if is_capturing and fa_handle is not None: + if torch.cuda.is_current_stream_capturing(): + if fa_handle.should_fast_ar(input_): + return fa_handle.all_reduce(input_) + else: + if fa_handle.should_fast_ar(input_): + # if warm up, mimic the allocation pattern + # since fast allreduce is out-of-place + return torch.empty_like(input_) + torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) + group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py new file mode 100644 index 0000000000000..c2afcc721ea4a --- /dev/null +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -0,0 +1,93 @@ +import os +import torch +import torch.distributed as dist +import pynvml +from vllm.logger import init_logger +from vllm._C import fast_ar + +logger = init_logger(__name__) + + +# query if the set of gpus are fully connected by nvlink (1 hop) +def full_nvlink(rank, world_size): + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(rank) + for i in range(world_size): + if i != rank: + try: + link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) + if not link_state: + return False + except pynvml.NVMLError as error: + logger.info( + f"NVLink detection failed with message \"{str(error)}\". " + "This is normal if your machine has no NVLink equipped") + return False + pynvml.nvmlShutdown() + return True + + +class FastAllreduce: + + def __init__(self, rank, world_size, max_size=8192*1024) -> None: + self.meta = torch.zeros(fast_ar.meta_size() + max_size, dtype=torch.uint8, device=rank) + self.max_size = max_size + self.world_size = world_size + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = full_nvlink(rank, world_size) + self._ptr = fast_ar.prepare_buffer(self.meta.data_ptr(), handles, + offsets, rank, + self.full_nvlink) + self.fast_cond = self.full_nvlink or world_size <= 2 + self.is_capturing = False + + def _get_ipc_meta(self, inp: torch.Tensor): + data = inp.storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + all_data = [None] * self.world_size + dist.all_gather_object(all_data, shard_data) + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0]) + offsets.append(all_data[i][1]) + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + fast_ar.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = fast_ar.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + fast_ar.register_graph_buffers(self._ptr, handles, offsets) + + def should_fast_ar(self, inp: torch.Tensor): + inp_size = inp.numel() * torch.finfo(inp.dtype).bits // 8 + if self.fast_cond: + return inp_size <= self.max_size + # 4 pcie gpus use 2 stage AR, and is only faster than NCCL + # when size <= 512k + return self.world_size <= 4 and inp_size <= 512 * 1024 + + def all_reduce(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + fast_ar.allreduce(self._ptr, inp, out) + return out + + def close(self): + if self._ptr: + fast_ar.dispose(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 276ef0708847a..38461c2c5e8e9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,6 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +import vllm.model_executor.parallel_utils.communication_op as comm_op logger = init_logger(__name__) @@ -406,7 +407,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - + + comm_op.init_fast_ar() + comm_op.begin_capture() # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): @@ -430,6 +433,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner + comm_op.end_capture() + if comm_op.fa_handle is not None: + comm_op.fa_handle.register_graph_buffers() end_time = time.perf_counter() elapsed_time = end_time - start_time From 89f8b974c8499a053d204a6135c16e087f1968da Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 05:04:02 +0000 Subject: [PATCH 02/53] add missing cuda free --- csrc/fast_allreduce.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index 4bbfae8686c47..f9a782444c3af 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -335,7 +335,7 @@ class FastAllreduce { std::unordered_map buffers_; Metadata *meta_; - RankData *d_rank_data_base_, *d_rank_data_end_; + RankData *d_rank_data_start_, *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; std::vector ipc_handles_; @@ -367,7 +367,8 @@ class FastAllreduce { sg_.signals[i] = &rank_meta->sg; } size_t rank_data_sz = 16 * 1024 * 1024; - CUDACHECK(cudaMalloc(&d_rank_data_base_, rank_data_sz)); + CUDACHECK(cudaMalloc(&d_rank_data_start_, rank_data_sz)); + d_rank_data_base_ = d_rank_data_start_; d_rank_data_end_ = d_rank_data_base_ + rank_data_sz / sizeof(RankData); } @@ -498,6 +499,7 @@ class FastAllreduce { for (auto ptr : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } + CUDACHECK(cudaFree(d_rank_data_start_)); } }; From 53ed0f9d807b385ab5e2764171aedb8f615a2e1d Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 05:07:51 +0000 Subject: [PATCH 03/53] link driver --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 054622d22a988..beddbe5df2536 100644 --- a/setup.py +++ b/setup.py @@ -235,6 +235,7 @@ def get_torch_arch_list() -> Set[str]: "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, + libraries=["cuda"] ) ext_modules.append(vllm_extension) From 39000b828cea4c9d3d2a4dac388fb3b1e8c26108 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 05:44:13 +0000 Subject: [PATCH 04/53] add more notes --- csrc/fast_allreduce.cuh | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index f9a782444c3af..e503d781afb09 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -169,7 +169,7 @@ __device__ __forceinline__ void end_sync(const RankSignals &sg, // Only the last completing block can perform the end synchronization // This can ensures when the final busy wait ends, all ranks must have - // finished reading each other's buffer, and the kernel can exit. + // finished reading each other's buffer. if (num == gridDim.x - 1) { if (threadIdx.x == 32) { // reset in a different warp @@ -180,6 +180,8 @@ __device__ __forceinline__ void end_sync(const RankSignals &sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end.data[rank] = 255; } + // if this is the final sync, only one block needs it + // because kernel exit can serve as sync if constexpr (final_sync) { if (threadIdx.x == 0) { while (meta->sg.end.flag != FLAG) @@ -294,9 +296,12 @@ __global__ void __launch_bounds__(512, 1) int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; using A = typename packed_t::A; - // do the actual reduction auto tmp_out = get_tmp_buf

(sg.signals[rank]); constexpr int hg = ngpus / 2; + // Actually not quite half butterfly. + // This is an all-to-all within each group containing half of the ranks + // followed by cross-group add. Equivalent to half butterfly when there + // are 4 GPUs, a common case for PCIe cards like T4 and A10. P *ptrs[hg]; { int start = rank - rank % hg; @@ -340,10 +345,10 @@ class FastAllreduce { std::vector ipc_handles_; /** - * meta is a pointer to device metadata and acutual data. + * meta is a pointer to device metadata and temporary buffer for allreduce. * * There's a total of sizeof(Metadata) of prefix before the actual data, - * so meta + 1 points to actual allreduce buffer. + * so meta + 1 points to actual temporary buffer. */ FastAllreduce(Metadata *meta, const cudaIpcMemHandle_t *handles, const std::vector &offsets, int rank, From f4dc283ef308ac43425f7339ae830d96bd9c5624 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 05:46:37 +0000 Subject: [PATCH 05/53] add todo --- csrc/fast_allreduce.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index e503d781afb09..2903d76e48a51 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -258,6 +258,8 @@ __global__ void __launch_bounds__(512, 1) } tmp_out[idx - start] = downcast

(tmp); } + // Maybe TODO: replace this with per-block release-acquire + // can save about 1-2us (not a lot though) end_sync(sg, meta, rank); // stage 2: allgather From 2f494545654f78d6fa3a02092459726ae05797c9 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 23:43:13 +0000 Subject: [PATCH 06/53] add flag and format --- setup.py | 16 ++++++-------- tests/kernels/test_fast_ar.py | 18 ++++++++++----- vllm/config.py | 22 ++++++++++++------- vllm/engine/arg_utils.py | 5 +++++ vllm/engine/llm_engine.py | 1 + vllm/entrypoints/llm.py | 3 +++ .../parallel_utils/communication_op.py | 9 +++----- .../parallel_utils/fast_allreduce.py | 12 +++++----- vllm/worker/model_runner.py | 5 +++-- 9 files changed, 54 insertions(+), 37 deletions(-) diff --git a/setup.py b/setup.py index beddbe5df2536..56d195b28d916 100644 --- a/setup.py +++ b/setup.py @@ -228,15 +228,13 @@ def get_torch_arch_list() -> Set[str]: vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu") -vllm_extension = CUDAExtension( - name="vllm._C", - sources=vllm_extension_sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, - libraries=["cuda"] -) +vllm_extension = CUDAExtension(name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + libraries=["cuda"]) ext_modules.append(vllm_extension) diff --git a/tests/kernels/test_fast_ar.py b/tests/kernels/test_fast_ar.py index 1d2edf36ab43e..5a90936c6f500 100644 --- a/tests/kernels/test_fast_ar.py +++ b/tests/kernels/test_fast_ar.py @@ -19,19 +19,23 @@ test_count = 8 if rank == 0: - test_sizes = [random.randint(1024, 2048*1024) for i in range(test_count)] + test_sizes = [random.randint(1024, 2048 * 1024) for i in range(test_count)] for i, v in enumerate(test_sizes): - test_sizes[i] -= test_sizes[i] % 8 + test_sizes[i] -= v % 8 else: test_sizes = [0] * test_count dist.broadcast_object_list(test_sizes, src=0) -def test_fast_ar(sz: int, dtype): +def test_fast_ar(sz: int, dtype): fa = FastAllreduce(rank, world_size) # use integers so result matches NCCL exactly - inp1 = torch.ones(sz, dtype=dtype, device=torch.cuda.current_device()) * random.randint(1, 32) - inp2 = torch.ones(sz, dtype=dtype, device=torch.cuda.current_device()) * random.randint(1, 32) + inp1 = torch.ones(sz, dtype=dtype, + device=torch.cuda.current_device()) * random.randint( + 1, 32) + inp2 = torch.ones(sz, dtype=dtype, + device=torch.cuda.current_device()) * random.randint( + 1, 32) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): @@ -53,7 +57,9 @@ def test_fast_ar(sz: int, dtype): def test_manual_registration(): sz = 1024 fa = FastAllreduce(rank, world_size) - inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) + inp = torch.ones(sz, + dtype=torch.float32, + device=torch.cuda.current_device()) fa.register_buffer(inp) out = fa.all_reduce(inp) assert torch.allclose(out, inp * world_size) diff --git a/vllm/config.py b/vllm/config.py index 353189f6e3381..5486f0b07f708 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -335,20 +335,26 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. + disable_fast_allreduce: Only applicable if enforce_eage is False and using tensor parallelism. Whether to disable fast allreduce path """ - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: bool, - max_parallel_loading_workers: Optional[int] = None, - ) -> None: + def __init__(self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, + disable_fast_allreduce=False) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers - + self.disable_fast_allreduce = disable_fast_allreduce + if not disable_fast_allreduce and (is_hip() + or pipeline_parallel_size > 1): + self.disable_fast_allreduce = False + logger.info( + "Fast allreduce automatically disabled. Not supported on HIP and pipeline parallel" + ) self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7a571ceefbc85..aabcadfbc787c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,6 +35,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + disable_fast_allreduce = False def __post_init__(self): if self.tokenizer is None: @@ -200,6 +201,10 @@ def add_cli_args( help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') + parser.add_argument('--disable_fast_allreduce', + type=int, + default=EngineArgs.disable_fast_allreduce, + help='See ParallelConfig') return parser @classmethod diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d91ab1430735c..8d24d6306d094 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -83,6 +83,7 @@ def __init__( f"download_dir={model_config.download_dir!r}, " f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"disable_fast_allreduce={parallel_config.disable_fast_allreduce}, " f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed})") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0700298b03a3d..f28c66f3e9d9b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -63,6 +63,7 @@ class LLM: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + disable_fast_allreduce: See ParallelConfig """ def __init__( @@ -81,6 +82,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, + disable_fast_allreduce=False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -100,6 +102,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + disable_fast_allreduce=disable_fast_allreduce, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 01cf52251af43..644df5be46060 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,13 +1,10 @@ import torch from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank -) + get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, + get_tensor_model_parallel_rank) from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce - fa_handle = None is_capturing = False @@ -50,7 +47,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor): return torch.empty_like(input_) torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) + group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index c2afcc721ea4a..3bf703f4279e3 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -1,4 +1,3 @@ -import os import torch import torch.distributed as dist import pynvml @@ -29,15 +28,16 @@ def full_nvlink(rank, world_size): class FastAllreduce: - def __init__(self, rank, world_size, max_size=8192*1024) -> None: - self.meta = torch.zeros(fast_ar.meta_size() + max_size, dtype=torch.uint8, device=rank) + def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: + self.meta = torch.zeros(fast_ar.meta_size() + max_size, + dtype=torch.uint8, + device=rank) self.max_size = max_size self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink(rank, world_size) self._ptr = fast_ar.prepare_buffer(self.meta.data_ptr(), handles, - offsets, rank, - self.full_nvlink) + offsets, rank, self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 self.is_capturing = False @@ -48,7 +48,7 @@ def _get_ipc_meta(self, inp: torch.Tensor): data[3], # offset of base ptr ) return self._gather_ipc_meta(shard_data) - + def _gather_ipc_meta(self, shard_data): all_data = [None] * self.world_size dist.all_gather_object(all_data, shard_data) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 38461c2c5e8e9..d1eddb3a7c35c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -407,8 +407,9 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - - comm_op.init_fast_ar() + + if not self.model_config.disable_fast_allreduce: + comm_op.init_fast_ar() comm_op.begin_capture() # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. From 16447e54f334754c74b06e32de8d9289efef0a6b Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 23:45:57 +0000 Subject: [PATCH 07/53] fix --- vllm/engine/arg_utils.py | 2 +- vllm/worker/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aabcadfbc787c..48c3a5fdcc8fe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,7 +35,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 - disable_fast_allreduce = False + disable_fast_allreduce: bool = False def __post_init__(self): if self.tokenizer is None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d1eddb3a7c35c..1a3ff0f1932ec 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -408,7 +408,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - if not self.model_config.disable_fast_allreduce: + if not self.parallel_config.disable_fast_allreduce: comm_op.init_fast_ar() comm_op.begin_capture() # NOTE: Capturing the largest batch size first may help reduce the From 60a51f232859dc9fbcdacee8d4ffc352cd6e41e6 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 19 Dec 2023 23:54:46 +0000 Subject: [PATCH 08/53] fix arg passing --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 48c3a5fdcc8fe..116570b4707a0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -232,7 +232,8 @@ def create_engine_configs( parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, - self.max_parallel_loading_workers) + self.max_parallel_loading_workers, + self.disable_fast_allreduce) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, From 2150a90efdce0e9cd4ce1ef633b9385249d8eb2e Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 20 Dec 2023 03:42:17 +0000 Subject: [PATCH 09/53] trailing comma --- vllm/config.py | 14 ++++++++------ vllm/entrypoints/llm.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3586099fede47..52111f502bf39 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -334,12 +334,14 @@ class ParallelConfig: disable_fast_allreduce: Only applicable if enforce_eage is False and using tensor parallelism. Whether to disable fast allreduce path """ - def __init__(self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: bool, - max_parallel_loading_workers: Optional[int] = None, - disable_fast_allreduce=False) -> None: + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, + disable_fast_allreduce: bool = False, + ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f28c66f3e9d9b..f017e03ff9f5c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -82,7 +82,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_fast_allreduce=False, + disable_fast_allreduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: From cd898ba3def4b74c43f561955323c8e8a5f1eee3 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 20 Dec 2023 04:32:27 +0000 Subject: [PATCH 10/53] use pytest for fast allreduce test --- setup.py | 16 +++--- tests/distributed/comm_utils.py | 36 ++++++++++++ tests/distributed/test_comm_ops.py | 30 +--------- tests/distributed/test_fast_allreduce.py | 63 ++++++++++++++++++++ tests/kernels/test_fast_ar.py | 73 ------------------------ 5 files changed, 110 insertions(+), 108 deletions(-) create mode 100644 tests/distributed/comm_utils.py create mode 100644 tests/distributed/test_fast_allreduce.py delete mode 100644 tests/kernels/test_fast_ar.py diff --git a/setup.py b/setup.py index 4a53d03db6316..3f99642613b08 100644 --- a/setup.py +++ b/setup.py @@ -228,13 +228,15 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") -vllm_extension = CUDAExtension(name="vllm._C", - sources=vllm_extension_sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, - libraries=["cuda"]) +vllm_extension = CUDAExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + libraries=["cuda"], +) ext_modules.append(vllm_extension) diff --git a/tests/distributed/comm_utils.py b/tests/distributed/comm_utils.py new file mode 100644 index 0000000000000..d9ed67c5aafa2 --- /dev/null +++ b/tests/distributed/comm_utils.py @@ -0,0 +1,36 @@ +"""Test the communication operators. + +Run `pytest tests/distributed/test_comm_ops.py --forked`. +""" +from multiprocessing import Process, set_start_method +import torch + +from vllm.config import ParallelConfig +from vllm.engine.ray_utils import get_open_port +from vllm.worker.worker import _init_distributed_environment + + +def init_test_distributed_environment(pipeline_parallel_size: int, + tensor_parallel_size: int, rank: int, + distributed_init_port: str): + parallel_config = ParallelConfig(pipeline_parallel_size, + tensor_parallel_size, + worker_use_ray=True) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + torch.cuda.set_device(rank) + _init_distributed_environment(parallel_config, rank, + distributed_init_method) + + +def multi_process_tensor_parallel(tensor_parallel_size, test_target): + set_start_method("spawn", force=True) + distributed_init_port = get_open_port() + processes = [] + for rank in range(tensor_parallel_size): + p = Process(target=test_target, + args=(tensor_parallel_size, rank, distributed_init_port)) + p.start() + processes.append(p) + for p in processes: + p.join() + assert all(p.exitcode == 0 for p in processes) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 733c7395811ef..1f56eedc16809 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,30 +2,14 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ -from multiprocessing import Process, set_start_method - import pytest import torch -from vllm.config import ParallelConfig -from vllm.engine.ray_utils import get_open_port from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather, ) -from vllm.worker.worker import _init_distributed_environment - - -def init_test_distributed_environment(pipeline_parallel_size: int, - tensor_parallel_size: int, rank: int, - distributed_init_port: str): - parallel_config = ParallelConfig(pipeline_parallel_size, - tensor_parallel_size, - worker_use_ray=True) - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - torch.cuda.set_device(rank) - _init_distributed_environment(parallel_config, rank, - distributed_init_method) +from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @@ -70,14 +54,4 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.parametrize("test_target", [all_reduce_test_worker, all_gather_test_worker]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - set_start_method("spawn", force=True) - distributed_init_port = get_open_port() - processes = [] - for rank in range(tensor_parallel_size): - p = Process(target=test_target, - args=(tensor_parallel_size, rank, distributed_init_port)) - p.start() - processes.append(p) - for p in processes: - p.join() - assert all(p.exitcode == 0 for p in processes) + multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py new file mode 100644 index 0000000000000..836d17c94f5ea --- /dev/null +++ b/tests/distributed/test_fast_allreduce.py @@ -0,0 +1,63 @@ +import torch +import random +import torch.distributed as dist +from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce +import pytest + +from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel + +random.seed(42) +test_sizes = [random.randint(1024, 2048 * 1024) for i in range(4)] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +def graph_registration(world_size, rank, distributed_init_port): + init_test_distributed_environment(1, world_size, rank, + distributed_init_port) + for sz in test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + fa = FastAllreduce(rank, world_size) + # use integers so result matches NCCL exactly + inp1 = torch.ones( + sz, dtype=dtype, + device=torch.cuda.current_device()) * random.randint(1, 16) + inp2 = torch.ones( + sz, dtype=dtype, + device=torch.cuda.current_device()) * random.randint(1, 16) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + out1 = fa.all_reduce(inp1) + out2 = fa.all_reduce(inp2) + # the input buffer is immediately modified to test synchronization + dist.all_reduce(inp1) + dist.all_reduce(inp2) + fa.register_graph_buffers() + graph.replay() + torch.cuda.synchronize() + + assert torch.allclose(out1, inp1) + assert torch.allclose(out2, inp2) + + +def manual_registration(world_size, rank, distributed_init_port): + init_test_distributed_environment(1, world_size, rank, + distributed_init_port) + sz = 1024 + fa = FastAllreduce(rank, world_size) + inp = torch.ones(sz, + dtype=torch.float32, + device=torch.cuda.current_device()) + fa.register_buffer(inp) + out = fa.all_reduce(inp) + assert torch.allclose(out, inp * world_size) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +@pytest.mark.parametrize("tensor_parallel_size", [2, 4]) +@pytest.mark.parametrize("test_target", + [manual_registration, graph_registration]) +def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): + multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/tests/kernels/test_fast_ar.py b/tests/kernels/test_fast_ar.py deleted file mode 100644 index 5a90936c6f500..0000000000000 --- a/tests/kernels/test_fast_ar.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Run this test like this: -torchrun --standalone --nnodes=1 --nproc-per-node=4 tests/kernels/test_fast_ar.py -""" -import torch -import os -import random -import torch.distributed as dist -from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce -from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel - -os.environ["MAX_JOBS"] = "16" -rank = int(os.environ["RANK"]) -local_rank = int(os.environ["LOCAL_RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) -torch.cuda.set_device(local_rank) -initialize_model_parallel(world_size) - -test_count = 8 -if rank == 0: - test_sizes = [random.randint(1024, 2048 * 1024) for i in range(test_count)] - for i, v in enumerate(test_sizes): - test_sizes[i] -= v % 8 -else: - test_sizes = [0] * test_count -dist.broadcast_object_list(test_sizes, src=0) - - -def test_fast_ar(sz: int, dtype): - fa = FastAllreduce(rank, world_size) - # use integers so result matches NCCL exactly - inp1 = torch.ones(sz, dtype=dtype, - device=torch.cuda.current_device()) * random.randint( - 1, 32) - inp2 = torch.ones(sz, dtype=dtype, - device=torch.cuda.current_device()) * random.randint( - 1, 32) - torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - out1 = fa.all_reduce(inp1) - out2 = fa.all_reduce(inp2) - # the input buffer is immediately modified to test synchronization - dist.all_reduce(inp1) - dist.all_reduce(inp2) - fa.register_graph_buffers() - graph.replay() - torch.cuda.synchronize() - - assert torch.allclose(out1, inp1) - assert torch.allclose(out2, inp2) - if rank == 0: - print("passed", sz, dtype) - - -def test_manual_registration(): - sz = 1024 - fa = FastAllreduce(rank, world_size) - inp = torch.ones(sz, - dtype=torch.float32, - device=torch.cuda.current_device()) - fa.register_buffer(inp) - out = fa.all_reduce(inp) - assert torch.allclose(out, inp * world_size) - - -if __name__ == "__main__": - print(test_sizes) - for dtype in [torch.float32, torch.float16, torch.bfloat16]: - for sz in test_sizes: - test_fast_ar(sz, dtype) - test_manual_registration() From 15672a9eb5ccceb4cb1e91d855a4d593fa6f335e Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Thu, 21 Dec 2023 03:54:43 +0000 Subject: [PATCH 11/53] small refactor --- csrc/fast_allreduce.cu | 29 ++------ csrc/fast_allreduce.cuh | 68 ++++++++++++++----- csrc/fast_allreduce_test.cu | 18 +++-- csrc/ops.h | 7 +- csrc/pybind.cpp | 2 +- .../parallel_utils/fast_allreduce.py | 11 ++- 6 files changed, 85 insertions(+), 50 deletions(-) diff --git a/csrc/fast_allreduce.cu b/csrc/fast_allreduce.cu index ab3423fd7e4ff..0c4168e706fbb 100644 --- a/csrc/fast_allreduce.cu +++ b/csrc/fast_allreduce.cu @@ -9,9 +9,10 @@ using fptr_t = uint64_t; static_assert(sizeof(void *) == sizeof(fptr_t)); -fptr_t prepare_buffer(fptr_t ptr, const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink) { +fptr_t init_fast_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -28,8 +29,8 @@ fptr_t prepare_buffer(fptr_t ptr, const std::vector &handles, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::FastAllreduce( - reinterpret_cast(ptr), ipc_handles, offsets, rank, - full_nvlink); + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { @@ -81,23 +82,7 @@ void register_buffer(fptr_t _fa, torch::Tensor &t, std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); - auto sz = fa->graph_unreg_buffers_.size(); - auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::vector handles(handle_sz * sz, 0); - std::vector offsets(sz); - for (int i = 0; i < sz; i++) { - auto ptr = fa->graph_unreg_buffers_[i]; - void *base_ptr; - // note: must share the base address of each allocation, or we get wrong address - auto _err = cuPointerGetAttribute( - &base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr); - if (_err != CUDA_SUCCESS) - throw std::runtime_error("failed to get pointer attr"); - CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&handles[i * handle_sz], - base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); - } - return std::make_pair(handles, offsets); + return fa->get_graph_buffer_ipc_meta(); } void register_graph_buffers(fptr_t _fa, const std::vector &handles, diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index 2903d76e48a51..de3956d964901 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -300,10 +300,10 @@ __global__ void __launch_bounds__(512, 1) using A = typename packed_t::A; auto tmp_out = get_tmp_buf

(sg.signals[rank]); constexpr int hg = ngpus / 2; - // Actually not quite half butterfly. + // Actually not quite half butterfly. // This is an all-to-all within each group containing half of the ranks // followed by cross-group add. Equivalent to half butterfly when there - // are 4 GPUs, a common case for PCIe cards like T4 and A10. + // are 4 GPUs, a common case for PCIe cards like T4 and A10. P *ptrs[hg]; { int start = rank - rank % hg; @@ -342,7 +342,8 @@ class FastAllreduce { std::unordered_map buffers_; Metadata *meta_; - RankData *d_rank_data_start_, *d_rank_data_base_, *d_rank_data_end_; + // stores the registered device pointers from all ranks + RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; std::vector ipc_handles_; @@ -351,14 +352,20 @@ class FastAllreduce { * * There's a total of sizeof(Metadata) of prefix before the actual data, * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor */ - FastAllreduce(Metadata *meta, const cudaIpcMemHandle_t *handles, + FastAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t *handles, const std::vector &offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), + full_nvlink_(full_nvlink), meta_(meta), - full_nvlink_(full_nvlink) { + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { Metadata *rank_meta; if (i != rank_) { @@ -373,10 +380,28 @@ class FastAllreduce { } sg_.signals[i] = &rank_meta->sg; } - size_t rank_data_sz = 16 * 1024 * 1024; - CUDACHECK(cudaMalloc(&d_rank_data_start_, rank_data_sz)); - d_rank_data_base_ = d_rank_data_start_; - d_rank_data_end_ = d_rank_data_base_ + rank_data_sz / sizeof(RankData); + } + + std::pair, std::vector> + get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::vector handles(handle_sz * num_buffers, 0); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void *base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char *)ptr) - ((char *)base_ptr); + } + return std::make_pair(handles, offsets); } void check_rank_data_capacity(size_t num = 1) { @@ -409,14 +434,22 @@ class FastAllreduce { buffers_[self] = d_data; } + // note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. void register_graph_buffers( const std::vector &handles, const std::vector> &offsets) { - auto sz = graph_unreg_buffers_.size(); - check_rank_data_capacity(sz); - for (int i = 0; i < sz; i++) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - RankData rd; + auto &rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { char *handle; @@ -432,9 +465,11 @@ class FastAllreduce { rd.ptrs[j] = self_ptr; } } - CUDACHECK(cudaMemcpy(d_rank_data_base_++, &rd, sizeof(RankData), - cudaMemcpyHostToDevice)); } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; graph_unreg_buffers_.clear(); } @@ -479,7 +514,7 @@ class FastAllreduce { (world_size_ <= 8 && bytes < 256 * 1024)) { \ KL(ngpus, cross_device_reduce_1stage); \ } else { \ - KL(ngpus, cross_device_reduce_2stage); \ + KL(ngpus, cross_device_reduce_1stage); \ } \ } else { \ KL(ngpus, cross_device_reduce_half_butterfly); \ @@ -506,7 +541,6 @@ class FastAllreduce { for (auto ptr : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } - CUDACHECK(cudaFree(d_rank_data_start_)); } }; diff --git a/csrc/fast_allreduce_test.cu b/csrc/fast_allreduce_test.cu index 79d6f78ea99af..6a5924c79d046 100644 --- a/csrc/fast_allreduce_test.cu +++ b/csrc/fast_allreduce_test.cu @@ -101,8 +101,12 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); + void *rank_data; + size_t rank_data_sz = 16 * 1024 * 1024; + CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); - vllm::FastAllreduce fa(buffer, data_handles, offsets, myRank); + vllm::FastAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, + myRank); auto *self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Metadata) + data_size * sizeof(T)); @@ -208,9 +212,15 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size << " me: " << my_diffs / data_size << std::endl; - // cudaFree(result); - // CUDACHECK(cudaStreamDestroy(stream)); + CUDACHECK(cudaFree(result)); + CUDACHECK(cudaFree(self_data_copy)); + CUDACHECK(cudaFree(rank_data)); + CUDACHECK(cudaFree(buffer)); CUDACHECK(cudaFree(states)); + CUDACHECK(cudaFreeHost(verification_buffer)); + CUDACHECK(cudaFreeHost(nccl_result)); + CUDACHECK(cudaFreeHost(my_result)); + CUDACHECK(cudaStreamDestroy(stream)); } int main(int argc, char **argv) { @@ -233,7 +243,7 @@ int main(int argc, char **argv) { // } // } for (int sz = 512; sz <= (4 << 20); sz *= 2) { - run(myRank, nRanks, comm, 512, 36, sz); + run(myRank, nRanks, comm, 512, 36, sz + 8 * 50); } cudaProfilerStop(); diff --git a/csrc/ops.h b/csrc/ops.h index c168c4a73b3c4..86a91e9ca0064 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,9 +92,10 @@ void gptq_shuffle( using fptr_t = uint64_t; -fptr_t prepare_buffer(fptr_t ptr, const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); +fptr_t init_fast_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink); void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); void dispose(fptr_t _fa); int meta_size(); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index c46026e05ee54..c96ed6bfaa50f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -83,7 +83,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Gets the specified device attribute."); pybind11::module fast_ar = m.def_submodule("fast_ar", "fast allreduce"); - fast_ar.def("prepare_buffer", &prepare_buffer, "prepare_buffer"); + fast_ar.def("init_fast_ar", &init_fast_ar, "init_fast_ar"); fast_ar.def("allreduce", &allreduce, "allreduce"); fast_ar.def("dispose", &dispose, "dispose"); fast_ar.def("meta_size", &meta_size, "meta_size"); diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index 3bf703f4279e3..d5d4d5567f9d2 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -28,16 +28,21 @@ def full_nvlink(rank, world_size): class FastAllreduce: + # max_size: max supported allreduce size def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: + # buffers memory are owned by this Python class and passed to C++ self.meta = torch.zeros(fast_ar.meta_size() + max_size, dtype=torch.uint8, - device=rank) + device="cuda") + self.rank_data = torch.empty(16 * 1024 * 1024, + dtype=torch.uint8, + device="cuda") self.max_size = max_size self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink(rank, world_size) - self._ptr = fast_ar.prepare_buffer(self.meta.data_ptr(), handles, - offsets, rank, self.full_nvlink) + self._ptr = fast_ar.init_fast_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 self.is_capturing = False From ad7d2208685dee4ad3b77b471f57a972916b4c42 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Thu, 21 Dec 2023 05:18:35 +0000 Subject: [PATCH 12/53] cleanup code add verify correctness --- csrc/fast_allreduce.cuh | 68 +++++++++++++++++-------------------- csrc/fast_allreduce_test.cu | 11 +++++- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index de3956d964901..dd48b20ca9ce6 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -38,7 +38,7 @@ struct Metadata { static_assert(offsetof(Metadata, counter) == 128); static_assert(sizeof(Metadata) == 256); -struct __align__(16) RankData { void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct RankSignals { volatile Signal *signals[8]; @@ -198,24 +198,34 @@ __device__ __forceinline__ void end_sync(const RankSignals &sg, } } +template +DINLINE P packed_reduce(const P *ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData *_dp, RankSignals sg, volatile Metadata *meta, T *__restrict__ result, int rank, int size) { - auto dp = *_dp; + using P = typename packed_t::P; + using A = typename packed_t::A; + const P *ptrs[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (P *)_dp->ptrs[target]; + } start_sync(sg, meta, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - using P = typename packed_t::P; - using A = typename packed_t::A; - A tmp = upcast(((P *)dp.ptrs[0])[idx]); -#pragma unroll - for (int i = 1; i < ngpus; i++) { - packed_assign_add(tmp, upcast(((P *)dp.ptrs[i])[idx])); - } - ((P *)result)[idx] = downcast

(tmp); + ((P *)result)[idx] = packed_reduce(ptrs, idx); } end_sync(sg, meta, rank); } @@ -230,9 +240,6 @@ __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData *_dp, RankSignals sg, volatile Metadata *meta, T *__restrict__ result, int rank, int size) { - auto dp = *_dp; - start_sync(sg, meta, rank); - int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -240,23 +247,19 @@ __global__ void __launch_bounds__(512, 1) int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; - P *ptrs[ngpus]; + const P *ptrs[ngpus]; P *tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (P *)dp.ptrs[target]; + ptrs[i] = (const P *)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; + start_sync(sg, meta, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { - A tmp = upcast(ptrs[0][idx]); -#pragma unroll - for (int i = 1; i < ngpus; i++) { - packed_assign_add(tmp, upcast(ptrs[i][idx])); - } - tmp_out[idx - start] = downcast

(tmp); + tmp_out[idx - start] = packed_reduce(ptrs, idx); } // Maybe TODO: replace this with per-block release-acquire // can save about 1-2us (not a lot though) @@ -266,7 +269,7 @@ __global__ void __launch_bounds__(512, 1) for (int idx = tid; idx < part; idx += stride) { #pragma unroll for (int i = 0; i < ngpus; i++) { - int dst_idx = i * part + idx; + int dst_idx = ((rank + i) % ngpus) * part + idx; ((P *)result)[dst_idx] = tmps[i][idx]; } } @@ -291,9 +294,6 @@ __global__ void __launch_bounds__(512, 1) volatile Metadata *meta, T *__restrict__ result, int rank, int size) { - auto dp = *_dp; - start_sync(sg, meta, rank); - int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -304,23 +304,18 @@ __global__ void __launch_bounds__(512, 1) // This is an all-to-all within each group containing half of the ranks // followed by cross-group add. Equivalent to half butterfly when there // are 4 GPUs, a common case for PCIe cards like T4 and A10. - P *ptrs[hg]; + const P *ptrs[hg]; { int start = rank - rank % hg; #pragma unroll for (int i = 0; i < hg; i++) { - ptrs[i] = (P *)dp.ptrs[i + start]; + ptrs[i] = (const P *)_dp->ptrs[i + start]; } } + start_sync(sg, meta, rank); for (int idx = tid; idx < size; idx += stride) { - A tmp = {0.0f, 0.0f}; -#pragma unroll - for (int i = 0; i < hg; i++) { - packed_assign_add(tmp, upcast(ptrs[i][idx])); - } - tmp_out[idx] = downcast

(tmp); + tmp_out[idx] = packed_reduce(ptrs, idx); } - end_sync(sg, meta, rank); auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); @@ -514,7 +509,7 @@ class FastAllreduce { (world_size_ <= 8 && bytes < 256 * 1024)) { \ KL(ngpus, cross_device_reduce_1stage); \ } else { \ - KL(ngpus, cross_device_reduce_1stage); \ + KL(ngpus, cross_device_reduce_2stage); \ } \ } else { \ KL(ngpus, cross_device_reduce_half_butterfly); \ @@ -543,5 +538,6 @@ class FastAllreduce { } } }; - +template void FastAllreduce::allreduce(cudaStream_t, half *, half *, int, + int, int); } // namespace vllm diff --git a/csrc/fast_allreduce_test.cu b/csrc/fast_allreduce_test.cu index 6a5924c79d046..a7c51f340c9a0 100644 --- a/csrc/fast_allreduce_test.cu +++ b/csrc/fast_allreduce_test.cu @@ -31,7 +31,7 @@ } while (0) __global__ void dummy_kernel() { - for (int i = 0; i < 500; i++) __nanosleep(1000000); // 500ms + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms } template @@ -202,6 +202,15 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, my_result, data_size); CUDACHECK(cudaStreamSynchronize(stream)); + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 1e-2) { + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f\n", myRank, + j, nccl_result[j], my_result[j]); + break; + } + } + long double nccl_diffs = 0.0; long double my_diffs = 0.0; for (int j = 0; j < data_size; j++) { From da5772ea03bff3b688b8f08651af4cb363dd09f8 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Thu, 21 Dec 2023 06:12:49 +0000 Subject: [PATCH 13/53] improve test robustness --- tests/distributed/test_fast_allreduce.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 836d17c94f5ea..43c62ef264324 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -7,7 +7,7 @@ from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel random.seed(42) -test_sizes = [random.randint(1024, 2048 * 1024) for i in range(4)] +test_sizes = [random.randint(1024, 2048 * 1024) for i in range(8)] for i, v in enumerate(test_sizes): test_sizes[i] -= v % 8 @@ -19,12 +19,14 @@ def graph_registration(world_size, rank, distributed_init_port): for dtype in [torch.float32, torch.float16, torch.bfloat16]: fa = FastAllreduce(rank, world_size) # use integers so result matches NCCL exactly - inp1 = torch.ones( - sz, dtype=dtype, - device=torch.cuda.current_device()) * random.randint(1, 16) - inp2 = torch.ones( - sz, dtype=dtype, - device=torch.cuda.current_device()) * random.randint(1, 16) + inp1 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): From 7644f7c561a3d3b6c1602fa34df27d3819a020e3 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 26 Dec 2023 15:47:14 -0800 Subject: [PATCH 14/53] Apply suggestions from code review Co-authored-by: Woosuk Kwon --- tests/distributed/comm_utils.py | 6 +----- tests/distributed/test_fast_allreduce.py | 6 ++++-- vllm/config.py | 6 ++++-- vllm/engine/arg_utils.py | 4 ++-- vllm/model_executor/parallel_utils/communication_op.py | 6 +++--- vllm/model_executor/parallel_utils/fast_allreduce.py | 5 +++-- vllm/worker/model_runner.py | 2 +- 7 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/distributed/comm_utils.py b/tests/distributed/comm_utils.py index d9ed67c5aafa2..49300dec91914 100644 --- a/tests/distributed/comm_utils.py +++ b/tests/distributed/comm_utils.py @@ -1,12 +1,8 @@ -"""Test the communication operators. - -Run `pytest tests/distributed/test_comm_ops.py --forked`. -""" from multiprocessing import Process, set_start_method import torch from vllm.config import ParallelConfig -from vllm.engine.ray_utils import get_open_port +from vllm.utils import get_open_port from vllm.worker.worker import _init_distributed_environment diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 43c62ef264324..0a71dddf4c043 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -1,8 +1,10 @@ -import torch import random + +import pytest +import torch import torch.distributed as dist + from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce -import pytest from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel diff --git a/vllm/config.py b/vllm/config.py index 52111f502bf39..bd3415f970569 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -331,7 +331,9 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. - disable_fast_allreduce: Only applicable if enforce_eage is False and using tensor parallelism. Whether to disable fast allreduce path + disable_fast_allreduce: Disable the custom all-reduce kernel and fall back to NCCL. + Note that the custom kernel is only used with CUDA graph and never used in eager + mode. """ def __init__( @@ -349,7 +351,7 @@ def __init__( self.disable_fast_allreduce = disable_fast_allreduce if not disable_fast_allreduce and (is_hip() or pipeline_parallel_size > 1): - self.disable_fast_allreduce = False + self.disable_fast_allreduce = True logger.info( "Fast allreduce automatically disabled. Not supported on HIP and pipeline parallel" ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index abc732fc1fc1b..eefed87167953 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -203,8 +203,8 @@ def add_cli_args( help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') - parser.add_argument('--disable_fast_allreduce', - type=int, + parser.add_argument('--disable-fast-allreduce', + action='store_true', default=EngineArgs.disable_fast_allreduce, help='See ParallelConfig') return parser diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 644df5be46060..703f793f088d1 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -9,19 +9,19 @@ is_capturing = False -def init_fast_ar(): +def init_fast_ar() -> None: global fa_handle world_size = get_tensor_model_parallel_world_size() if world_size > 1: fa_handle = FastAllreduce(get_tensor_model_parallel_rank(), world_size) -def begin_capture(): +def begin_capture() -> None: global is_capturing is_capturing = True -def end_capture(): +def end_capture() -> None: global is_capturing is_capturing = False diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index d5d4d5567f9d2..d8b9774dfa5b2 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -1,8 +1,9 @@ +import pynvml import torch import torch.distributed as dist -import pynvml -from vllm.logger import init_logger + from vllm._C import fast_ar +from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 77d1149787760..a1e42dc44b3aa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -import vllm.model_executor.parallel_utils.communication_op as comm_op +from vllm.model_executor.parallel_utils import communication_op as comm_op from vllm.utils import in_wsl logger = init_logger(__name__) From 4096c9da61edc54499cc546a386279ec71586bfc Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 27 Dec 2023 00:37:14 +0000 Subject: [PATCH 15/53] use context manager --- tests/distributed/test_fast_allreduce.py | 45 +++++++-------- .../parallel_utils/communication_op.py | 28 ++-------- .../parallel_utils/fast_allreduce.py | 51 ++++++++++++++++- vllm/worker/model_runner.py | 56 +++++++++---------- 4 files changed, 99 insertions(+), 81 deletions(-) diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 0a71dddf4c043..4522627d7b6b2 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist -from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce - +from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar +from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_reduce from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel random.seed(42) @@ -19,28 +19,25 @@ def graph_registration(world_size, rank, distributed_init_port): distributed_init_port) for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - fa = FastAllreduce(rank, world_size) - # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - out1 = fa.all_reduce(inp1) - out2 = fa.all_reduce(inp2) - # the input buffer is immediately modified to test synchronization - dist.all_reduce(inp1) - dist.all_reduce(inp2) - fa.register_graph_buffers() + with fast_ar.capture(enable=True): + # use integers so result matches NCCL exactly + inp1 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + out1 = tensor_model_parallel_all_reduce(inp1) + out2 = tensor_model_parallel_all_reduce(inp2) + # the input buffer is immediately modified to test synchronization + dist.all_reduce(inp1) + dist.all_reduce(inp2) graph.replay() - torch.cuda.synchronize() - assert torch.allclose(out1, inp1) assert torch.allclose(out2, inp2) @@ -49,7 +46,7 @@ def manual_registration(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) sz = 1024 - fa = FastAllreduce(rank, world_size) + fa = fast_ar.FastAllreduce(rank, world_size) inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 703f793f088d1..7e5c277cc44c2 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,29 +1,8 @@ import torch from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, - get_tensor_model_parallel_rank) -from vllm.model_executor.parallel_utils.fast_allreduce import FastAllreduce - -fa_handle = None -is_capturing = False - - -def init_fast_ar() -> None: - global fa_handle - world_size = get_tensor_model_parallel_world_size() - if world_size > 1: - fa_handle = FastAllreduce(get_tensor_model_parallel_rank(), world_size) - - -def begin_capture() -> None: - global is_capturing - is_capturing = True - - -def end_capture() -> None: - global is_capturing - is_capturing = False + get_tensor_model_parallel_world_size, get_tensor_model_parallel_group) +from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar def tensor_model_parallel_all_reduce(input_: torch.Tensor): @@ -36,7 +15,8 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor): return input_ # fast allreduce only works with IPC pre-registered buffer. # This is only handled when captured with cuda graph - if is_capturing and fa_handle is not None: + if fast_ar.is_capturing(): + fa_handle = fast_ar.get_handle() if torch.cuda.is_current_stream_capturing(): if fa_handle.should_fast_ar(input_): return fa_handle.all_reduce(input_) diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index d8b9774dfa5b2..83be875630830 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -1,15 +1,61 @@ +from contextlib import contextmanager import pynvml import torch import torch.distributed as dist +from typing import Optional from vllm._C import fast_ar from vllm.logger import init_logger +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) logger = init_logger(__name__) +_FA_HANDLE = None +_IS_CAPTURING = False + + +def init_fast_ar() -> None: + global _FA_HANDLE + world_size = get_tensor_model_parallel_world_size() + if world_size > 1: + _FA_HANDLE = FastAllreduce(get_tensor_model_parallel_rank(), + world_size) + + +def begin_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = True + + +def end_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = False + + +def is_capturing() -> bool: + return _IS_CAPTURING and _FA_HANDLE is not None + + +def get_handle() -> Optional["FastAllreduce"]: + return _FA_HANDLE + + +@contextmanager +def capture(enable: bool): + if enable: + init_fast_ar() + try: + begin_capture() + yield + finally: + end_capture() + if enable: + get_handle().register_graph_buffers() + # query if the set of gpus are fully connected by nvlink (1 hop) -def full_nvlink(rank, world_size): +def _is_full_nvlink(rank, world_size): pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(rank) for i in range(world_size): @@ -41,11 +87,10 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.max_size = max_size self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) - self.full_nvlink = full_nvlink(rank, world_size) + self.full_nvlink = _is_full_nvlink(rank, world_size) self._ptr = fast_ar.init_fast_ar(self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 - self.is_capturing = False def _get_ipc_meta(self, inp: torch.Tensor): data = inp.storage()._share_cuda_() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a1e42dc44b3aa..50d25781473e7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.model_executor.parallel_utils import communication_op as comm_op +from vllm.model_executor.parallel_utils import fast_allreduce from vllm.utils import in_wsl logger = init_logger(__name__) @@ -420,35 +420,31 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() - if not self.parallel_config.disable_fast_allreduce: - comm_op.init_fast_ar() - comm_op.begin_capture() - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): - # Create dummy input_metadata. - input_metadata = InputMetadata( - prompt_lens=[], - slot_mapping=slot_mapping[:batch_size], - max_context_len=self.max_context_len_to_capture, - context_lens=context_lens[:batch_size], - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) - - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - kv_caches, - input_metadata, - memory_pool=self.graph_memory_pool, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner - comm_op.end_capture() - if comm_op.fa_handle is not None: - comm_op.fa_handle.register_graph_buffers() + with fast_allreduce.capture( + enable=not self.parallel_config.disable_fast_allreduce): + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + # Create dummy input_metadata. + input_metadata = InputMetadata( + prompt_lens=[], + slot_mapping=slot_mapping[:batch_size], + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner end_time = time.perf_counter() elapsed_time = end_time - start_time From af015e7884449466c35500619982365ffeac4268 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 27 Dec 2023 00:52:19 +0000 Subject: [PATCH 16/53] add p2p check --- tests/distributed/test_fast_allreduce.py | 3 +- .../parallel_utils/fast_allreduce.py | 29 +++++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 4522627d7b6b2..90d2ab8440bec 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -46,7 +46,8 @@ def manual_registration(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) sz = 1024 - fa = fast_ar.FastAllreduce(rank, world_size) + fast_ar.init_fast_ar() + fa = fast_ar.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index 83be875630830..59d49f0934ca3 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -17,10 +17,10 @@ def init_fast_ar() -> None: global _FA_HANDLE + rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() - if world_size > 1: - _FA_HANDLE = FastAllreduce(get_tensor_model_parallel_rank(), - world_size) + if world_size > 1 and _can_p2p(rank, world_size): + _FA_HANDLE = FastAllreduce(rank, world_size) def begin_capture() -> None: @@ -73,6 +73,29 @@ def _is_full_nvlink(rank, world_size): return True +def _can_p2p(rank, world_size): + pynvml.nvmlInit() + handle1 = pynvml.nvmlDeviceGetHandleByIndex(rank) + for i in range(world_size): + if i != rank: + handle2 = pynvml.nvmlDeviceGetHandleByIndex(rank) + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle1, handle2, pynvml.NVML_P2P_CAPS_INDEX_READ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + logger.info( + f"P2P is not supported between device {i} and {rank}. " + "Fast allreduce will be disabled") + return False + except pynvml.NVMLError as error: + logger.info( + f"P2P detection failed with message \"{str(error)}\". " + "Fast allreduce will be disabled") + return False + pynvml.nvmlShutdown() + return True + + class FastAllreduce: # max_size: max supported allreduce size From 7f11bc58b4496cd973269d1ed7273645f5f5a10b Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 27 Dec 2023 01:00:13 +0000 Subject: [PATCH 17/53] address review --- csrc/fast_allreduce_test.cu | 9 +++++++++ vllm/config.py | 11 +++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/csrc/fast_allreduce_test.cu b/csrc/fast_allreduce_test.cu index a7c51f340c9a0..eabef83c8cf96 100644 --- a/csrc/fast_allreduce_test.cu +++ b/csrc/fast_allreduce_test.cu @@ -1,3 +1,12 @@ +/** + * This is a standalone test for fast allreduce. + * To compile, make sure you have MPI and NCCL installed in your system. + * export MPI_HOME=XXX + * nvcc -O2 -arch=native -std=c++17 fast_allreduce_test.cu -o fast_allreduce_test -lnccl -I${MPI_HOME}/include -lmpi + * + * To run: + * mpirun -np 8 ./fast_allreduce_test +*/ #include #include #include diff --git a/vllm/config.py b/vllm/config.py index bd3415f970569..9b69d81ebcd18 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -349,12 +349,6 @@ def __init__( self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_fast_allreduce = disable_fast_allreduce - if not disable_fast_allreduce and (is_hip() - or pipeline_parallel_size > 1): - self.disable_fast_allreduce = True - logger.info( - "Fast allreduce automatically disabled. Not supported on HIP and pipeline parallel" - ) self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True @@ -364,6 +358,11 @@ def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if is_hip() or self.pipeline_parallel_size > 1: + self.disable_fast_allreduce = True + logger.info( + "Fast allreduce automatically disabled. Not supported on HIP and pipeline parallel" + ) class SchedulerConfig: From 47653b5964434c09a8db059637330e7aa9fd6ea8 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 27 Dec 2023 04:32:48 +0000 Subject: [PATCH 18/53] do not reinit --- vllm/model_executor/parallel_utils/fast_allreduce.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index 59d49f0934ca3..17f3232ac77cc 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -17,6 +17,8 @@ def init_fast_ar() -> None: global _FA_HANDLE + if _FA_HANDLE is not None: + return rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() if world_size > 1 and _can_p2p(rank, world_size): From 78546de03a59e3e0f0bc7a6c5b6037c6158dc1f3 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Thu, 4 Jan 2024 07:03:37 +0000 Subject: [PATCH 19/53] format --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9a0a42f869523..3aecc36462f2c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -521,7 +521,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. with fast_allreduce.capture( - enable=not self.parallel_config.disable_fast_allreduce): + enable=not self.parallel_config.disable_fast_allreduce): for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): # Create dummy input_metadata. input_metadata = InputMetadata( From c8d41b32fae27b878baceaf87fb3442ac873066a Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Mon, 15 Jan 2024 07:22:01 +0000 Subject: [PATCH 20/53] fix tests and format --- tests/distributed/comm_utils.py | 24 +++++++++---------- tests/distributed/test_comm_ops.py | 2 +- tests/distributed/test_fast_allreduce.py | 3 +++ .../parallel_utils/fast_allreduce.py | 2 +- vllm/worker/model_runner.py | 2 +- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/distributed/comm_utils.py b/tests/distributed/comm_utils.py index 49300dec91914..ad13205f95cb0 100644 --- a/tests/distributed/comm_utils.py +++ b/tests/distributed/comm_utils.py @@ -1,5 +1,4 @@ -from multiprocessing import Process, set_start_method -import torch +import ray from vllm.config import ParallelConfig from vllm.utils import get_open_port @@ -13,20 +12,21 @@ def init_test_distributed_environment(pipeline_parallel_size: int, tensor_parallel_size, worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - torch.cuda.set_device(rank) _init_distributed_environment(parallel_config, rank, distributed_init_method) def multi_process_tensor_parallel(tensor_parallel_size, test_target): - set_start_method("spawn", force=True) + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init() + distributed_init_port = get_open_port() - processes = [] + refs = [] for rank in range(tensor_parallel_size): - p = Process(target=test_target, - args=(tensor_parallel_size, rank, distributed_init_port)) - p.start() - processes.append(p) - for p in processes: - p.join() - assert all(p.exitcode == 0 for p in processes) + refs.append( + test_target.remote(tensor_parallel_size, rank, + distributed_init_port)) + ray.get(refs) + + ray.shutdown() \ No newline at end of file diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 96befdf87b782..c3bc285214663 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -57,4 +57,4 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.parametrize("test_target", [all_reduce_test_worker, all_gather_test_worker]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) + multi_process_tensor_parallel(tensor_parallel_size, test_target) \ No newline at end of file diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 90d2ab8440bec..90e04bc2cab42 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -3,6 +3,7 @@ import pytest import torch import torch.distributed as dist +import ray from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_reduce @@ -14,6 +15,7 @@ test_sizes[i] -= v % 8 +@ray.remote(num_gpus=1, max_calls=1) def graph_registration(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) @@ -42,6 +44,7 @@ def graph_registration(world_size, rank, distributed_init_port): assert torch.allclose(out2, inp2) +@ray.remote(num_gpus=1, max_calls=1) def manual_registration(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index 17f3232ac77cc..79326150383d3 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -52,7 +52,7 @@ def capture(enable: bool): yield finally: end_capture() - if enable: + if enable and get_handle() is not None: get_handle().register_graph_buffers() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7bb500b492fe4..43ed692375ef2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -531,7 +531,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. with fast_allreduce.capture( - enable=not self.parallel_config.disable_fast_allreduce): + enable=not self.parallel_config.disable_fast_allreduce): for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( From 364c06e74f5d4f0f29fe6ba1d24fadc411a20d2d Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Mon, 15 Jan 2024 07:38:13 +0000 Subject: [PATCH 21/53] add a few more comments --- csrc/fast_allreduce.cuh | 14 ++++++------- csrc/fast_allreduce_test.cu | 41 +++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/csrc/fast_allreduce.cuh b/csrc/fast_allreduce.cuh index dd48b20ca9ce6..f507ad52308b5 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/fast_allreduce.cuh @@ -215,17 +215,15 @@ __global__ void __launch_bounds__(512, 1) int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; - const P *ptrs[ngpus]; -#pragma unroll - for (int i = 0; i < ngpus; i++) { - int target = (rank + i) % ngpus; - ptrs[i] = (P *)_dp->ptrs[target]; - } + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; start_sync(sg, meta, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = packed_reduce(ptrs, idx); + ((P *)result)[idx] = + packed_reduce((const P **)&dp.ptrs[0], idx); } end_sync(sg, meta, rank); } @@ -319,7 +317,7 @@ __global__ void __launch_bounds__(512, 1) end_sync(sg, meta, rank); auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); - // do the actual reduction + // do the cross group reduction for (int idx = tid; idx < size; idx += stride) { auto tmp = tmp_out[idx]; packed_assign_add(tmp, src[idx]); diff --git a/csrc/fast_allreduce_test.cu b/csrc/fast_allreduce_test.cu index eabef83c8cf96..6161448148ea6 100644 --- a/csrc/fast_allreduce_test.cu +++ b/csrc/fast_allreduce_test.cu @@ -1,12 +1,16 @@ /** - * This is a standalone test for fast allreduce. + * This is a standalone test for fast allreduce. * To compile, make sure you have MPI and NCCL installed in your system. * export MPI_HOME=XXX - * nvcc -O2 -arch=native -std=c++17 fast_allreduce_test.cu -o fast_allreduce_test -lnccl -I${MPI_HOME}/include -lmpi - * + * nvcc -O2 -arch=native -std=c++17 fast_allreduce_test.cu -o + * fast_allreduce_test -lnccl -I${MPI_HOME}/include -lmpi + * + * Warning: this C++ test is not designed to be very readable and was used + * during the rapid prototyping process. + * * To run: * mpirun -np 8 ./fast_allreduce_test -*/ + */ #include #include #include @@ -99,6 +103,17 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t data_handles[8]; vllm::Metadata *buffer; T *self_data_copy; + /** + * Allocate IPC buffer + * + * The first section is a temporary buffer for storing intermediate allreduce + * results, if a particular algorithm requires it. The second section is for + * the input to the allreduce. The actual API takes the input pointer as an + * argument (that is, they can and usually should be allocated separately). + * But since the input pointers and the temporary buffer all require IPC + * registration, they are allocated and registered together in the test for + * convenience. + */ CUDACHECK( cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); CUDACHECK(cudaMemset(buffer, 0, @@ -133,12 +148,12 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *verification_buffer; - CUDACHECK(cudaMallocHost(&verification_buffer, data_size * sizeof(double))); + double *ground_truth; + CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); curandState_t *states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); - gen_data<<<108, 1024, 0, stream>>>(states, self_data, verification_buffer, + gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, nRanks, data_size); CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream)); @@ -214,8 +229,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, for (unsigned long j = 0; j < data_size; j++) { auto diff = abs(nccl_result[j] - my_result[j]); if (diff >= 1e-2) { - printf("Rank %d: Verification mismatch at %lld: %f != (my) %f\n", myRank, - j, nccl_result[j], my_result[j]); + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", myRank, + j, nccl_result[j], my_result[j], ground_truth[j]); break; } } @@ -223,8 +238,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, long double nccl_diffs = 0.0; long double my_diffs = 0.0; for (int j = 0; j < data_size; j++) { - nccl_diffs += abs(nccl_result[j] - verification_buffer[j]); - my_diffs += abs(my_result[j] - verification_buffer[j]); + nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + my_diffs += abs(my_result[j] - ground_truth[j]); } if (myRank == 0) std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size @@ -235,7 +250,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaFree(rank_data)); CUDACHECK(cudaFree(buffer)); CUDACHECK(cudaFree(states)); - CUDACHECK(cudaFreeHost(verification_buffer)); + CUDACHECK(cudaFreeHost(ground_truth)); CUDACHECK(cudaFreeHost(nccl_result)); CUDACHECK(cudaFreeHost(my_result)); CUDACHECK(cudaStreamDestroy(stream)); @@ -260,7 +275,7 @@ int main(int argc, char **argv) { // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } - for (int sz = 512; sz <= (4 << 20); sz *= 2) { + for (int sz = 512; sz <= (32 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 50); } From 8884de85b1b538c3550a3dcfd9a0c7ab1c3d37f6 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Fri, 19 Jan 2024 12:12:46 +0000 Subject: [PATCH 22/53] use untyped storage --- vllm/model_executor/parallel_utils/fast_allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/fast_allreduce.py index 79326150383d3..3691135c8b210 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/fast_allreduce.py @@ -118,7 +118,7 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.fast_cond = self.full_nvlink or world_size <= 2 def _get_ipc_meta(self, inp: torch.Tensor): - data = inp.storage()._share_cuda_() + data = inp.untyped_storage()._share_cuda_() shard_data = ( data[1], # ipc handle to base ptr data[3], # offset of base ptr From cce7c9828cbf9008f28642c87c7a811b33c7c298 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Fri, 19 Jan 2024 12:33:33 +0000 Subject: [PATCH 23/53] move test utils --- tests/distributed/test_comm_ops.py | 2 +- tests/distributed/test_fast_allreduce.py | 2 +- tests/distributed/comm_utils.py => vllm/test_utils.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename tests/distributed/comm_utils.py => vllm/test_utils.py (100%) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index c3bc285214663..e11f102328b7f 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -10,7 +10,7 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather, ) -from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel +from vllm.test_utils import init_test_distributed_environment, multi_process_tensor_parallel @ray.remote(num_gpus=1, max_calls=1) diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_fast_allreduce.py index 90e04bc2cab42..43935806b8742 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_fast_allreduce.py @@ -7,7 +7,7 @@ from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_reduce -from tests.distributed.comm_utils import init_test_distributed_environment, multi_process_tensor_parallel +from vllm.test_utils import init_test_distributed_environment, multi_process_tensor_parallel random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for i in range(8)] diff --git a/tests/distributed/comm_utils.py b/vllm/test_utils.py similarity index 100% rename from tests/distributed/comm_utils.py rename to vllm/test_utils.py From 6b85e421bf4b98e500c4d39a3ef165251b9600f5 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 06:30:52 +0000 Subject: [PATCH 24/53] rename to custom all reduce --- ...fast_allreduce.cu => custom_all_reduce.cu} | 6 +-- ...st_allreduce.cuh => custom_all_reduce.cuh} | 4 +- ...duce_test.cu => custom_all_reduce_test.cu} | 10 ++--- csrc/ops.h | 2 +- csrc/pybind.cpp | 16 ++++---- setup.py | 2 +- ...allreduce.py => test_custom_all_reduce.py} | 8 ++-- vllm/config.py | 10 ++--- vllm/engine/arg_utils.py | 6 +-- vllm/engine/llm_engine.py | 2 +- vllm/entrypoints/llm.py | 6 +-- .../parallel_utils/communication_op.py | 16 ++++---- ...fast_allreduce.py => custom_all_reduce.py} | 38 +++++++++---------- vllm/worker/model_runner.py | 6 +-- 14 files changed, 66 insertions(+), 66 deletions(-) rename csrc/{fast_allreduce.cu => custom_all_reduce.cu} (94%) rename csrc/{fast_allreduce.cuh => custom_all_reduce.cuh} (99%) rename csrc/{fast_allreduce_test.cu => custom_all_reduce_test.cu} (97%) rename tests/distributed/{test_fast_allreduce.py => test_custom_all_reduce.py} (93%) rename vllm/model_executor/parallel_utils/{fast_allreduce.py => custom_all_reduce.py} (84%) diff --git a/csrc/fast_allreduce.cu b/csrc/custom_all_reduce.cu similarity index 94% rename from csrc/fast_allreduce.cu rename to csrc/custom_all_reduce.cu index 0c4168e706fbb..9e8387ef87c97 100644 --- a/csrc/fast_allreduce.cu +++ b/csrc/custom_all_reduce.cu @@ -3,13 +3,13 @@ #include #include -#include "fast_allreduce.cuh" +#include "custom_all_reduce.cuh" // fake pointer type using fptr_t = uint64_t; static_assert(sizeof(void *) == sizeof(fptr_t)); -fptr_t init_fast_ar(torch::Tensor &meta, torch::Tensor &rank_data, +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, const std::vector &handles, const std::vector &offsets, int rank, bool full_nvlink) { @@ -61,7 +61,7 @@ void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { #endif default: throw std::runtime_error( - "Fast allreduce only supports float32, float16 and bfloat16"); + "custom allreduce only supports float32, float16 and bfloat16"); } } diff --git a/csrc/fast_allreduce.cuh b/csrc/custom_all_reduce.cuh similarity index 99% rename from csrc/fast_allreduce.cuh rename to csrc/custom_all_reduce.cuh index f507ad52308b5..0602065ecdd71 100644 --- a/csrc/fast_allreduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -473,7 +473,7 @@ class FastAllreduce { auto d = packed_t::P::size; if (size % d != 0) throw std::runtime_error( - "fast allreduce currently requires input length to be multiple of " + + "custom allreduce currently requires input length to be multiple of " + std::to_string(d)); RankData *ptrs; @@ -522,7 +522,7 @@ class FastAllreduce { REDUCE_CASE(8) default: throw std::runtime_error( - "Fast allreduce only supports num gpus in (2,4,6,8). Actual num " + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " "gpus = " + std::to_string(world_size_)); } diff --git a/csrc/fast_allreduce_test.cu b/csrc/custom_all_reduce_test.cu similarity index 97% rename from csrc/fast_allreduce_test.cu rename to csrc/custom_all_reduce_test.cu index 6161448148ea6..ce4338ddd1e27 100644 --- a/csrc/fast_allreduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** - * This is a standalone test for fast allreduce. + * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. * export MPI_HOME=XXX - * nvcc -O2 -arch=native -std=c++17 fast_allreduce_test.cu -o - * fast_allreduce_test -lnccl -I${MPI_HOME}/include -lmpi + * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./fast_allreduce_test + * mpirun -np 8 ./custom_all_reduce_test */ #include #include @@ -20,7 +20,7 @@ #include #include "cuda_profiler_api.h" -#include "fast_allreduce.cuh" +#include "custom_all_reduce.cuh" #include "mpi.h" #include "nccl.h" diff --git a/csrc/ops.h b/csrc/ops.h index 86a91e9ca0064..a966dc47ebd35 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,7 +92,7 @@ void gptq_shuffle( using fptr_t = uint64_t; -fptr_t init_fast_ar(torch::Tensor &meta, torch::Tensor &rank_data, +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, const std::vector &handles, const std::vector &offsets, int rank, bool full_nvlink); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index c96ed6bfaa50f..5d406c3f48933 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -82,14 +82,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_device_attribute, "Gets the specified device attribute."); - pybind11::module fast_ar = m.def_submodule("fast_ar", "fast allreduce"); - fast_ar.def("init_fast_ar", &init_fast_ar, "init_fast_ar"); - fast_ar.def("allreduce", &allreduce, "allreduce"); - fast_ar.def("dispose", &dispose, "dispose"); - fast_ar.def("meta_size", &meta_size, "meta_size"); - fast_ar.def("register_buffer", ®ister_buffer, "register_buffer"); - fast_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, + pybind11::module custom_ar = m.def_submodule("custom_ar", "fast allreduce"); + custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); + custom_ar.def("allreduce", &allreduce, "allreduce"); + custom_ar.def("dispose", &dispose, "dispose"); + custom_ar.def("meta_size", &meta_size, "meta_size"); + custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); - fast_ar.def("register_graph_buffers", ®ister_graph_buffers, + custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); } diff --git a/setup.py b/setup.py index d6cbecda91041..74f049ab7d6b5 100644 --- a/setup.py +++ b/setup.py @@ -252,7 +252,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", - "csrc/fast_allreduce.cu", + "csrc/custom_all_reduce.cu", "csrc/pybind.cpp", ] diff --git a/tests/distributed/test_fast_allreduce.py b/tests/distributed/test_custom_all_reduce.py similarity index 93% rename from tests/distributed/test_fast_allreduce.py rename to tests/distributed/test_custom_all_reduce.py index 43935806b8742..3b61035fb4f47 100644 --- a/tests/distributed/test_fast_allreduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -5,7 +5,7 @@ import torch.distributed as dist import ray -from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar +from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_reduce from vllm.test_utils import init_test_distributed_environment, multi_process_tensor_parallel @@ -21,7 +21,7 @@ def graph_registration(world_size, rank, distributed_init_port): distributed_init_port) for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with fast_ar.capture(enable=True): + with custom_ar.capture(enable=True): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -49,8 +49,8 @@ def manual_registration(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) sz = 1024 - fast_ar.init_fast_ar() - fa = fast_ar.get_handle() + custom_ar.init_custom_ar() + fa = custom_ar.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) diff --git a/vllm/config.py b/vllm/config.py index 5f84fef1de3d3..6a2a41784aa9b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -325,7 +325,7 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. - disable_fast_allreduce: Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. Note that the custom kernel is only used with CUDA graph and never used in eager mode. """ @@ -336,13 +336,13 @@ def __init__( tensor_parallel_size: int, worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, - disable_fast_allreduce: bool = False, + disable_custom_all_reduce: bool = False, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers - self.disable_fast_allreduce = disable_fast_allreduce + self.disable_custom_all_reduce = disable_custom_all_reduce self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True @@ -353,9 +353,9 @@ def _verify_args(self) -> None: raise NotImplementedError( "Pipeline parallelism is not supported yet.") if is_hip() or self.pipeline_parallel_size > 1: - self.disable_fast_allreduce = True + self.disable_custom_all_reduce = True logger.info( - "Fast allreduce automatically disabled. Not supported on HIP and pipeline parallel" + "custom allreduce automatically disabled. Not supported on HIP and pipeline parallel" ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index eefed87167953..4afe347f3b494 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,7 +35,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 - disable_fast_allreduce: bool = False + disable_custom_all_reduce: bool = False def __post_init__(self): if self.tokenizer is None: @@ -205,7 +205,7 @@ def add_cli_args( 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-fast-allreduce', action='store_true', - default=EngineArgs.disable_fast_allreduce, + default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') return parser @@ -235,7 +235,7 @@ def create_engine_configs( self.tensor_parallel_size, self.worker_use_ray, self.max_parallel_loading_workers, - self.disable_fast_allreduce) + self.disable_custom_all_reduce) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3fc8e7d4c3b74..a93d59eb243bd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -80,7 +80,7 @@ def __init__( f"download_dir={model_config.download_dir!r}, " f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " - f"disable_fast_allreduce={parallel_config.disable_fast_allreduce}, " + f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed})") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5e39d0b888059..e553adc48bcf0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -63,7 +63,7 @@ class LLM: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. - disable_fast_allreduce: See ParallelConfig + disable_custom_all_reduce: See ParallelConfig """ def __init__( @@ -82,7 +82,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_fast_allreduce: bool = False, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -102,7 +102,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, - disable_fast_allreduce=disable_fast_allreduce, + disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index eee95e193a7fa..0742d2d97fa53 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -5,7 +5,7 @@ get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) -from vllm.model_executor.parallel_utils import fast_allreduce as fast_ar +from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar def tensor_model_parallel_all_reduce(input_: torch.Tensor): @@ -16,17 +16,17 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor): # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - # fast allreduce only works with IPC pre-registered buffer. + # custom allreduce only works with IPC pre-registered buffer. # This is only handled when captured with cuda graph - if fast_ar.is_capturing(): - fa_handle = fast_ar.get_handle() + if custom_ar.is_capturing(): + ca_handle = custom_ar.get_handle() if torch.cuda.is_current_stream_capturing(): - if fa_handle.should_fast_ar(input_): - return fa_handle.all_reduce(input_) + if ca_handle.should_custom_ar(input_): + return ca_handle.all_reduce(input_) else: - if fa_handle.should_fast_ar(input_): + if ca_handle.should_custom_ar(input_): # if warm up, mimic the allocation pattern - # since fast allreduce is out-of-place + # since custom allreduce is out-of-place return torch.empty_like(input_) torch.distributed.all_reduce(input_, diff --git a/vllm/model_executor/parallel_utils/fast_allreduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py similarity index 84% rename from vllm/model_executor/parallel_utils/fast_allreduce.py rename to vllm/model_executor/parallel_utils/custom_all_reduce.py index 3691135c8b210..a6835bf6e0238 100644 --- a/vllm/model_executor/parallel_utils/fast_allreduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -4,25 +4,25 @@ import torch.distributed as dist from typing import Optional -from vllm._C import fast_ar +from vllm._C import custom_ar from vllm.logger import init_logger from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) logger = init_logger(__name__) -_FA_HANDLE = None +_ca_handle = None _IS_CAPTURING = False -def init_fast_ar() -> None: - global _FA_HANDLE - if _FA_HANDLE is not None: +def init_custom_ar() -> None: + global _ca_handle + if _ca_handle is not None: return rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() if world_size > 1 and _can_p2p(rank, world_size): - _FA_HANDLE = FastAllreduce(rank, world_size) + _ca_handle = FastAllreduce(rank, world_size) def begin_capture() -> None: @@ -36,17 +36,17 @@ def end_capture() -> None: def is_capturing() -> bool: - return _IS_CAPTURING and _FA_HANDLE is not None + return _IS_CAPTURING and _ca_handle is not None def get_handle() -> Optional["FastAllreduce"]: - return _FA_HANDLE + return _ca_handle @contextmanager def capture(enable: bool): if enable: - init_fast_ar() + init_custom_ar() try: begin_capture() yield @@ -87,12 +87,12 @@ def _can_p2p(rank, world_size): if p2p_status != pynvml.NVML_P2P_STATUS_OK: logger.info( f"P2P is not supported between device {i} and {rank}. " - "Fast allreduce will be disabled") + "custom allreduce will be disabled") return False except pynvml.NVMLError as error: logger.info( f"P2P detection failed with message \"{str(error)}\". " - "Fast allreduce will be disabled") + "custom allreduce will be disabled") return False pynvml.nvmlShutdown() return True @@ -103,7 +103,7 @@ class FastAllreduce: # max_size: max supported allreduce size def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: # buffers memory are owned by this Python class and passed to C++ - self.meta = torch.zeros(fast_ar.meta_size() + max_size, + self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, device="cuda") self.rank_data = torch.empty(16 * 1024 * 1024, @@ -113,7 +113,7 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = _is_full_nvlink(rank, world_size) - self._ptr = fast_ar.init_fast_ar(self.meta, self.rank_data, handles, + self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 @@ -138,15 +138,15 @@ def _gather_ipc_meta(self, shard_data): def register_buffer(self, inp: torch.Tensor): handles, offsets = self._get_ipc_meta(inp) - fast_ar.register_buffer(self._ptr, inp, handles, offsets) + custom_ar.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - handle, offset = fast_ar.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) - fast_ar.register_graph_buffers(self._ptr, handles, offsets) + custom_ar.register_graph_buffers(self._ptr, handles, offsets) - def should_fast_ar(self, inp: torch.Tensor): + def should_custom_ar(self, inp: torch.Tensor): inp_size = inp.numel() * torch.finfo(inp.dtype).bits // 8 if self.fast_cond: return inp_size <= self.max_size @@ -157,12 +157,12 @@ def should_fast_ar(self, inp: torch.Tensor): def all_reduce(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - fast_ar.allreduce(self._ptr, inp, out) + custom_ar.allreduce(self._ptr, inp, out) return out def close(self): if self._ptr: - fast_ar.dispose(self._ptr) + custom_ar.dispose(self._ptr) self._ptr = 0 def __del__(self): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 990b547255d26..270765aca8ccb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,7 +12,7 @@ broadcast, broadcast_object_list) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.model_executor.parallel_utils import fast_allreduce +from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.utils import in_wsl logger = init_logger(__name__) @@ -606,8 +606,8 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - with fast_allreduce.capture( - enable=not self.parallel_config.disable_fast_allreduce): + with custom_all_reduce.capture( + enable=not self.parallel_config.disable_custom_all_reduce): for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( From 6ae050e90f21ad06c2c55f94c8dadfcaf22a8ce5 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 08:13:56 +0000 Subject: [PATCH 25/53] add support for eager mode --- csrc/custom_all_reduce.cu | 56 +++++++++++++++---- csrc/ops.h | 4 +- csrc/pybind.cpp | 5 +- tests/distributed/test_custom_all_reduce.py | 20 ++++--- .../parallel_utils/communication_op.py | 42 +++++++++----- .../parallel_utils/custom_all_reduce.py | 29 +++++++--- vllm/worker/model_runner.py | 3 +- vllm/worker/worker.py | 4 +- 8 files changed, 115 insertions(+), 48 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 9e8387ef87c97..4141303a47358 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -10,9 +10,9 @@ using fptr_t = uint64_t; static_assert(sizeof(void *) == sizeof(fptr_t)); fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink) { + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -33,29 +33,39 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } -void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +// Make sure tensor t's data lies completely within +// ((char)t.data_ptr()) + t.numel() * t.element_size() +// This is slightly weaker than t.is_contiguous() because it allows transposes. +// Currently, we need this information because stride information is not passed +// into the kernels +bool _is_weak_contiguous(torch::Tensor &t) { + return t.is_contiguous() || + t.storage().nbytes() == t.numel() * t.element_size(); +} + +void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, + cudaStream_t stream) { auto fa = reinterpret_cast(_fa); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - switch (inp.scalar_type()) { + TORCH_CHECK(_is_weak_contiguous(inp)); + TORCH_CHECK(_is_weak_contiguous(out)); + switch (out.scalar_type()) { case at::ScalarType::Float: { fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - inp.numel()); + out.numel()); break; } case at::ScalarType::Half: { fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - inp.numel()); + out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), inp.numel()); + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -65,6 +75,30 @@ void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { } } +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + TORCH_CHECK(_is_weak_contiguous(inp)); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, cudaMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + void dispose(fptr_t _fa) { auto fa = reinterpret_cast(_fa); delete fa; diff --git a/csrc/ops.h b/csrc/ops.h index a966dc47ebd35..c7fb2b770657c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -96,7 +96,9 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, const std::vector &handles, const std::vector &offsets, int rank, bool full_nvlink); -void allreduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out); void dispose(fptr_t _fa); int meta_size(); void register_buffer(fptr_t _fa, torch::Tensor &t, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 5d406c3f48933..853065c175c15 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -82,9 +82,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_device_attribute, "Gets the specified device attribute."); - pybind11::module custom_ar = m.def_submodule("custom_ar", "fast allreduce"); + pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); - custom_ar.def("allreduce", &allreduce, "allreduce"); + custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); + custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); custom_ar.def("dispose", &dispose, "dispose"); custom_ar.def("meta_size", &meta_size, "meta_size"); custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3b61035fb4f47..e3bc89c030b42 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -16,12 +16,13 @@ @ray.remote(num_gpus=1, max_calls=1) -def graph_registration(world_size, rank, distributed_init_port): +def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) + custom_ar.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_ar.capture(enable=True): + with custom_ar.capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -35,9 +36,9 @@ def graph_registration(world_size, rank, distributed_init_port): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): out1 = tensor_model_parallel_all_reduce(inp1) - out2 = tensor_model_parallel_all_reduce(inp2) # the input buffer is immediately modified to test synchronization dist.all_reduce(inp1) + out2 = tensor_model_parallel_all_reduce(inp2) dist.all_reduce(inp2) graph.replay() assert torch.allclose(out1, inp1) @@ -45,7 +46,7 @@ def graph_registration(world_size, rank, distributed_init_port): @ray.remote(num_gpus=1, max_calls=1) -def manual_registration(world_size, rank, distributed_init_port): +def eager_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) sz = 1024 @@ -54,8 +55,13 @@ def manual_registration(world_size, rank, distributed_init_port): inp = torch.ones(sz, dtype=torch.float32, device=torch.cuda.current_device()) - fa.register_buffer(inp) - out = fa.all_reduce(inp) + out = fa.all_reduce_unreg(inp) + assert torch.allclose(out, inp * world_size) + + inp = torch.ones(sz * 4, + dtype=torch.bfloat16, + device=torch.cuda.current_device()) + out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) @@ -63,6 +69,6 @@ def manual_registration(world_size, rank, distributed_init_port): reason="Need at least 4 GPUs to run the test.") @pytest.mark.parametrize("tensor_parallel_size", [2, 4]) @pytest.mark.parametrize("test_target", - [manual_registration, graph_registration]) + [eager_allreduce, graph_allreduce]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 0742d2d97fa53..bb30057692c72 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,4 +1,5 @@ import torch +from typing import Optional from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, @@ -8,27 +9,38 @@ from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -def tensor_model_parallel_all_reduce(input_: torch.Tensor): - """All-reduce the input tensor across model parallel group. - - NOTE: This operation is applied in-place on the input tensor. - """ - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - # custom allreduce only works with IPC pre-registered buffer. - # This is only handled when captured with cuda graph +def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: + ca_handle = custom_ar.get_handle() + # when custom allreduce is disabled, this will be None + if ca_handle is None: + return if custom_ar.is_capturing(): - ca_handle = custom_ar.get_handle() if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input_): - return ca_handle.all_reduce(input_) + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_reg(input) else: - if ca_handle.should_custom_ar(input_): + if ca_handle.should_custom_ar(input): # if warm up, mimic the allocation pattern # since custom allreduce is out-of-place - return torch.empty_like(input_) + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # fast allreduce incurs a cost of cudaMemcpy + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_unreg(input) + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group. + + NOTE: This operation may be applied in-place on the input tensor. + """ + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + out = custom_all_reduce(input_) + if out is not None: + return out torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index a6835bf6e0238..caf7b5b83cbf7 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -44,16 +44,15 @@ def get_handle() -> Optional["FastAllreduce"]: @contextmanager -def capture(enable: bool): - if enable: - init_custom_ar() +def capture(): try: begin_capture() yield finally: end_capture() - if enable and get_handle() is not None: - get_handle().register_graph_buffers() + handle = get_handle() + if handle is not None: + handle.register_graph_buffers() # query if the set of gpus are fully connected by nvlink (1 hop) @@ -106,7 +105,10 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, device="cuda") - self.rank_data = torch.empty(16 * 1024 * 1024, + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device="cuda") + self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device="cuda") self.max_size = max_size @@ -116,6 +118,7 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 + self.register_buffer(self.buffer) def _get_ipc_meta(self, inp: torch.Tensor): data = inp.untyped_storage()._share_cuda_() @@ -147,17 +150,25 @@ def register_graph_buffers(self): custom_ar.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - inp_size = inp.numel() * torch.finfo(inp.dtype).bits // 8 + inp_size = inp.numel() * inp.element_size() if self.fast_cond: return inp_size <= self.max_size # 4 pcie gpus use 2 stage AR, and is only faster than NCCL # when size <= 512k return self.world_size <= 4 and inp_size <= 512 * 1024 - def all_reduce(self, inp: torch.Tensor, out: torch.Tensor = None): + # all reduce, assuming inp tensor is IPC registered with register_buffer, or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.allreduce(self._ptr, inp, out) + custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def close(self): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 270765aca8ccb..fc43a1849934e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -606,8 +606,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - with custom_all_reduce.capture( - enable=not self.parallel_config.disable_custom_all_reduce): + with custom_all_reduce.capture(): for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c2a2ac148085b..da2c1ed97d1c5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,6 +10,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.communication_op import ( broadcast_object_list) +from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -73,7 +74,8 @@ def init_model(self) -> None: # Initialize the distributed environment. _init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method) - + if not self.parallel_config.disable_custom_all_reduce: + init_custom_ar() # Initialize the model. set_random_seed(self.model_config.seed) From 73ab0a83035956750793d25a66c95cca2469e8cf Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 08:14:34 +0000 Subject: [PATCH 26/53] format --- tests/distributed/test_custom_all_reduce.py | 5 ++--- vllm/model_executor/parallel_utils/custom_all_reduce.py | 9 ++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index e3bc89c030b42..db1180dd2e9d2 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -57,7 +57,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): device=torch.cuda.current_device()) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) - + inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=torch.cuda.current_device()) @@ -68,7 +68,6 @@ def eager_allreduce(world_size, rank, distributed_init_port): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") @pytest.mark.parametrize("tensor_parallel_size", [2, 4]) -@pytest.mark.parametrize("test_target", - [eager_allreduce, graph_allreduce]) +@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index caf7b5b83cbf7..0de3f96cc28c2 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -105,9 +105,7 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, device="cuda") - self.buffer = torch.empty(max_size, - dtype=torch.uint8, - device="cuda") + self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device="cuda") @@ -115,8 +113,9 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = _is_full_nvlink(rank, world_size) - self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, handles, - offsets, rank, self.full_nvlink) + self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, + handles, offsets, rank, + self.full_nvlink) self.fast_cond = self.full_nvlink or world_size <= 2 self.register_buffer(self.buffer) From bedf60ed41d2551e0e2df9732410ea2e2254ae1a Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 08:25:12 +0000 Subject: [PATCH 27/53] add comment --- vllm/model_executor/parallel_utils/communication_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 4c11cd5044b8f..bd79d72923d8b 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -2,7 +2,6 @@ from typing import Any, Dict, List, Optional, Union import torch -from typing import Optional from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, @@ -28,7 +27,9 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: return torch.empty_like(input) else: # note: outside of cuda graph context, - # fast allreduce incurs a cost of cudaMemcpy + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels if ca_handle.should_custom_ar(input): return ca_handle.all_reduce_unreg(input) From 8626b8c304f9cce13dc5e57c5a03d0a960921879 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 09:14:34 +0000 Subject: [PATCH 28/53] move function --- .../parallel_utils/communication_op.py | 25 +------------------ .../parallel_utils/custom_all_reduce.py | 23 +++++++++++++++++ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index bd79d72923d8b..85ca6926746ca 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -8,30 +8,7 @@ get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) -from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar - - -def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: - ca_handle = custom_ar.get_handle() - # when custom allreduce is disabled, this will be None - if ca_handle is None: - return - if custom_ar.is_capturing(): - if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_reg(input) - else: - if ca_handle.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) - else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_unreg(input) +from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 0de3f96cc28c2..db5ccdaf997b3 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -53,6 +53,29 @@ def capture(): handle = get_handle() if handle is not None: handle.register_graph_buffers() + + +def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: + ca_handle = get_handle() + # when custom allreduce is disabled, this will be None + if ca_handle is None: + return + if is_capturing(): + if torch.cuda.is_current_stream_capturing(): + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_reg(input) + else: + if ca_handle.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_unreg(input) # query if the set of gpus are fully connected by nvlink (1 hop) From 7da072365d7d613284a6f15504dcdcf8de6100b2 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 09:15:11 +0000 Subject: [PATCH 29/53] format --- vllm/model_executor/parallel_utils/custom_all_reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index db5ccdaf997b3..c93dd9e8a07d7 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -53,7 +53,7 @@ def capture(): handle = get_handle() if handle is not None: handle.register_graph_buffers() - + def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ca_handle = get_handle() From 2ab52d0bed97ad4ed5a0a3fbdb1cf220f3715489 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 10:14:52 +0000 Subject: [PATCH 30/53] add comments --- csrc/custom_all_reduce.cu | 24 ++++++++++++++----- .../parallel_utils/communication_op.py | 9 ++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 4141303a47358..82127c332715d 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -33,14 +33,26 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } -// Make sure tensor t's data lies completely within -// ((char)t.data_ptr()) + t.numel() * t.element_size() -// This is slightly weaker than t.is_contiguous() because it allows transposes. -// Currently, we need this information because stride information is not passed -// into the kernels +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ bool _is_weak_contiguous(torch::Tensor &t) { return t.is_contiguous() || - t.storage().nbytes() == t.numel() * t.element_size(); + t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size(); } void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 85ca6926746ca..2c6514b18de7c 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -14,7 +14,14 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group. - NOTE: This operation may be applied in-place on the input tensor. + NOTE: This operation will be applied in-place on the input tensor if + disable_custom_all_reduce is set to True. Otherwise, this operation may or + may not be applied in place depending on whether custom all reduce is + invoked for a particular tensor, which further depends on the tensor size + and GPU topology. + + TLDR: always assume this function modifies its input, but use the return + value as the output. """ # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: From dcf27354d46fe6d728d4fbb8f90e7c79bde2758b Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Sat, 20 Jan 2024 10:40:13 +0000 Subject: [PATCH 31/53] fix name --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4afe347f3b494..e70e3b18c0d51 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -203,7 +203,7 @@ def add_cli_args( help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') - parser.add_argument('--disable-fast-allreduce', + parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, help='See ParallelConfig') From 21f2fccf24b439dba428a3460da55c1d072420a2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jan 2024 05:19:06 +0000 Subject: [PATCH 32/53] Minor fixes on comments --- vllm/config.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6a2a41784aa9b..fcd64798dd592 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -325,9 +325,8 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. - disable_custom_all_reduce: Disable the custom all-reduce kernel and fall back to NCCL. - Note that the custom kernel is only used with CUDA graph and never used in eager - mode. + disable_custom_all_reduce: Disable the custom all-reduce kernel and + fall back to NCCL. """ def __init__( @@ -343,6 +342,7 @@ def __init__( self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce + self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True @@ -352,11 +352,16 @@ def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") - if is_hip() or self.pipeline_parallel_size > 1: + if is_hip(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on AMD GPUs.") + elif self.pipeline_parallel_size > 1: self.disable_custom_all_reduce = True logger.info( - "custom allreduce automatically disabled. Not supported on HIP and pipeline parallel" - ) + "Disabled the custom all-reduce kernel because it is not " + "supported with pipeline parallelism.") class SchedulerConfig: From 50cc5f8d34bbd3c4f861e4e45a29aa42ad9e0caa Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jan 2024 22:04:25 +0000 Subject: [PATCH 33/53] Don't compile for ROCm backend --- csrc/pybind.cpp | 7 +++++-- setup.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 853065c175c15..fdd092f49a8b2 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -82,6 +82,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_device_attribute, "Gets the specified device attribute."); +#ifndef USE_ROCM + // Custom all-reduce kernels pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); @@ -90,7 +92,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { custom_ar.def("meta_size", &meta_size, "meta_size"); custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, - "get_graph_buffer_ipc_meta"); + "get_graph_buffer_ipc_meta"); custom_ar.def("register_graph_buffers", ®ister_graph_buffers, - "register_graph_buffers"); + "register_graph_buffers"); +#endif } diff --git a/setup.py b/setup.py index 74f049ab7d6b5..794941fadcaa9 100644 --- a/setup.py +++ b/setup.py @@ -252,12 +252,12 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", - "csrc/custom_all_reduce.cu", "csrc/pybind.cpp", ] if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + vllm_extension_sources.append("csrc/custom_all_reduce.cu") if not _is_neuron(): vllm_extension = CUDAExtension( @@ -267,7 +267,6 @@ def get_torch_arch_list() -> Set[str]: "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, - libraries=["cuda"], ) ext_modules.append(vllm_extension) From 0581f4e3a2abdc40f2278149ee07053055ce57ce Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jan 2024 23:46:29 +0000 Subject: [PATCH 34/53] Minor fix for ROCm backend --- .../parallel_utils/custom_all_reduce.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index c93dd9e8a07d7..d7aa578584a12 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -1,14 +1,21 @@ from contextlib import contextmanager -import pynvml +from typing import Optional + import torch import torch.distributed as dist -from typing import Optional -from vllm._C import custom_ar from vllm.logger import init_logger from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) +try: + from vllm._C import custom_ar + import pynvml +except ImportError: + # For AMD GPUs + custom_ar = None + pynvml = None + logger = init_logger(__name__) _ca_handle = None From 704416a0a58267d51e4feb16eb01bfea1108b5b2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jan 2024 23:50:51 +0000 Subject: [PATCH 35/53] Use context manager for NVML --- .../parallel_utils/custom_all_reduce.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index d7aa578584a12..2719e4ccc2316 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -85,9 +85,18 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: return ca_handle.all_reduce_unreg(input) +@contextmanager +def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + + # query if the set of gpus are fully connected by nvlink (1 hop) +@_nvml() def _is_full_nvlink(rank, world_size): - pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(rank) for i in range(world_size): if i != rank: @@ -100,12 +109,11 @@ def _is_full_nvlink(rank, world_size): f"NVLink detection failed with message \"{str(error)}\". " "This is normal if your machine has no NVLink equipped") return False - pynvml.nvmlShutdown() return True +@_nvml() def _can_p2p(rank, world_size): - pynvml.nvmlInit() handle1 = pynvml.nvmlDeviceGetHandleByIndex(rank) for i in range(world_size): if i != rank: @@ -123,7 +131,6 @@ def _can_p2p(rank, world_size): f"P2P detection failed with message \"{str(error)}\". " "custom allreduce will be disabled") return False - pynvml.nvmlShutdown() return True From 602930a6a24dff0e7c72c13cad546fac3454b1db Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jan 2024 23:54:01 +0000 Subject: [PATCH 36/53] Minor fix for long comment --- vllm/model_executor/parallel_utils/custom_all_reduce.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 2719e4ccc2316..4937e7d73ffa8 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -193,7 +193,8 @@ def should_custom_ar(self, inp: torch.Tensor): # when size <= 512k return self.world_size <= 4 and inp_size <= 512 * 1024 - # all reduce, assuming inp tensor is IPC registered with register_buffer, or, in the context of cuda graphs, register_graph_buffers + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) From 3bfd8fae0beaea922da91ee54cdf33683be79071 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jan 2024 00:43:41 +0000 Subject: [PATCH 37/53] Minor --- csrc/custom_all_reduce.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 0602065ecdd71..43dc52de708c0 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -324,6 +324,7 @@ __global__ void __launch_bounds__(512, 1) ((P *)result)[idx] = tmp; } } + class FastAllreduce { public: int rank_; From d8f92bc371bd43a6dd4caa09ee4c7bd32534d8c3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jan 2024 07:08:46 +0000 Subject: [PATCH 38/53] Add library=cuda --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 794941fadcaa9..a034b479c3542 100644 --- a/setup.py +++ b/setup.py @@ -267,6 +267,7 @@ def get_torch_arch_list() -> Set[str]: "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, + libraries=["cuda"] if _is_cuda() else [], ) ext_modules.append(vllm_extension) From b4711a1fb8af471ec2510573a0840a64713708da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jan 2024 08:37:38 +0000 Subject: [PATCH 39/53] Skip ops for ROCm backend --- csrc/ops.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/ops.h b/csrc/ops.h index c7fb2b770657c..8e9b6351d0dc9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -91,6 +91,7 @@ void gptq_shuffle( torch::Tensor q_perm); +#ifndef USE_ROCM using fptr_t = uint64_t; fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, const std::vector &handles, @@ -107,3 +108,4 @@ void register_buffer(fptr_t _fa, torch::Tensor &t, std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets); +#endif From 84ee01994bea4e2b8516cf0eca87dc76aaeac742 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jan 2024 08:37:51 +0000 Subject: [PATCH 40/53] Minor --- csrc/custom_all_reduce.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82127c332715d..5bee86ec44edd 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -51,8 +51,8 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, */ bool _is_weak_contiguous(torch::Tensor &t) { return t.is_contiguous() || - t.storage().nbytes() - t.storage_offset() * t.element_size() == - t.numel() * t.element_size(); + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); } void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, From 8fbb2aa955ba5f3a2cb4e5e359b67a2b1adb0c54 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jan 2024 08:58:47 +0000 Subject: [PATCH 41/53] Fix can_p2p --- .../parallel_utils/custom_all_reduce.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 4937e7d73ffa8..31fc3f571c0f0 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -112,25 +112,12 @@ def _is_full_nvlink(rank, world_size): return True -@_nvml() -def _can_p2p(rank, world_size): - handle1 = pynvml.nvmlDeviceGetHandleByIndex(rank) +def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): - if i != rank: - handle2 = pynvml.nvmlDeviceGetHandleByIndex(rank) - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle1, handle2, pynvml.NVML_P2P_CAPS_INDEX_READ) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: - logger.info( - f"P2P is not supported between device {i} and {rank}. " - "custom allreduce will be disabled") - return False - except pynvml.NVMLError as error: - logger.info( - f"P2P detection failed with message \"{str(error)}\". " - "custom allreduce will be disabled") - return False + if i == rank: + continue + if not torch.cuda.can_device_access_peer(rank, i): + return False return True From c5b421259501d552318ffceb3351f55976f03eb4 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 23 Jan 2024 19:08:53 +0800 Subject: [PATCH 42/53] Apply suggestions from code review Co-authored-by: Woosuk Kwon --- tests/distributed/test_comm_ops.py | 2 +- vllm/model_executor/parallel_utils/custom_all_reduce.py | 2 +- vllm/test_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index b923a2e050c0f..d04d78a1d3589 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -88,4 +88,4 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, broadcast_tensor_dict_test_worker ]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) \ No newline at end of file + multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 31fc3f571c0f0..ad6f6655dfa26 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -28,7 +28,7 @@ def init_custom_ar() -> None: return rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() - if world_size > 1 and _can_p2p(rank, world_size): + if world_size in [2, 4, 6, 8] and _can_p2p(rank, world_size): _ca_handle = FastAllreduce(rank, world_size) diff --git a/vllm/test_utils.py b/vllm/test_utils.py index ad13205f95cb0..eb18f5033b00c 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -29,4 +29,4 @@ def multi_process_tensor_parallel(tensor_parallel_size, test_target): distributed_init_port)) ray.get(refs) - ray.shutdown() \ No newline at end of file + ray.shutdown() From 60e013aef2fc101009002b12ee7a3bbfbc053143 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 23 Jan 2024 11:30:59 +0000 Subject: [PATCH 43/53] add notes --- csrc/custom_all_reduce.cuh | 11 +++++++++-- .../parallel_utils/custom_all_reduce.py | 10 ++++++++++ vllm/worker/model_runner.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 43dc52de708c0..5c93f29810b5f 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -467,14 +467,21 @@ class FastAllreduce { graph_unreg_buffers_.clear(); } - // note: 512, 36 is good for most cases + /** + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. + */ template void allreduce(cudaStream_t stream, T *input, T *output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) throw std::runtime_error( - "custom allreduce currently requires input length to be multiple of " + + "custom allreduce currently requires input length to be multiple " + "of " + std::to_string(d)); RankData *ptrs; diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index ad6f6655dfa26..35be4e35558fd 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -126,10 +126,20 @@ class FastAllreduce: # max_size: max supported allreduce size def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, device="cuda") + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device="cuda") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5ba08722c7509..442b196bac8db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,9 +10,9 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) +from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.utils import in_wsl logger = init_logger(__name__) From 10a906e9325e6f238a703135491248be7b868c27 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 23 Jan 2024 11:38:14 +0000 Subject: [PATCH 44/53] add size check --- vllm/model_executor/parallel_utils/custom_all_reduce.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 35be4e35558fd..35f795cfeb13e 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -184,6 +184,9 @@ def register_graph_buffers(self): def should_custom_ar(self, inp: torch.Tensor): inp_size = inp.numel() * inp.element_size() + # custom allreduce currently input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False if self.fast_cond: return inp_size <= self.max_size # 4 pcie gpus use 2 stage AR, and is only faster than NCCL From b896fbd270df1f23107ca3febfad1bd66e1d3db5 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 23 Jan 2024 19:45:25 +0800 Subject: [PATCH 45/53] Update csrc/custom_all_reduce.cuh Co-authored-by: Woosuk Kwon --- csrc/custom_all_reduce.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 5c93f29810b5f..9465dd7720593 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -1,3 +1,5 @@ +#pragma once + #include #include #include From aae30fb94d163a1095e8613ab3464fe247762eb4 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 23 Jan 2024 19:46:05 +0800 Subject: [PATCH 46/53] Apply suggestions from code review Co-authored-by: Woosuk Kwon --- csrc/custom_all_reduce.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 9465dd7720593..c0af9b2a7ed27 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -141,7 +141,7 @@ __host__ __device__ constexpr uint64_t compute_flag(int ngpus) { } template -__device__ __forceinline__ void start_sync(const RankSignals &sg, +DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta, int rank) { constexpr auto FLAG = compute_flag(ngpus); if (blockIdx.x == 0) { @@ -161,7 +161,7 @@ __device__ __forceinline__ void start_sync(const RankSignals &sg, } template -__device__ __forceinline__ void end_sync(const RankSignals &sg, +DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta, int rank) { constexpr auto FLAG = compute_flag(ngpus); __syncthreads(); From 627a49f15ff4763af67a22beef5fb9c275e3c70a Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 23 Jan 2024 11:47:57 +0000 Subject: [PATCH 47/53] grammar --- vllm/model_executor/parallel_utils/custom_all_reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 35f795cfeb13e..77cd76352198e 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -184,7 +184,7 @@ def register_graph_buffers(self): def should_custom_ar(self, inp: torch.Tensor): inp_size = inp.numel() * inp.element_size() - # custom allreduce currently input byte size to be multiples of 16 + # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: return False if self.fast_cond: From c7e3704d48fed07ec85f2e94de87fddabce4dab6 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Tue, 23 Jan 2024 12:44:40 +0000 Subject: [PATCH 48/53] move test to c++ --- csrc/custom_all_reduce.cu | 14 ++++++++++++-- .../parallel_utils/custom_all_reduce.py | 13 +++---------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 5bee86ec44edd..0810e5a120900 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,10 +55,21 @@ bool _is_weak_contiguous(torch::Tensor &t) { t.numel() * t.element_size()); } +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, + bool full_nvlink) { + auto inp_size = inp.numel() * inp.element_size(); + // custom allreduce requires input byte size to be multiples of 16 + if (inp_size % 16 != 0) return false; + if (!_is_weak_contiguous(inp)) return false; + if (world_size == 2 || full_nvlink) return inp_size <= max_size; + // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size + // <= 512k + return world_size <= 4 && inp_size <= 512 * 1024; +} + void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); - TORCH_CHECK(_is_weak_contiguous(inp)); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { @@ -105,7 +116,6 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), "registered buffer is too small to contain the input"); - TORCH_CHECK(_is_weak_contiguous(inp)); AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream)); _all_reduce(_fa, reg_buffer, out, stream); diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 77cd76352198e..cdc8617b64b3e 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -139,7 +139,7 @@ def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. + # needs less than 10000 of registered tuples. self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device="cuda") @@ -183,15 +183,8 @@ def register_graph_buffers(self): custom_ar.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - inp_size = inp.numel() * inp.element_size() - # custom allreduce requires input byte size to be multiples of 16 - if inp_size % 16 != 0: - return False - if self.fast_cond: - return inp_size <= self.max_size - # 4 pcie gpus use 2 stage AR, and is only faster than NCCL - # when size <= 512k - return self.world_size <= 4 and inp_size <= 512 * 1024 + return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers From 036bb685581dbedcd6cf7fc845fe87268c7d1191 Mon Sep 17 00:00:00 2001 From: hanzhizhou Date: Wed, 24 Jan 2024 00:03:33 +0000 Subject: [PATCH 49/53] add warnings and do few renames --- csrc/custom_all_reduce.cu | 12 +++---- csrc/custom_all_reduce.cuh | 28 ++++++++++------- csrc/custom_all_reduce_test.cu | 12 +++---- csrc/ops.h | 2 ++ csrc/pybind.cpp | 1 + .../parallel_utils/custom_all_reduce.py | 31 +++++++++++++------ 6 files changed, 53 insertions(+), 33 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 0810e5a120900..88e4af9d4a99f 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -28,7 +28,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, for (int i = 0; i < world_size; i++) { std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } - return (fptr_t) new vllm::FastAllreduce( + return (fptr_t) new vllm::CustomAllreduce( reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -69,7 +69,7 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { @@ -122,7 +122,7 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } @@ -131,18 +131,18 @@ int meta_size() { return sizeof(vllm::Metadata); } void register_buffer(fptr_t _fa, torch::Tensor &t, const std::vector &handles, const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index c0af9b2a7ed27..6e71bb9a9c6e8 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -141,8 +141,8 @@ __host__ __device__ constexpr uint64_t compute_flag(int ngpus) { } template -DINLINE void start_sync(const RankSignals &sg, - volatile Metadata *meta, int rank) { +DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta, + int rank) { constexpr auto FLAG = compute_flag(ngpus); if (blockIdx.x == 0) { if (threadIdx.x < ngpus) @@ -161,8 +161,8 @@ DINLINE void start_sync(const RankSignals &sg, } template -DINLINE void end_sync(const RankSignals &sg, - volatile Metadata *meta, int rank) { +DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta, + int rank) { constexpr auto FLAG = compute_flag(ngpus); __syncthreads(); __shared__ int num; @@ -327,7 +327,7 @@ __global__ void __launch_bounds__(512, 1) } } -class FastAllreduce { +class CustomAllreduce { public: int rank_; int world_size_; @@ -352,10 +352,10 @@ class FastAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - FastAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, - bool full_nvlink = true) + CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t *handles, + const std::vector &offsets, int rank, + bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), @@ -540,12 +540,16 @@ class FastAllreduce { #undef KL } - ~FastAllreduce() { + ~CustomAllreduce() { for (auto ptr : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } } }; -template void FastAllreduce::allreduce(cudaStream_t, half *, half *, int, - int, int); +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void CustomAllreduce::allreduce(cudaStream_t, half *, half *, + int, int, int); +*/ } // namespace vllm diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index ce4338ddd1e27..6b094e2fdc9ba 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -129,8 +129,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); - vllm::FastAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, - myRank); + vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, + offsets, myRank); auto *self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Metadata) + data_size * sizeof(T)); @@ -153,8 +153,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, curandState_t *states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); - gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, - myRank, nRanks, data_size); + gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, + nRanks, data_size); CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream)); cudaEvent_t start, stop; @@ -229,8 +229,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, for (unsigned long j = 0; j < data_size; j++) { auto diff = abs(nccl_result[j] - my_result[j]); if (diff >= 1e-2) { - printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", myRank, - j, nccl_result[j], my_result[j], ground_truth[j]); + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); break; } } diff --git a/csrc/ops.h b/csrc/ops.h index 8e9b6351d0dc9..fb96330c8a21e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -97,6 +97,8 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, const std::vector &handles, const std::vector &offsets, int rank, bool full_nvlink); +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, + bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, torch::Tensor &out); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index fdd092f49a8b2..22be0b49f6328 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -86,6 +86,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Custom all-reduce kernels pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); + custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); custom_ar.def("dispose", &dispose, "dispose"); diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index cdc8617b64b3e..5b88649cc2129 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -18,18 +18,31 @@ logger = init_logger(__name__) -_ca_handle = None +_CA_HANDLE = None _IS_CAPTURING = False +_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] def init_custom_ar() -> None: - global _ca_handle - if _ca_handle is not None: + global _CA_HANDLE + if _CA_HANDLE is not None: return rank = get_tensor_model_parallel_rank() world_size = get_tensor_model_parallel_world_size() - if world_size in [2, 4, 6, 8] and _can_p2p(rank, world_size): - _ca_handle = FastAllreduce(rank, world_size) + if world_size not in _SUPPORTED_WORLD_SIZES: + logger.warn( + "Custom allreduce is disabled due to an unsupported world size: " + "%d. Supported world sizes: %s. To slience this warning, specify" + "disable_custom_all_reduce=True explicitly.", world_size, + str(_SUPPORTED_WORLD_SIZES)) + return + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability. To slience this warning, specify" + "disable_custom_all_reduce=True explicitly.") + return + _CA_HANDLE = CustomAllreduce(rank, world_size) def begin_capture() -> None: @@ -43,11 +56,11 @@ def end_capture() -> None: def is_capturing() -> bool: - return _IS_CAPTURING and _ca_handle is not None + return _IS_CAPTURING and _CA_HANDLE is not None -def get_handle() -> Optional["FastAllreduce"]: - return _ca_handle +def get_handle() -> Optional["CustomAllreduce"]: + return _CA_HANDLE @contextmanager @@ -121,7 +134,7 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True -class FastAllreduce: +class CustomAllreduce: # max_size: max supported allreduce size def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: From e1e802ea525da578ee73d25ee05fd327b05a6953 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jan 2024 07:27:36 +0000 Subject: [PATCH 50/53] Fix custom all reduce tests --- tests/distributed/test_custom_all_reduce.py | 40 +++++++++++++-------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index db1180dd2e9d2..ed4965593c2f0 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -1,24 +1,31 @@ import random +import os import pytest +import ray import torch import torch.distributed as dist -import ray from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_reduce -from vllm.test_utils import init_test_distributed_environment, multi_process_tensor_parallel +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.test_utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) -test_sizes = [random.randint(1024, 2048 * 1024) for i in range(8)] +test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] for i, v in enumerate(test_sizes): test_sizes[i] -= v % 8 @ray.remote(num_gpus=1, max_calls=1) def graph_allreduce(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, world_size, rank, distributed_init_port) + custom_ar.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: @@ -36,7 +43,8 @@ def graph_allreduce(world_size, rank, distributed_init_port): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test synchronization + # the input buffer is immediately modified to test + # synchronization dist.all_reduce(inp1) out2 = tensor_model_parallel_all_reduce(inp2) dist.all_reduce(inp2) @@ -47,27 +55,31 @@ def graph_allreduce(world_size, rank, distributed_init_port): @ray.remote(num_gpus=1, max_calls=1) def eager_allreduce(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) init_test_distributed_environment(1, world_size, rank, distributed_init_port) + sz = 1024 custom_ar.init_custom_ar() fa = custom_ar.get_handle() - inp = torch.ones(sz, - dtype=torch.float32, - device=torch.cuda.current_device()) + inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) - inp = torch.ones(sz * 4, - dtype=torch.bfloat16, - device=torch.cuda.current_device()) + inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2, 4]) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tensor_parallel_size", [2]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): multi_process_tensor_parallel(tensor_parallel_size, test_target) + + +if __name__ == "__main__": + multi_process_tensor_parallel(2, graph_allreduce) From bbfc263950c5fbab98ae8ba4efdd1891bd7be76d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jan 2024 07:30:33 +0000 Subject: [PATCH 51/53] Move test_utils to tests/distributed/utils --- tests/distributed/test_custom_all_reduce.py | 4 ++-- .../distributed/utils.py | 20 ++++++++++++------- vllm/worker/worker.py | 6 +++--- 3 files changed, 18 insertions(+), 12 deletions(-) rename vllm/test_utils.py => tests/distributed/utils.py (61%) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index ed4965593c2f0..ed483812cd72e 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -9,8 +9,8 @@ from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from tests.distributed.utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] diff --git a/vllm/test_utils.py b/tests/distributed/utils.py similarity index 61% rename from vllm/test_utils.py rename to tests/distributed/utils.py index eb18f5033b00c..4f74c05038e70 100644 --- a/vllm/test_utils.py +++ b/tests/distributed/utils.py @@ -2,21 +2,27 @@ from vllm.config import ParallelConfig from vllm.utils import get_open_port -from vllm.worker.worker import _init_distributed_environment +from vllm.worker.worker import init_distributed_environment -def init_test_distributed_environment(pipeline_parallel_size: int, - tensor_parallel_size: int, rank: int, - distributed_init_port: str): +def init_test_distributed_environment( + pipeline_parallel_size: int, + tensor_parallel_size: int, + rank: int, + distributed_init_port: str, +) -> None: parallel_config = ParallelConfig(pipeline_parallel_size, tensor_parallel_size, worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - _init_distributed_environment(parallel_config, rank, - distributed_init_method) + init_distributed_environment(parallel_config, rank, + distributed_init_method) -def multi_process_tensor_parallel(tensor_parallel_size, test_target): +def multi_process_tensor_parallel( + tensor_parallel_size: int, + test_target, +) -> None: # Using ray helps debugging the error when it failed # as compared to multiprocessing. ray.init() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4e47a674e4be6..f1dad64b2b27a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -79,8 +79,8 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) if not self.parallel_config.disable_custom_all_reduce: init_custom_ar() # Initialize the model. @@ -221,7 +221,7 @@ def list_loras(self) -> Set[int]: return self.model_runner.list_loras() -def _init_distributed_environment( +def init_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, From 6f613476695e2e34fe6dcf5393d9c8e8bcb343c2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jan 2024 09:28:08 +0000 Subject: [PATCH 52/53] Minor --- setup.py | 4 ++-- tests/distributed/test_comm_ops.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 99d98aedcc2c6..2f6242690a263 100644 --- a/setup.py +++ b/setup.py @@ -51,8 +51,8 @@ def _is_cuda() -> bool: "Cannot find ROCM_HOME. ROCm must be available to build the package." ) NVCC_FLAGS += ["-DUSE_ROCM"] - NVCC_FLAGS += [f"-U__HIP_NO_HALF_CONVERSIONS__"] - NVCC_FLAGS += [f"-U__HIP_NO_HALF_OPERATORS__"] + NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"] + NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"] if _is_cuda() and CUDA_HOME is None: raise RuntimeError( diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d04d78a1d3589..00d2fb4838e99 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,7 +11,8 @@ tensor_model_parallel_all_gather, broadcast_tensor_dict, ) -from vllm.test_utils import init_test_distributed_environment, multi_process_tensor_parallel +from tests.distributed.utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) From c09772c298360b6abdca51738ff76e2919388c0d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jan 2024 18:25:42 +0000 Subject: [PATCH 53/53] Roll back to test_utils --- tests/distributed/test_comm_ops.py | 4 ++-- tests/distributed/test_custom_all_reduce.py | 4 ++-- tests/distributed/utils.py => vllm/test_utils.py | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename tests/distributed/utils.py => vllm/test_utils.py (100%) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 00d2fb4838e99..9474cb21599d4 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,8 +11,8 @@ tensor_model_parallel_all_gather, broadcast_tensor_dict, ) -from tests.distributed.utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from vllm.test_utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index ed483812cd72e..ed4965593c2f0 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -9,8 +9,8 @@ from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) -from tests.distributed.utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from vllm.test_utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] diff --git a/tests/distributed/utils.py b/vllm/test_utils.py similarity index 100% rename from tests/distributed/utils.py rename to vllm/test_utils.py