Skip to content

Commit 7cdb800

Browse files
committed
Flashinfer cutlass moe backend for TP/DP + EP.
1 parent 657f2f3 commit 7cdb800

File tree

15 files changed

+904
-145
lines changed

15 files changed

+904
-145
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
VisionArenaDataset,
2929
)
3030
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
31+
from vllm.distributed import cleanup_dist_env_and_memory
3132
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
3233
from vllm.entrypoints.openai.api_server import (
3334
build_async_engine_client_from_engine_args,
@@ -110,6 +111,8 @@ def run_vllm(
110111
),
111112
)
112113
end = time.perf_counter()
114+
115+
cleanup_dist_env_and_memory()
113116
return end - start, outputs
114117

115118

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import threading
4-
from typing import Optional
4+
from typing import List, Optional, Union
55
from weakref import WeakValueDictionary
66

77
import torch
@@ -138,9 +138,23 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
138138
input_size[dim + 1:])
139139
return output_tensor
140140

141+
def all_gatherv(self,
142+
input_: Union[torch.Tensor, List[torch.Tensor]],
143+
dim: int = 0,
144+
sizes: Optional[List[int]] = None):
145+
assert False, "not implemented"
146+
147+
def all_gatherv(self,
148+
input_: Union[torch.Tensor, List[torch.Tensor]],
149+
dim: int = 0,
150+
sizes: Optional[List[int]] = None):
151+
assert False, "not implemented"
152+
141153
def reduce_scatter(self,
142154
input_: torch.Tensor,
143-
dim: int = -1) -> torch.Tensor:
155+
dim: int = -1,
156+
sizes: Optional[List[int]] = None) -> torch.Tensor:
157+
assert sizes is None, "Varying size reduce scatter not supported with base device communicator"
144158
world_size = self.world_size
145159
# Bypass the function if we are using only 1 GPU.
146160
if world_size == 1:

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional
4+
from typing import List, Optional, Union
55

66
import torch
77
from torch.distributed import ProcessGroup
@@ -117,7 +117,10 @@ def all_reduce(self, input_):
117117
torch.distributed.all_reduce(out, group=self.device_group)
118118
return out
119119

120-
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
120+
def reduce_scatter(self,
121+
input_: torch.Tensor,
122+
dim: int = -1,
123+
sizes: Optional[List[int]] = None):
121124
world_size = self.world_size
122125
pynccl_comm = self.pynccl_comm
123126
assert pynccl_comm is not None
@@ -129,15 +132,20 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
129132
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
130133
input_tensor = input_.movedim(0, dim).contiguous()
131134

132-
assert input_tensor.shape[0] % world_size == 0
133-
chunk_size = input_tensor.shape[0] // world_size
135+
if sizes is not None:
136+
assert len(sizes) == world_size
137+
assert input_tensor.shape[0] == sum(sizes)
138+
chunk_size = sizes[self.rank_in_group]
139+
else:
140+
assert input_tensor.shape[0] % world_size == 0
141+
chunk_size = input_tensor.shape[0] // world_size
134142
output_shape = (chunk_size, ) + input_tensor.shape[1:]
135143

136144
output = torch.empty(output_shape,
137145
dtype=input_tensor.dtype,
138146
device=input_tensor.device)
139147

140-
pynccl_comm.reduce_scatter(output, input_)
148+
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
141149

142150
# Reshape before returning
143151
return output.movedim(0, dim).contiguous()
@@ -180,6 +188,53 @@ def destroy(self):
180188
self.all2all_manager.destroy()
181189
self.all2all_manager = None
182190

191+
"""
192+
Allgather with support for list of tensors and varying sizes per rank.
193+
Example:
194+
Instead of:
195+
... = get_ep_group().dispatch(...)
196+
Use this:
197+
... = get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], dim=0, sizes=get_forward_context().dp_metadata.num_tokens_across_dp_cpu)
198+
"""
199+
200+
def all_gatherv(self,
201+
input_: Union[torch.Tensor, List[torch.Tensor]],
202+
dim: int = 0,
203+
sizes: Optional[List[int]] = None):
204+
assert dim == 0, "only dim 0 all-gather is supported"
205+
world_size = self.world_size
206+
pynccl_comm = self.pynccl_comm
207+
assert pynccl_comm is not None and not pynccl_comm.disabled
208+
209+
def _all_gather_single(input_: torch.Tensor,
210+
sizes: Optional[List[int]] = None):
211+
input_size = input_.size()
212+
if sizes is not None:
213+
assert len(sizes) == world_size
214+
assert input_.shape[dim] == sizes[self.rank_in_group]
215+
output_size = (sum(sizes), ) + input_size[1:]
216+
# 'sizes' is not needed if all inputs in the same group have the same shape
217+
if all(s == sizes[0] for s in sizes):
218+
sizes = None
219+
else:
220+
output_size = (input_size[0] * world_size, ) + input_size[1:]
221+
# Allocate output tensor.
222+
output_tensor = torch.empty(output_size,
223+
dtype=input_.dtype,
224+
device=input_.device)
225+
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
226+
return output_tensor
227+
228+
if isinstance(input_, torch.Tensor):
229+
return _all_gather_single(input_, sizes)
230+
231+
pynccl_comm.group_start()
232+
output_list = []
233+
for inp in input_:
234+
output_list.append(_all_gather_single(inp, sizes=sizes))
235+
pynccl_comm.group_end()
236+
return output_list
237+
183238
def dispatch(
184239
self, hidden_states: torch.Tensor,
185240
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

vllm/distributed/device_communicators/pynccl.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional, Union
4+
from typing import List, Optional, Union
55

6+
import numpy as np
67
# ===================== import region =====================
78
import torch
89
import torch.distributed as dist
@@ -135,7 +136,8 @@ def all_reduce(self,
135136
def all_gather(self,
136137
output_tensor: torch.Tensor,
137138
input_tensor: torch.Tensor,
138-
stream=None):
139+
stream=None,
140+
sizes: Optional[List[int]] = None):
139141
if self.disabled:
140142
return
141143
# nccl communicator created on a specific device
@@ -146,17 +148,38 @@ def all_gather(self,
146148
f"but the input tensor is on {input_tensor.device}")
147149
if stream is None:
148150
stream = current_stream()
149-
self.nccl.ncclAllGather(
150-
buffer_type(input_tensor.data_ptr()),
151-
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
152-
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
153-
cudaStream_t(stream.cuda_stream))
151+
if sizes is not None:
152+
assert output_tensor.shape[0] == sum(sizes)
153+
numel_base = int(np.prod(output_tensor.shape[1:]))
154+
split_offset = 0
155+
self.nccl.ncclGroupStart()
156+
for root, split_size in enumerate(sizes):
157+
dst_slice = output_tensor[split_offset:split_offset +
158+
split_size]
159+
self.nccl.ncclBroadcast(
160+
buffer_type(input_tensor.data_ptr()),
161+
buffer_type(dst_slice.data_ptr()),
162+
split_size * numel_base,
163+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
164+
root,
165+
self.comm,
166+
cudaStream_t(stream.cuda_stream),
167+
)
168+
split_offset += split_size
169+
self.nccl.ncclGroupEnd()
170+
else:
171+
self.nccl.ncclAllGather(
172+
buffer_type(input_tensor.data_ptr()),
173+
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
174+
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
175+
cudaStream_t(stream.cuda_stream))
154176

155177
def reduce_scatter(self,
156178
output_tensor: torch.Tensor,
157179
input_tensor: torch.Tensor,
158180
op: ReduceOp = ReduceOp.SUM,
159-
stream=None):
181+
stream=None,
182+
sizes: Optional[List[int]] = None):
160183
if self.disabled:
161184
return
162185
# nccl communicator created on a specific device
@@ -167,12 +190,29 @@ def reduce_scatter(self,
167190
f"but the input tensor is on {input_tensor.device}")
168191
if stream is None:
169192
stream = current_stream()
170-
self.nccl.ncclReduceScatter(
171-
buffer_type(input_tensor.data_ptr()),
172-
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
173-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
174-
ncclRedOpTypeEnum.from_torch(op), self.comm,
175-
cudaStream_t(stream.cuda_stream))
193+
194+
if sizes is not None:
195+
numel_base = int(np.prod(input_tensor.shape[1:]))
196+
split_offset = 0
197+
self.nccl.ncclGroupStart()
198+
for root, split_size in enumerate(sizes):
199+
chunk = input_tensor[split_offset:split_offset + split_size, :]
200+
self.nccl.ncclReduce(
201+
buffer_type(chunk.data_ptr()),
202+
buffer_type(output_tensor.data_ptr()),
203+
split_size * numel_base,
204+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
205+
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
206+
cudaStream_t(stream.cuda_stream))
207+
split_offset += split_size
208+
self.nccl.ncclGroupEnd()
209+
else:
210+
self.nccl.ncclReduceScatter(
211+
buffer_type(input_tensor.data_ptr()),
212+
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
213+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
214+
ncclRedOpTypeEnum.from_torch(op), self.comm,
215+
cudaStream_t(stream.cuda_stream))
176216

177217
def send(self, tensor: torch.Tensor, dst: int, stream=None):
178218
if self.disabled:
@@ -216,3 +256,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
216256
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
217257
ncclDataTypeEnum.from_torch(tensor.dtype), src,
218258
self.comm, cudaStream_t(stream.cuda_stream))
259+
260+
def group_start(self):
261+
self.nccl.ncclGroupStart()
262+
263+
def group_end(self):
264+
self.nccl.ncclGroupEnd()

vllm/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ class NCCLLibrary:
154154
ncclRedOp_t, ncclComm_t, cudaStream_t
155155
]),
156156

157+
# ncclResult_t ncclReduce(
158+
# const void* sendbuff, void* recvbuff, size_t count,
159+
# ncclDataType_t datatype, ncclRedOp_t op, int root,
160+
# ncclComm_t comm, cudaStream_t stream);
161+
# note that cudaStream_t is a pointer type, so the last argument
162+
# is a pointer
163+
Function("ncclReduce", ncclResult_t, [
164+
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
165+
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
166+
]),
157167
# ncclResult_t ncclAllGather(
158168
# const void* sendbuff, void* recvbuff, size_t count,
159169
# ncclDataType_t datatype, ncclComm_t comm,
@@ -207,6 +217,10 @@ class NCCLLibrary:
207217
# it is better not to call it at all.
208218
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
209219
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
220+
# ncclResult_t ncclGroupStart();
221+
Function("ncclGroupStart", ncclResult_t, []),
222+
# ncclResult_t ncclGroupEnd();
223+
Function("ncclGroupEnd", ncclResult_t, []),
210224
]
211225

212226
# class attribute to store the mapping from the path to the library
@@ -300,6 +314,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
300314
datatype, op, comm,
301315
stream))
302316

317+
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
318+
count: int, datatype: int, op: int, root: int,
319+
comm: ncclComm_t, stream: cudaStream_t) -> None:
320+
# `datatype` actually should be `ncclDataType_t`
321+
# and `op` should be `ncclRedOp_t`
322+
# both are aliases of `ctypes.c_int`
323+
# when we pass int to a function, it will be converted to `ctypes.c_int`
324+
# by ctypes automatically
325+
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
326+
datatype, op, root, comm,
327+
stream))
328+
303329
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
304330
count: int, datatype: int, op: int, comm: ncclComm_t,
305331
stream: cudaStream_t) -> None:
@@ -342,6 +368,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
342368
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
343369
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
344370

371+
def ncclGroupStart(self) -> None:
372+
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
373+
374+
def ncclGroupEnd(self) -> None:
375+
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
376+
345377

346378
__all__ = [
347379
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

vllm/distributed/parallel_state.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from contextlib import contextmanager, nullcontext
3131
from dataclasses import dataclass
3232
from multiprocessing import shared_memory
33-
from typing import Any, Callable, Optional, Union
33+
from typing import Any, Callable, List, Optional, Union
3434
from unittest.mock import patch
3535

3636
import torch
@@ -381,9 +381,16 @@ def _all_gather_out_place(self, input_: torch.Tensor,
381381
dim: int) -> torch.Tensor:
382382
return self.device_communicator.all_gather(input_, dim)
383383

384+
def all_gatherv(self,
385+
input_: Union[torch.Tensor, List[torch.Tensor]],
386+
dim: int = 0,
387+
sizes: Optional[List[int]] = None):
388+
return self.device_communicator.all_gatherv(input_, dim, sizes)
389+
384390
def reduce_scatter(self,
385391
input_: torch.Tensor,
386-
dim: int = -1) -> torch.Tensor:
392+
dim: int = -1,
393+
sizes: Optional[List[int]] = None) -> torch.Tensor:
387394
world_size = self.world_size
388395
# Bypass the function if we are using only 1 GPU.
389396
if world_size == 1:
@@ -392,16 +399,20 @@ def reduce_scatter(self,
392399
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
393400

394401
if self.use_custom_op_call:
402+
assert sizes is None, "Varying size reduce scatter not supported with vllm custom op"
395403
return torch.ops.vllm.reduce_scatter(input_,
396404
dim,
397405
world_size,
398406
group_name=self.unique_name)
399407
else:
400-
return self._reduce_scatter_out_place(input_, dim)
401-
402-
def _reduce_scatter_out_place(self, input_: torch.Tensor,
403-
dim: int) -> torch.Tensor:
404-
return self.device_communicator.reduce_scatter(input_, dim)
408+
return self._reduce_scatter_out_place(input_, dim, sizes)
409+
410+
def _reduce_scatter_out_place(
411+
self,
412+
input_: torch.Tensor,
413+
dim: int,
414+
sizes: Optional[List[int]] = None) -> torch.Tensor:
415+
return self.device_communicator.reduce_scatter(input_, dim, sizes)
405416

406417
def gather(self,
407418
input_: torch.Tensor,

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
122122
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
123123
VLLM_USE_DEEP_GEMM: bool = False
124+
VLLM_USE_FLASHINFER_MOE: bool = False
124125
VLLM_XGRAMMAR_CACHE_MB: int = 0
125126
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
126127
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -867,6 +868,10 @@ def get_vllm_port() -> Optional[int]:
867868
"VLLM_USE_DEEP_GEMM":
868869
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
869870

871+
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
872+
"VLLM_USE_FLASHINFER_MOE":
873+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
874+
870875
# Control the cache sized used by the xgrammar compiler. The default
871876
# of 512 MB should be enough for roughly 1000 JSON schemas.
872877
# It can be changed with this variable if needed for some reason.

0 commit comments

Comments
 (0)