Skip to content

Commit f1a809e

Browse files
youkaichaoAkshat-Tripathi
authored andcommitted
[core] set up data parallel communication (vllm-project#13591)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 439a0ea commit f1a809e

File tree

17 files changed

+416
-28
lines changed

17 files changed

+416
-28
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ steps:
134134
- tests/compile/test_basic_correctness
135135
- examples/offline_inference/rlhf.py
136136
- examples/offline_inference/rlhf_colocate.py
137+
- tests/examples/offline_inference/data_parallel.py
137138
commands:
139+
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
138140
- pytest -v -s distributed/test_utils.py
139141
- pytest -v -s compile/test_basic_correctness.py
140142
- pytest -v -s distributed/test_pynccl.py
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
3+
# we need to have a launcher to create multiple data parallel
4+
# ranks. And each rank will create a vLLM instance to process its own prompts.
5+
import os
6+
7+
from vllm import LLM, SamplingParams
8+
from vllm.utils import get_open_port
9+
10+
11+
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
12+
os.environ["VLLM_DP_RANK"] = str(dp_rank)
13+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
14+
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
15+
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
16+
# set devices for each dp_rank
17+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
18+
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
19+
GPUs_per_dp_rank))
20+
21+
# Sample prompts.
22+
prompts = [
23+
"Hello, my name is",
24+
"The president of the United States is",
25+
"The capital of France is",
26+
"The future of AI is",
27+
]
28+
29+
# with DP, each rank should process different prompts.
30+
# usually all the DP ranks process a full dataset,
31+
# and each rank processes a different part of the dataset.
32+
promts_per_rank = len(prompts) // dp_size
33+
start = dp_rank * promts_per_rank
34+
end = start + promts_per_rank
35+
prompts = prompts[start:end]
36+
if len(prompts) == 0:
37+
# if any rank has no prompts to process,
38+
# we need to set a placeholder prompt
39+
prompts = ["Placeholder"]
40+
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
41+
42+
# Create a sampling params object.
43+
# since we are doing data parallel, every rank can have different
44+
# sampling params. here we set different max_tokens for different
45+
# ranks for demonstration.
46+
sampling_params = SamplingParams(temperature=0.8,
47+
top_p=0.95,
48+
max_tokens=16 * (dp_rank + 1))
49+
50+
# Create an LLM.
51+
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True)
52+
outputs = llm.generate(prompts, sampling_params)
53+
# Print the outputs.
54+
for output in outputs:
55+
prompt = output.prompt
56+
generated_text = output.outputs[0].text
57+
print(
58+
f"DP rank {dp_rank}, Prompt: {prompt!r}, "
59+
f"Generated text: {generated_text!r}")
60+
61+
62+
if __name__ == "__main__":
63+
from multiprocessing import Process
64+
dp_size = 2
65+
GPUs_per_dp_rank = 2
66+
dp_master_ip = "127.0.0.1"
67+
dp_master_port = get_open_port()
68+
procs = []
69+
for i in range(dp_size):
70+
proc = Process(target=main,
71+
args=(dp_size, i, dp_master_ip, dp_master_port,
72+
GPUs_per_dp_rank))
73+
proc.start()
74+
procs.append(proc)
75+
for proc in procs:
76+
proc.join()

vllm/config.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
from pydantic import BaseModel, Field, PrivateAttr
19+
from torch.distributed import ProcessGroup, ReduceOp
1920
from transformers import PretrainedConfig
2021

2122
import vllm.envs as envs
@@ -1296,6 +1297,11 @@ class ParallelConfig:
12961297

12971298
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
12981299
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
1300+
data_parallel_size: int = 1 # Number of data parallel groups.
1301+
data_parallel_rank: int = 0 # Rank of the data parallel group.
1302+
# IP of the data parallel master.
1303+
data_parallel_master_ip: str = "127.0.0.1"
1304+
data_parallel_master_port: int = 29500 # Port of the data parallel master.
12991305

13001306
# Maximum number of multiple batches
13011307
# when load model sequentially. To avoid RAM OOM when using tensor
@@ -1329,10 +1335,55 @@ class ParallelConfig:
13291335
worker_cls: str = "auto"
13301336
sd_worker_cls: str = "auto"
13311337

1338+
# world_size is TPxPP, it affects the number of workers we create.
13321339
world_size: int = field(init=False)
1340+
# world_size_across_dp is TPxPPxDP, it is the size of the world
1341+
# including data parallelism.
1342+
world_size_across_dp: int = field(init=False)
13331343

13341344
rank: int = 0
13351345

1346+
def get_next_dp_init_port(self) -> int:
1347+
"""
1348+
We might need to initialize process groups in multiple
1349+
processes that is related to data parallelism,
1350+
e.g. both in the worker and in the engine, which
1351+
can live in different processes. To avoid port conflicts, we
1352+
increment the port number each time we need to initialize a
1353+
new process group related to data parallelism.
1354+
"""
1355+
answer = self.data_parallel_master_port
1356+
self.data_parallel_master_port += 1
1357+
return answer
1358+
1359+
def stateless_init_dp_group(self) -> "ProcessGroup":
1360+
from vllm.distributed.utils import (
1361+
stateless_init_torch_distributed_process_group)
1362+
1363+
# use gloo since the engine process might not have cuda device
1364+
dp_group = stateless_init_torch_distributed_process_group(
1365+
self.data_parallel_master_ip,
1366+
self.get_next_dp_init_port(),
1367+
self.data_parallel_rank,
1368+
self.data_parallel_size,
1369+
backend="gloo")
1370+
1371+
return dp_group
1372+
1373+
@staticmethod
1374+
def has_unfinished_dp(dp_group: "ProcessGroup",
1375+
has_unfinished: bool) -> bool:
1376+
tensor = torch.tensor([has_unfinished],
1377+
dtype=torch.int32,
1378+
device="cpu")
1379+
# dp rank 0: has_unfinished_seqs=True
1380+
# dp rank 1: has_unfinished_seqs=False
1381+
# aggregated: has_unfinished_seqs=True
1382+
# so this is an OR operation, i.e. MAX in integers
1383+
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
1384+
aggregated_has_unfinished = bool(tensor.item())
1385+
return aggregated_has_unfinished
1386+
13361387
def compute_hash(self):
13371388
"""
13381389
Provide a hash that uniquely identifies all the configs
@@ -1350,6 +1401,12 @@ def __post_init__(self) -> None:
13501401
self.world_size = self.pipeline_parallel_size * \
13511402
self.tensor_parallel_size
13521403

1404+
self.data_parallel_size = envs.VLLM_DP_SIZE
1405+
self.data_parallel_rank = envs.VLLM_DP_RANK
1406+
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
1407+
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
1408+
self.world_size_across_dp = self.world_size * self.data_parallel_size
1409+
13531410
ray_only_devices = ["tpu"]
13541411
from vllm.platforms import current_platform
13551412
if (current_platform.device_type in ray_only_devices

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def __init__(self,
1616
device_group: Optional[ProcessGroup] = None,
1717
unique_name: str = ""):
1818
super().__init__(cpu_group, device, device_group, unique_name)
19-
if "pp" in unique_name:
20-
# pipeline parallel does not need custom allreduce
19+
if "tp" not in unique_name:
20+
# only tp uses custom allreduce
2121
use_custom_allreduce = False
2222
else:
2323
from vllm.distributed.parallel_state import (

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self,
8787
return
8888

8989
rank = dist.get_rank(group=self.group)
90+
self.rank = rank
9091
world_size = dist.get_world_size(group=self.group)
9192
if world_size == 1:
9293
# No need to initialize custom allreduce for single GPU case.
@@ -201,8 +202,10 @@ def create_shared_buffer(
201202

202203
@staticmethod
203204
def free_shared_buffer(pointers: List[int],
204-
group: Optional[ProcessGroup] = None) -> None:
205-
rank = dist.get_rank(group=group)
205+
group: Optional[ProcessGroup] = None,
206+
rank: Optional[int] = None) -> None:
207+
if rank is None:
208+
rank = dist.get_rank(group=group)
206209
lib = CudaRTLibrary()
207210
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
208211

@@ -298,8 +301,8 @@ def close(self):
298301
if not self.disabled and self._ptr:
299302
ops.dispose(self._ptr)
300303
self._ptr = 0
301-
self.free_shared_buffer(self.meta_ptrs)
302-
self.free_shared_buffer(self.buffer_ptrs)
304+
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
305+
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
303306

304307
def __del__(self):
305308
self.close()

vllm/distributed/parallel_state.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,13 @@ def get_tp_group() -> GroupCoordinator:
750750

751751
_PP: Optional[GroupCoordinator] = None
752752

753+
_DP: Optional[GroupCoordinator] = None
754+
755+
756+
def get_dp_group() -> GroupCoordinator:
757+
assert _DP is not None, ("data parallel group is not initialized")
758+
return _DP
759+
753760

754761
def get_pp_group() -> GroupCoordinator:
755762
assert _PP is not None, (
@@ -811,6 +818,21 @@ def init_distributed_environment(
811818
"world_size=%d rank=%d local_rank=%d "
812819
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
813820
distributed_init_method, backend)
821+
from vllm.config import get_current_vllm_config
822+
config = get_current_vllm_config()
823+
if config is not None and config.parallel_config.data_parallel_size > 1:
824+
parallel_config = config.parallel_config
825+
# adjust to take into account data parallelism
826+
# offset the rank by the data parallel rank
827+
rank = parallel_config.data_parallel_rank * world_size + rank
828+
# adjust the world size to take into account data parallelism
829+
world_size = parallel_config.world_size_across_dp
830+
ip = parallel_config.data_parallel_master_ip
831+
port = parallel_config.get_next_dp_init_port()
832+
distributed_init_method = f"tcp://{ip}:{port}" # noqa
833+
logger.info(
834+
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
835+
world_size, rank, distributed_init_method)
814836
if not torch.distributed.is_initialized():
815837
assert distributed_init_method is not None, (
816838
"distributed_init_method must be provided when initializing "
@@ -870,20 +892,28 @@ def initialize_model_parallel(
870892
# Get world size and rank. Ensure some consistencies.
871893
assert torch.distributed.is_initialized()
872894
world_size: int = torch.distributed.get_world_size()
895+
rank = torch.distributed.get_rank()
873896
backend = backend or torch.distributed.get_backend(
874897
get_world_group().device_group)
875898

899+
data_parallel_size = 1
900+
from vllm.config import get_current_vllm_config
901+
config = get_current_vllm_config()
902+
if config is not None:
903+
data_parallel_size = config.parallel_config.data_parallel_size
904+
905+
# the layout order is: DP x PP x TP
906+
# to get group_ranks for each dimension, transpose that dimension to the
907+
# last dimension, then reshape to 2D, then unbind the last dimension
908+
all_ranks = torch.arange(world_size).reshape(
909+
data_parallel_size, pipeline_model_parallel_size,
910+
tensor_model_parallel_size) # noqa
911+
876912
# Build the tensor model-parallel groups.
877-
num_tensor_model_parallel_groups: int = (world_size //
878-
tensor_model_parallel_size)
879913
global _TP
880914
assert _TP is None, ("tensor model parallel group is already initialized")
881-
group_ranks = []
882-
for i in range(num_tensor_model_parallel_groups):
883-
ranks = list(
884-
range(i * tensor_model_parallel_size,
885-
(i + 1) * tensor_model_parallel_size))
886-
group_ranks.append(ranks)
915+
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
916+
group_ranks = [x.tolist() for x in group_ranks]
887917

888918
# message queue broadcaster is only used in tensor model parallel group
889919
_TP = init_model_parallel_group(group_ranks,
@@ -893,20 +923,33 @@ def initialize_model_parallel(
893923
group_name="tp")
894924

895925
# Build the pipeline model-parallel groups.
896-
num_pipeline_model_parallel_groups: int = (world_size //
897-
pipeline_model_parallel_size)
898926
global _PP
899927
assert _PP is None, (
900928
"pipeline model parallel group is already initialized")
901-
group_ranks = []
902-
for i in range(num_pipeline_model_parallel_groups):
903-
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
904-
group_ranks.append(ranks)
929+
group_ranks = all_ranks.transpose(1, 2).reshape(
930+
-1, pipeline_model_parallel_size).unbind(0)
931+
group_ranks = [x.tolist() for x in group_ranks]
905932
_PP = init_model_parallel_group(group_ranks,
906933
get_world_group().local_rank,
907934
backend,
908935
group_name="pp")
909936

937+
global _DP
938+
assert _DP is None, ("data parallel group is already initialized")
939+
group_ranks = all_ranks.transpose(0,
940+
2).reshape(-1,
941+
data_parallel_size).unbind(0)
942+
group_ranks = [x.tolist() for x in group_ranks]
943+
_DP = init_model_parallel_group(group_ranks,
944+
get_world_group().local_rank,
945+
backend,
946+
group_name="dp")
947+
948+
logger.info(
949+
"rank %s in world size %s is assigned as "
950+
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
951+
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
952+
910953

911954
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
912955
"""
@@ -1011,6 +1054,11 @@ def destroy_model_parallel():
10111054
_PP.destroy()
10121055
_PP = None
10131056

1057+
global _DP
1058+
if _DP:
1059+
_DP.destroy()
1060+
_DP = None
1061+
10141062

10151063
def destroy_distributed_environment():
10161064
global _WORLD

0 commit comments

Comments
 (0)