Skip to content

Commit 49bfc53

Browse files
authored
Update num_tokens_across_dp to use nccl instead of gloo (#24105)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
1 parent a0b2670 commit 49bfc53

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
9393
VLLM_SKIP_P2P_CHECK: bool = False
9494
VLLM_DISABLED_KERNELS: list[str] = []
95+
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
9596
VLLM_USE_V1: bool = True
9697
VLLM_ROCM_USE_AITER: bool = False
9798
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
@@ -745,6 +746,13 @@ def get_vllm_port() -> Optional[int]:
745746
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
746747
"VLLM_DISABLED_KERNELS"].split(","),
747748

749+
# Swaps the all reduce backend that we use to coordinate the DP padding
750+
# information from NCCL to gloo.
751+
"VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION":
752+
lambda:
753+
(os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in
754+
("true", "1")),
755+
748756
# If set, use the V1 code path.
749757
"VLLM_USE_V1":
750758
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

vllm/forward_context.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import vllm.envs as envs
1414
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
1515
from vllm.logger import init_logger
16+
from vllm.platforms import current_platform
1617

1718
if TYPE_CHECKING:
1819
from vllm.attention.backends.abstract import AttentionMetadata
@@ -75,14 +76,26 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
7576
Gather the num_tokens across all DP ranks and return results in a
7677
CPU tensor of size dp_size.
7778
"""
79+
from vllm.distributed.parallel_state import get_dp_group
80+
device = current_platform.device_type
81+
group = get_dp_group().device_group
82+
83+
# Transfering this tensor from GPU to CPU will introduce a GPU sync
84+
# point that could adversely affect performance of vllm with asynch
85+
# scheduling. This environment variable exists to quickly disable
86+
# this optimization if we run into this case.
87+
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
88+
logger.info_once(
89+
"Using CPU all reduce to syncronize DP padding between ranks.")
90+
device = "cpu"
91+
group = get_dp_group().cpu_group
7892
num_tokens_across_dp = [0] * dp_size
7993
num_tokens_across_dp[dp_rank] = num_tokens
8094
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
81-
device="cpu",
95+
device=device,
8296
dtype=torch.int32)
83-
from vllm.distributed.parallel_state import get_dp_group
84-
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
85-
return num_tokens_tensor
97+
dist.all_reduce(num_tokens_tensor, group=group)
98+
return num_tokens_tensor.cpu()
8699

87100
@staticmethod
88101
def make(

0 commit comments

Comments
 (0)