@@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
1212
1313fptr_t init_custom_ar (const std::vector<fptr_t >& fake_ipc_ptrs,
1414 torch::Tensor& rank_data, int64_t rank,
15- bool full_nvlink ) {
15+ bool fully_connected ) {
1616 int world_size = fake_ipc_ptrs.size ();
1717 if (world_size > 8 )
1818 throw std::invalid_argument (" world size > 8 is not supported" );
@@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
2727 }
2828 return (fptr_t ) new vllm::CustomAllreduce (ipc_ptrs, rank_data.data_ptr (),
2929 rank_data.numel (), rank, world_size,
30- full_nvlink );
30+ fully_connected );
3131}
3232
3333/* *
@@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
142142 bytes.reserve (handles.size ());
143143 fa->register_graph_buffers (bytes, offsets);
144144}
145+
146+ std::tuple<fptr_t , torch::Tensor> allocate_shared_buffer_and_handle (
147+ int64_t size) {
148+ auto device_index = c10::cuda::current_device ();
149+ at::DeviceGuard device_guard (at::Device (at::DeviceType::CUDA, device_index));
150+ void * buffer;
151+ cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
152+ auto stream = c10::cuda::getCurrentCUDAStream ().stream ();
153+ AT_CUDA_CHECK (cudaThreadExchangeStreamCaptureMode (&mode));
154+
155+ // Allocate buffer
156+ #if defined(USE_ROCM)
157+ // data buffers need to be "uncached" for signal on MI200
158+ AT_CUDA_CHECK (
159+ hipExtMallocWithFlags ((void **)&buffer, size, hipDeviceMallocUncached));
160+ #else
161+ AT_CUDA_CHECK (cudaMalloc ((void **)&buffer, size));
162+ #endif
163+ AT_CUDA_CHECK (cudaMemsetAsync (buffer, 0 , size, stream));
164+ AT_CUDA_CHECK (cudaStreamSynchronize (stream));
165+ AT_CUDA_CHECK (cudaThreadExchangeStreamCaptureMode (&mode));
166+
167+ // Create IPC memhandle for the allocated buffer.
168+ // Will use it in open_mem_handle.
169+ auto options =
170+ torch::TensorOptions ().dtype (torch::kUInt8 ).device (torch::kCPU );
171+ auto handle =
172+ torch::empty ({static_cast <int64_t >(sizeof (cudaIpcMemHandle_t))}, options);
173+ AT_CUDA_CHECK (
174+ cudaIpcGetMemHandle ((cudaIpcMemHandle_t*)handle.data_ptr (), buffer));
175+
176+ return std::make_tuple (reinterpret_cast <fptr_t >(buffer), handle);
177+ }
178+
179+ fptr_t open_mem_handle (torch::Tensor& mem_handle) {
180+ void * ipc_ptr;
181+ AT_CUDA_CHECK (cudaIpcOpenMemHandle (
182+ (void **)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr ()),
183+ cudaIpcMemLazyEnablePeerAccess));
184+ return reinterpret_cast <fptr_t >(ipc_ptr);
185+ }
186+
187+ void free_shared_buffer (fptr_t buffer) {
188+ AT_CUDA_CHECK (cudaFree (reinterpret_cast <void *>(buffer)));
189+ }
0 commit comments