Skip to content

Commit 0343099

Browse files
committed
Separate pynccl functions
Signed-off-by: Trevor Morris <tmorris@nvidia.com>
1 parent 0dfbe60 commit 0343099

File tree

2 files changed

+82
-48
lines changed

2 files changed

+82
-48
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def reduce_scatterv(self,
152152
dtype=input_tensor.dtype,
153153
device=input_tensor.device)
154154

155-
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
155+
if sizes is not None:
156+
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
157+
else:
158+
pynccl_comm.reduce_scatter(output, input_)
156159

157160
# Reshape before returning
158161
return output.movedim(0, dim).contiguous()
@@ -222,7 +225,10 @@ def _all_gather_single(input_: torch.Tensor,
222225
output_tensor = torch.empty(output_size,
223226
dtype=input_.dtype,
224227
device=input_.device)
225-
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
228+
if sizes is not None:
229+
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
230+
else:
231+
pynccl_comm.all_gather(output_tensor, input_)
226232
return output_tensor
227233

228234
if isinstance(input_, torch.Tensor):

vllm/distributed/device_communicators/pynccl.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ def all_reduce(self,
135135
def all_gather(self,
136136
output_tensor: torch.Tensor,
137137
input_tensor: torch.Tensor,
138-
stream=None,
139-
sizes: Optional[list[int]] = None):
138+
stream=None):
140139
if self.disabled:
141140
return
142141
# nccl communicator created on a specific device
@@ -147,37 +146,51 @@ def all_gather(self,
147146
f"but the input tensor is on {input_tensor.device}")
148147
if stream is None:
149148
stream = current_stream()
150-
if sizes is not None:
151-
assert output_tensor.shape[0] == sum(sizes)
152-
split_offset = 0
153-
self.nccl.ncclGroupStart()
154-
for root, split_size in enumerate(sizes):
155-
dst_slice = output_tensor[split_offset:split_offset +
156-
split_size]
157-
self.nccl.ncclBroadcast(
158-
buffer_type(input_tensor.data_ptr()),
159-
buffer_type(dst_slice.data_ptr()),
160-
dst_slice.numel(),
161-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
162-
root,
163-
self.comm,
164-
cudaStream_t(stream.cuda_stream),
165-
)
166-
split_offset += split_size
167-
self.nccl.ncclGroupEnd()
168-
else:
169-
self.nccl.ncclAllGather(
149+
self.nccl.ncclAllGather(
150+
buffer_type(input_tensor.data_ptr()),
151+
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
152+
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
153+
cudaStream_t(stream.cuda_stream))
154+
155+
def all_gatherv(
156+
self,
157+
output_tensor: torch.Tensor,
158+
input_tensor: torch.Tensor,
159+
sizes: list[int],
160+
stream=None,
161+
):
162+
if self.disabled:
163+
return
164+
# nccl communicator created on a specific device
165+
# will only work on tensors on the same device
166+
# otherwise it will cause "illegal memory access"
167+
assert input_tensor.device == self.device, (
168+
f"this nccl communicator is created to work on {self.device}, "
169+
f"but the input tensor is on {input_tensor.device}")
170+
if stream is None:
171+
stream = current_stream()
172+
assert output_tensor.shape[0] == sum(sizes)
173+
split_offset = 0
174+
self.nccl.ncclGroupStart()
175+
for root, split_size in enumerate(sizes):
176+
dst_slice = output_tensor[split_offset:split_offset + split_size]
177+
self.nccl.ncclBroadcast(
170178
buffer_type(input_tensor.data_ptr()),
171-
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
172-
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
173-
cudaStream_t(stream.cuda_stream))
179+
buffer_type(dst_slice.data_ptr()),
180+
dst_slice.numel(),
181+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
182+
root,
183+
self.comm,
184+
cudaStream_t(stream.cuda_stream),
185+
)
186+
split_offset += split_size
187+
self.nccl.ncclGroupEnd()
174188

175189
def reduce_scatter(self,
176190
output_tensor: torch.Tensor,
177191
input_tensor: torch.Tensor,
178192
op: ReduceOp = ReduceOp.SUM,
179-
stream=None,
180-
sizes: Optional[list[int]] = None):
193+
stream=None):
181194
if self.disabled:
182195
return
183196
# nccl communicator created on a specific device
@@ -188,29 +201,44 @@ def reduce_scatter(self,
188201
f"but the input tensor is on {input_tensor.device}")
189202
if stream is None:
190203
stream = current_stream()
204+
self.nccl.ncclReduceScatter(
205+
buffer_type(input_tensor.data_ptr()),
206+
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
207+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
208+
ncclRedOpTypeEnum.from_torch(op), self.comm,
209+
cudaStream_t(stream.cuda_stream))
191210

192-
if sizes is not None:
193-
split_offset = 0
194-
self.nccl.ncclGroupStart()
195-
for root, split_size in enumerate(sizes):
196-
chunk = input_tensor[split_offset:split_offset + split_size,
197-
...]
211+
def reduce_scatterv(
212+
self,
213+
output_tensor: torch.Tensor,
214+
input_tensor: torch.Tensor,
215+
sizes: list[int],
216+
op: ReduceOp = ReduceOp.SUM,
217+
stream=None,
218+
):
219+
if self.disabled:
220+
return
221+
# nccl communicator created on a specific device
222+
# will only work on tensors on the same device
223+
# otherwise it will cause "illegal memory access"
224+
assert input_tensor.device == self.device, (
225+
f"this nccl communicator is created to work on {self.device}, "
226+
f"but the input tensor is on {input_tensor.device}")
227+
if stream is None:
228+
stream = current_stream()
198229

199-
self.nccl.ncclReduce(
200-
buffer_type(chunk.data_ptr()),
201-
buffer_type(output_tensor.data_ptr()), chunk.numel(),
202-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
203-
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
204-
cudaStream_t(stream.cuda_stream))
205-
split_offset += split_size
206-
self.nccl.ncclGroupEnd()
207-
else:
208-
self.nccl.ncclReduceScatter(
209-
buffer_type(input_tensor.data_ptr()),
210-
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
230+
split_offset = 0
231+
self.nccl.ncclGroupStart()
232+
for root, split_size in enumerate(sizes):
233+
chunk = input_tensor[split_offset:split_offset + split_size, ...]
234+
self.nccl.ncclReduce(
235+
buffer_type(chunk.data_ptr()),
236+
buffer_type(output_tensor.data_ptr()), chunk.numel(),
211237
ncclDataTypeEnum.from_torch(input_tensor.dtype),
212-
ncclRedOpTypeEnum.from_torch(op), self.comm,
238+
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
213239
cudaStream_t(stream.cuda_stream))
240+
split_offset += split_size
241+
self.nccl.ncclGroupEnd()
214242

215243
def send(self, tensor: torch.Tensor, dst: int, stream=None):
216244
if self.disabled:

0 commit comments

Comments
 (0)