@@ -135,8 +135,7 @@ def all_reduce(self,
135135 def all_gather (self ,
136136 output_tensor : torch .Tensor ,
137137 input_tensor : torch .Tensor ,
138- stream = None ,
139- sizes : Optional [list [int ]] = None ):
138+ stream = None ):
140139 if self .disabled :
141140 return
142141 # nccl communicator created on a specific device
@@ -147,37 +146,51 @@ def all_gather(self,
147146 f"but the input tensor is on { input_tensor .device } " )
148147 if stream is None :
149148 stream = current_stream ()
150- if sizes is not None :
151- assert output_tensor .shape [0 ] == sum (sizes )
152- split_offset = 0
153- self .nccl .ncclGroupStart ()
154- for root , split_size in enumerate (sizes ):
155- dst_slice = output_tensor [split_offset :split_offset +
156- split_size ]
157- self .nccl .ncclBroadcast (
158- buffer_type (input_tensor .data_ptr ()),
159- buffer_type (dst_slice .data_ptr ()),
160- dst_slice .numel (),
161- ncclDataTypeEnum .from_torch (input_tensor .dtype ),
162- root ,
163- self .comm ,
164- cudaStream_t (stream .cuda_stream ),
165- )
166- split_offset += split_size
167- self .nccl .ncclGroupEnd ()
168- else :
169- self .nccl .ncclAllGather (
149+ self .nccl .ncclAllGather (
150+ buffer_type (input_tensor .data_ptr ()),
151+ buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
152+ ncclDataTypeEnum .from_torch (input_tensor .dtype ), self .comm ,
153+ cudaStream_t (stream .cuda_stream ))
154+
155+ def all_gatherv (
156+ self ,
157+ output_tensor : torch .Tensor ,
158+ input_tensor : torch .Tensor ,
159+ sizes : list [int ],
160+ stream = None ,
161+ ):
162+ if self .disabled :
163+ return
164+ # nccl communicator created on a specific device
165+ # will only work on tensors on the same device
166+ # otherwise it will cause "illegal memory access"
167+ assert input_tensor .device == self .device , (
168+ f"this nccl communicator is created to work on { self .device } , "
169+ f"but the input tensor is on { input_tensor .device } " )
170+ if stream is None :
171+ stream = current_stream ()
172+ assert output_tensor .shape [0 ] == sum (sizes )
173+ split_offset = 0
174+ self .nccl .ncclGroupStart ()
175+ for root , split_size in enumerate (sizes ):
176+ dst_slice = output_tensor [split_offset :split_offset + split_size ]
177+ self .nccl .ncclBroadcast (
170178 buffer_type (input_tensor .data_ptr ()),
171- buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
172- ncclDataTypeEnum .from_torch (input_tensor .dtype ), self .comm ,
173- cudaStream_t (stream .cuda_stream ))
179+ buffer_type (dst_slice .data_ptr ()),
180+ dst_slice .numel (),
181+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
182+ root ,
183+ self .comm ,
184+ cudaStream_t (stream .cuda_stream ),
185+ )
186+ split_offset += split_size
187+ self .nccl .ncclGroupEnd ()
174188
175189 def reduce_scatter (self ,
176190 output_tensor : torch .Tensor ,
177191 input_tensor : torch .Tensor ,
178192 op : ReduceOp = ReduceOp .SUM ,
179- stream = None ,
180- sizes : Optional [list [int ]] = None ):
193+ stream = None ):
181194 if self .disabled :
182195 return
183196 # nccl communicator created on a specific device
@@ -188,29 +201,44 @@ def reduce_scatter(self,
188201 f"but the input tensor is on { input_tensor .device } " )
189202 if stream is None :
190203 stream = current_stream ()
204+ self .nccl .ncclReduceScatter (
205+ buffer_type (input_tensor .data_ptr ()),
206+ buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
207+ ncclDataTypeEnum .from_torch (input_tensor .dtype ),
208+ ncclRedOpTypeEnum .from_torch (op ), self .comm ,
209+ cudaStream_t (stream .cuda_stream ))
191210
192- if sizes is not None :
193- split_offset = 0
194- self .nccl .ncclGroupStart ()
195- for root , split_size in enumerate (sizes ):
196- chunk = input_tensor [split_offset :split_offset + split_size ,
197- ...]
211+ def reduce_scatterv (
212+ self ,
213+ output_tensor : torch .Tensor ,
214+ input_tensor : torch .Tensor ,
215+ sizes : list [int ],
216+ op : ReduceOp = ReduceOp .SUM ,
217+ stream = None ,
218+ ):
219+ if self .disabled :
220+ return
221+ # nccl communicator created on a specific device
222+ # will only work on tensors on the same device
223+ # otherwise it will cause "illegal memory access"
224+ assert input_tensor .device == self .device , (
225+ f"this nccl communicator is created to work on { self .device } , "
226+ f"but the input tensor is on { input_tensor .device } " )
227+ if stream is None :
228+ stream = current_stream ()
198229
199- self .nccl .ncclReduce (
200- buffer_type (chunk .data_ptr ()),
201- buffer_type (output_tensor .data_ptr ()), chunk .numel (),
202- ncclDataTypeEnum .from_torch (input_tensor .dtype ),
203- ncclRedOpTypeEnum .from_torch (op ), root , self .comm ,
204- cudaStream_t (stream .cuda_stream ))
205- split_offset += split_size
206- self .nccl .ncclGroupEnd ()
207- else :
208- self .nccl .ncclReduceScatter (
209- buffer_type (input_tensor .data_ptr ()),
210- buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
230+ split_offset = 0
231+ self .nccl .ncclGroupStart ()
232+ for root , split_size in enumerate (sizes ):
233+ chunk = input_tensor [split_offset :split_offset + split_size , ...]
234+ self .nccl .ncclReduce (
235+ buffer_type (chunk .data_ptr ()),
236+ buffer_type (output_tensor .data_ptr ()), chunk .numel (),
211237 ncclDataTypeEnum .from_torch (input_tensor .dtype ),
212- ncclRedOpTypeEnum .from_torch (op ), self .comm ,
238+ ncclRedOpTypeEnum .from_torch (op ), root , self .comm ,
213239 cudaStream_t (stream .cuda_stream ))
240+ split_offset += split_size
241+ self .nccl .ncclGroupEnd ()
214242
215243 def send (self , tensor : torch .Tensor , dst : int , stream = None ):
216244 if self .disabled :
0 commit comments