|
13 | 13 | import vllm.envs as envs |
14 | 14 | from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig |
15 | 15 | from vllm.logger import init_logger |
| 16 | +from vllm.platforms import current_platform |
16 | 17 |
|
17 | 18 | if TYPE_CHECKING: |
18 | 19 | from vllm.attention.backends.abstract import AttentionMetadata |
@@ -75,14 +76,26 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, |
75 | 76 | Gather the num_tokens across all DP ranks and return results in a |
76 | 77 | CPU tensor of size dp_size. |
77 | 78 | """ |
| 79 | + from vllm.distributed.parallel_state import get_dp_group |
| 80 | + device = current_platform.device_type |
| 81 | + group = get_dp_group().device_group |
| 82 | + |
| 83 | + # Transfering this tensor from GPU to CPU will introduce a GPU sync |
| 84 | + # point that could adversely affect performance of vllm with asynch |
| 85 | + # scheduling. This environment variable exists to quickly disable |
| 86 | + # this optimization if we run into this case. |
| 87 | + if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: |
| 88 | + logger.info_once( |
| 89 | + "Using CPU all reduce to syncronize DP padding between ranks.") |
| 90 | + device = "cpu" |
| 91 | + group = get_dp_group().cpu_group |
78 | 92 | num_tokens_across_dp = [0] * dp_size |
79 | 93 | num_tokens_across_dp[dp_rank] = num_tokens |
80 | 94 | num_tokens_tensor = torch.tensor(num_tokens_across_dp, |
81 | | - device="cpu", |
| 95 | + device=device, |
82 | 96 | dtype=torch.int32) |
83 | | - from vllm.distributed.parallel_state import get_dp_group |
84 | | - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) |
85 | | - return num_tokens_tensor |
| 97 | + dist.all_reduce(num_tokens_tensor, group=group) |
| 98 | + return num_tokens_tensor.cpu() |
86 | 99 |
|
87 | 100 | @staticmethod |
88 | 101 | def make( |
|
0 commit comments