11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- from typing import Optional , Union
4+ from typing import List , Optional , Union
55
6+ import numpy as np
67# ===================== import region =====================
78import torch
89import torch .distributed as dist
@@ -135,7 +136,8 @@ def all_reduce(self,
135136 def all_gather (self ,
136137 output_tensor : torch .Tensor ,
137138 input_tensor : torch .Tensor ,
138- stream = None ):
139+ stream = None ,
140+ sizes : Optional [List [int ]] = None ):
139141 if self .disabled :
140142 return
141143 # nccl communicator created on a specific device
@@ -146,17 +148,38 @@ def all_gather(self,
146148 f"but the input tensor is on { input_tensor .device } " )
147149 if stream is None :
148150 stream = current_stream ()
149- self .nccl .ncclAllGather (
150- buffer_type (input_tensor .data_ptr ()),
151- buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
152- ncclDataTypeEnum .from_torch (input_tensor .dtype ), self .comm ,
153- cudaStream_t (stream .cuda_stream ))
151+ if sizes is not None :
152+ assert output_tensor .shape [0 ] == sum (sizes )
153+ numel_base = int (np .prod (output_tensor .shape [1 :]))
154+ split_offset = 0
155+ self .nccl .ncclGroupStart ()
156+ for root , split_size in enumerate (sizes ):
157+ dst_slice = output_tensor [split_offset :split_offset +
158+ split_size ]
159+ self .nccl .ncclBroadcast (
160+ buffer_type (input_tensor .data_ptr ()),
161+ buffer_type (dst_slice .data_ptr ()),
162+ split_size * numel_base ,
163+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
164+ root ,
165+ self .comm ,
166+ cudaStream_t (stream .cuda_stream ),
167+ )
168+ split_offset += split_size
169+ self .nccl .ncclGroupEnd ()
170+ else :
171+ self .nccl .ncclAllGather (
172+ buffer_type (input_tensor .data_ptr ()),
173+ buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
174+ ncclDataTypeEnum .from_torch (input_tensor .dtype ), self .comm ,
175+ cudaStream_t (stream .cuda_stream ))
154176
155177 def reduce_scatter (self ,
156178 output_tensor : torch .Tensor ,
157179 input_tensor : torch .Tensor ,
158180 op : ReduceOp = ReduceOp .SUM ,
159- stream = None ):
181+ stream = None ,
182+ sizes : Optional [List [int ]] = None ):
160183 if self .disabled :
161184 return
162185 # nccl communicator created on a specific device
@@ -167,12 +190,29 @@ def reduce_scatter(self,
167190 f"but the input tensor is on { input_tensor .device } " )
168191 if stream is None :
169192 stream = current_stream ()
170- self .nccl .ncclReduceScatter (
171- buffer_type (input_tensor .data_ptr ()),
172- buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
173- ncclDataTypeEnum .from_torch (input_tensor .dtype ),
174- ncclRedOpTypeEnum .from_torch (op ), self .comm ,
175- cudaStream_t (stream .cuda_stream ))
193+
194+ if sizes is not None :
195+ numel_base = int (np .prod (input_tensor .shape [1 :]))
196+ split_offset = 0
197+ self .nccl .ncclGroupStart ()
198+ for root , split_size in enumerate (sizes ):
199+ chunk = input_tensor [split_offset :split_offset + split_size , :]
200+ self .nccl .ncclReduce (
201+ buffer_type (chunk .data_ptr ()),
202+ buffer_type (output_tensor .data_ptr ()),
203+ split_size * numel_base ,
204+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
205+ ncclRedOpTypeEnum .from_torch (op ), root , self .comm ,
206+ cudaStream_t (stream .cuda_stream ))
207+ split_offset += split_size
208+ self .nccl .ncclGroupEnd ()
209+ else :
210+ self .nccl .ncclReduceScatter (
211+ buffer_type (input_tensor .data_ptr ()),
212+ buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
213+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
214+ ncclRedOpTypeEnum .from_torch (op ), self .comm ,
215+ cudaStream_t (stream .cuda_stream ))
176216
177217 def send (self , tensor : torch .Tensor , dst : int , stream = None ):
178218 if self .disabled :
@@ -216,3 +256,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
216256 self .nccl .ncclBroadcast (sendbuff , recvbuff , tensor .numel (),
217257 ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
218258 self .comm , cudaStream_t (stream .cuda_stream ))
259+
260+ def group_start (self ):
261+ self .nccl .ncclGroupStart ()
262+
263+ def group_end (self ):
264+ self .nccl .ncclGroupEnd ()
0 commit comments