Skip to content

Commit 37e67c7

Browse files
ilmarkovilmarkov
authored andcommitted
[Distributed] Add custom allreduce support for ROCM (vllm-project#14125)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 84aed20 commit 37e67c7

File tree

13 files changed

+373
-160
lines changed

13 files changed

+373
-160
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
242242
"csrc/quantization/gguf/gguf_kernel.cu"
243243
"csrc/cuda_utils_kernels.cu"
244244
"csrc/prepare_inputs/advance_step.cu"
245+
"csrc/custom_all_reduce.cu"
245246
"csrc/torch_bindings.cpp")
246247

247248
if(VLLM_GPU_LANG STREQUAL "CUDA")
@@ -283,7 +284,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
283284
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
284285
"csrc/quantization/aqlm/gemm_kernels.cu"
285286
"csrc/quantization/awq/gemm_kernels.cu"
286-
"csrc/custom_all_reduce.cu"
287287
"csrc/permute_cols.cu"
288288
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
289289
"csrc/quantization/fp4/nvfp4_quant_entry.cu"

csrc/custom_all_reduce.cu

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
1212

1313
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
1414
torch::Tensor& rank_data, int64_t rank,
15-
bool full_nvlink) {
15+
bool fully_connected) {
1616
int world_size = fake_ipc_ptrs.size();
1717
if (world_size > 8)
1818
throw std::invalid_argument("world size > 8 is not supported");
@@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
2727
}
2828
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
2929
rank_data.numel(), rank, world_size,
30-
full_nvlink);
30+
fully_connected);
3131
}
3232

3333
/**
@@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
142142
bytes.reserve(handles.size());
143143
fa->register_graph_buffers(bytes, offsets);
144144
}
145+
146+
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
147+
int64_t size) {
148+
auto device_index = c10::cuda::current_device();
149+
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
150+
void* buffer;
151+
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
152+
auto stream = c10::cuda::getCurrentCUDAStream().stream();
153+
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
154+
155+
// Allocate buffer
156+
#if defined(USE_ROCM)
157+
// data buffers need to be "uncached" for signal on MI200
158+
AT_CUDA_CHECK(
159+
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
160+
#else
161+
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
162+
#endif
163+
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
164+
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
165+
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
166+
167+
// Create IPC memhandle for the allocated buffer.
168+
// Will use it in open_mem_handle.
169+
auto options =
170+
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
171+
auto handle =
172+
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
173+
AT_CUDA_CHECK(
174+
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
175+
176+
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
177+
}
178+
179+
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
180+
void* ipc_ptr;
181+
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
182+
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
183+
cudaIpcMemLazyEnablePeerAccess));
184+
return reinterpret_cast<fptr_t>(ipc_ptr);
185+
}
186+
187+
void free_shared_buffer(fptr_t buffer) {
188+
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
189+
}

0 commit comments

Comments
 (0)