Skip to content

Commit 510085d

Browse files
committed
Enable symm mem in DP setup
Signed-off-by: ilmarkov <markovilya197@gmail.com>
1 parent 2675bca commit 510085d

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,18 @@ def __init__(self,
2424
unique_name: str = ""):
2525
super().__init__(cpu_group, device, device_group, unique_name)
2626
if "tp" not in unique_name:
27-
# only tp uses custom allreduce
27+
# custom allreduce or torch symm mem can be used only by tp
2828
use_custom_allreduce = False
29+
use_torch_symm_mem = False
2930
else:
3031
from vllm.distributed.parallel_state import (
3132
_ENABLE_CUSTOM_ALL_REDUCE)
3233
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
34+
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
3335

3436
# ep does not use pynccl
3537
use_pynccl = "ep" not in unique_name
3638

37-
use_torch_symm_mem = ("ep" not in unique_name
38-
and envs.VLLM_ALLREDUCE_USE_SYMM_MEM)
39-
4039
self.use_pynccl = use_pynccl
4140
self.use_custom_allreduce = use_custom_allreduce
4241
self.use_torch_symm_mem = use_torch_symm_mem

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import os
43
from typing import Optional, Union
54

65
import torch
@@ -81,11 +80,6 @@ def __init__(
8180
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
8281
self.device_capability][self.world_size]
8382

84-
# allow overlapping devices in case of data parallel
85-
# if torch symm mem can not initialize multicast ptr will be 0
86-
if self.world_size < dist.get_world_size():
87-
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "1"
88-
8983
self.buffer = torch_symm_mem.empty(
9084
self.max_size // self.dtype.itemsize,
9185
device=self.device,

0 commit comments

Comments
 (0)