-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[core] set up data parallel communication #13591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9f6c969
756425f
eac1046
6f2de3c
3d5f971
ccb2f75
9e10661
fea1ab2
ab63901
4652353
c26df65
ebbcb18
0c26f38
afbaca4
0f93cb3
f8ffe7e
8330dcd
3ab3465
80ae5b6
1bf33ea
fba6287
81468ad
b284b36
be8c281
af53b4b
32c78e5
18ed136
f79743b
2114e7d
e1034cd
bb7d639
c41839c
9c67a65
b137c89
103aa2e
10387cf
6c73748
4c332b7
90b770e
2e7caea
3809f1b
d6e4eba
866395d
b0b5c05
01d4242
37fbae6
a0b5ab6
25fdb12
a3d44d1
5e3b75a
9440691
4ac1f34
741da58
eef9c06
d7c6212
7b14349
df0239f
216bbb9
79404cb
6c04388
7289f6c
6e5a472
c3704a0
9586d0c
445b474
3d92428
0c6b1db
779cb33
29e6e60
267cd82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py | ||
| # we need to have a launcher to create multiple data parallel | ||
| # ranks. And each rank will create a vLLM instance to process its own prompts. | ||
| import os | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.utils import get_open_port | ||
|
|
||
|
|
||
| def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): | ||
| os.environ["VLLM_DP_RANK"] = str(dp_rank) | ||
| os.environ["VLLM_DP_SIZE"] = str(dp_size) | ||
| os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip | ||
| os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) | ||
| # set devices for each dp_rank | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( | ||
| str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) * | ||
| GPUs_per_dp_rank)) | ||
|
|
||
| # Sample prompts. | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] | ||
|
|
||
| # with DP, each rank should process different prompts. | ||
| # usually all the DP ranks process a full dataset, | ||
| # and each rank processes a different part of the dataset. | ||
| promts_per_rank = len(prompts) // dp_size | ||
| start = dp_rank * promts_per_rank | ||
| end = start + promts_per_rank | ||
| prompts = prompts[start:end] | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if len(prompts) == 0: | ||
| # if any rank has no prompts to process, | ||
| # we need to set a placeholder prompt | ||
| prompts = ["Placeholder"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is just an example but in practice I guess you'd want to set max_tokens to 1 for any placeholder prompts. |
||
| print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") | ||
|
|
||
| # Create a sampling params object. | ||
| # since we are doing data parallel, every rank can have different | ||
| # sampling params. here we set different max_tokens for different | ||
| # ranks for demonstration. | ||
| sampling_params = SamplingParams(temperature=0.8, | ||
| top_p=0.95, | ||
| max_tokens=16 * (dp_rank + 1)) | ||
|
|
||
| # Create an LLM. | ||
| llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True) | ||
| outputs = llm.generate(prompts, sampling_params) | ||
| # Print the outputs. | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print( | ||
| f"DP rank {dp_rank}, Prompt: {prompt!r}, " | ||
| f"Generated text: {generated_text!r}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from multiprocessing import Process | ||
| dp_size = 2 | ||
| GPUs_per_dp_rank = 2 | ||
| dp_master_ip = "127.0.0.1" | ||
| dp_master_port = get_open_port() | ||
| procs = [] | ||
| for i in range(dp_size): | ||
| proc = Process(target=main, | ||
| args=(dp_size, i, dp_master_ip, dp_master_port, | ||
| GPUs_per_dp_rank)) | ||
| proc.start() | ||
| procs.append(proc) | ||
| for proc in procs: | ||
| proc.join() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
|
|
||
| import torch | ||
| from pydantic import BaseModel, Field, PrivateAttr | ||
| from torch.distributed import ProcessGroup, ReduceOp | ||
| from transformers import PretrainedConfig | ||
|
|
||
| import vllm.envs as envs | ||
|
|
@@ -1290,6 +1291,11 @@ class ParallelConfig: | |
|
|
||
| pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. | ||
| tensor_parallel_size: int = 1 # Number of tensor parallel groups. | ||
| data_parallel_size: int = 1 # Number of data parallel groups. | ||
| data_parallel_rank: int = 0 # Rank of the data parallel group. | ||
| # IP of the data parallel master. | ||
| data_parallel_master_ip: str = "127.0.0.1" | ||
| data_parallel_master_port: int = 29500 # Port of the data parallel master. | ||
|
|
||
| # Maximum number of multiple batches | ||
| # when load model sequentially. To avoid RAM OOM when using tensor | ||
|
|
@@ -1323,10 +1329,55 @@ class ParallelConfig: | |
| worker_cls: str = "auto" | ||
| sd_worker_cls: str = "auto" | ||
|
|
||
| # world_size is TPxPP, it affects the number of workers we create. | ||
| world_size: int = field(init=False) | ||
| # world_size_across_dp is TPxPPxDP, it is the size of the world | ||
| # including data parallelism. | ||
| world_size_across_dp: int = field(init=False) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| rank: int = 0 | ||
|
|
||
| def get_next_dp_init_port(self) -> int: | ||
| """ | ||
| We might need to initialize process groups in multiple | ||
| processes that is related to data parallelism, | ||
| e.g. both in the worker and in the engine, which | ||
| can live in different processes. To avoid port conflicts, we | ||
| increment the port number each time we need to initialize a | ||
| new process group related to data parallelism. | ||
| """ | ||
| answer = self.data_parallel_master_port | ||
| self.data_parallel_master_port += 1 | ||
|
Comment on lines
+1349
to
+1350
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the port is already being used by other services? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then it will error. We can document and say we will use more than one port starting from the specified port. And the assumption usually should be fine. NOTE: even if we only use the specified port, there're still chances that some other services already use that port before we start to use that port. It is unavoidable if we are running multiple services in the same host. But for cloud deployment, where each service runs in a separate container, it should be fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively we can just check if this port is being used using socket? So we just keep searching for the next available port There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not feasible because non-zero ranks will directly connect to the specified port, and it does not know if it is the master rank or some other services. and it also needs to wait for some time in case the master rank is not started yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the code in 267cd82, at least vllm's internal port usage will not conflict with the dp master ports. |
||
| return answer | ||
|
|
||
| def stateless_init_dp_group(self) -> "ProcessGroup": | ||
| from vllm.distributed.utils import ( | ||
| stateless_init_torch_distributed_process_group) | ||
|
|
||
| # use gloo since the engine process might not have cuda device | ||
| dp_group = stateless_init_torch_distributed_process_group( | ||
| self.data_parallel_master_ip, | ||
| self.get_next_dp_init_port(), | ||
| self.data_parallel_rank, | ||
| self.data_parallel_size, | ||
| backend="gloo") | ||
|
|
||
| return dp_group | ||
|
|
||
| @staticmethod | ||
| def has_unfinished_dp(dp_group: "ProcessGroup", | ||
| has_unfinished: bool) -> bool: | ||
| tensor = torch.tensor([has_unfinished], | ||
| dtype=torch.int32, | ||
| device="cpu") | ||
| # dp rank 0: has_unfinished_seqs=True | ||
| # dp rank 1: has_unfinished_seqs=False | ||
| # aggregated: has_unfinished_seqs=True | ||
| # so this is an OR operation, i.e. MAX in integers | ||
| torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) | ||
| aggregated_has_unfinished = bool(tensor.item()) | ||
| return aggregated_has_unfinished | ||
|
|
||
| def compute_hash(self): | ||
| """ | ||
| Provide a hash that uniquely identifies all the configs | ||
|
|
@@ -1344,6 +1395,12 @@ def __post_init__(self) -> None: | |
| self.world_size = self.pipeline_parallel_size * \ | ||
| self.tensor_parallel_size | ||
|
|
||
| self.data_parallel_size = envs.VLLM_DP_SIZE | ||
| self.data_parallel_rank = envs.VLLM_DP_RANK | ||
| self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP | ||
| self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that I'm hitting issues like: This is true even if I change the master port with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's strange. I also met it once but then it disappeared. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems this disappeared when i remove torchrun in af53b4b |
||
| self.world_size_across_dp = self.world_size * self.data_parallel_size | ||
|
|
||
| ray_only_devices = ["tpu"] | ||
| from vllm.platforms import current_platform | ||
| if (current_platform.device_type in ray_only_devices | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -750,6 +750,13 @@ def get_tp_group() -> GroupCoordinator: | |
|
|
||
| _PP: Optional[GroupCoordinator] = None | ||
|
|
||
| _DP: Optional[GroupCoordinator] = None | ||
|
|
||
|
|
||
| def get_dp_group() -> GroupCoordinator: | ||
| assert _DP is not None, ("data parallel group is not initialized") | ||
| return _DP | ||
|
|
||
|
|
||
| def get_pp_group() -> GroupCoordinator: | ||
| assert _PP is not None, ( | ||
|
|
@@ -811,6 +818,21 @@ def init_distributed_environment( | |
| "world_size=%d rank=%d local_rank=%d " | ||
| "distributed_init_method=%s backend=%s", world_size, rank, local_rank, | ||
| distributed_init_method, backend) | ||
| from vllm.config import get_current_vllm_config | ||
| config = get_current_vllm_config() | ||
| if config is not None and config.parallel_config.data_parallel_size > 1: | ||
| parallel_config = config.parallel_config | ||
| # adjust to take into account data parallelism | ||
| # offset the rank by the data parallel rank | ||
| rank = parallel_config.data_parallel_rank * world_size + rank | ||
| # adjust the world size to take into account data parallelism | ||
| world_size = parallel_config.world_size_across_dp | ||
| ip = parallel_config.data_parallel_master_ip | ||
| port = parallel_config.get_next_dp_init_port() | ||
| distributed_init_method = f"tcp://{ip}:{port}" # noqa | ||
| logger.info( | ||
| "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", | ||
| world_size, rank, distributed_init_method) | ||
| if not torch.distributed.is_initialized(): | ||
| assert distributed_init_method is not None, ( | ||
| "distributed_init_method must be provided when initializing " | ||
|
|
@@ -870,20 +892,28 @@ def initialize_model_parallel( | |
| # Get world size and rank. Ensure some consistencies. | ||
| assert torch.distributed.is_initialized() | ||
| world_size: int = torch.distributed.get_world_size() | ||
| rank = torch.distributed.get_rank() | ||
| backend = backend or torch.distributed.get_backend( | ||
| get_world_group().device_group) | ||
|
|
||
| data_parallel_size = 1 | ||
| from vllm.config import get_current_vllm_config | ||
| config = get_current_vllm_config() | ||
| if config is not None: | ||
| data_parallel_size = config.parallel_config.data_parallel_size | ||
|
|
||
| # the layout order is: DP x PP x TP | ||
| # to get group_ranks for each dimension, transpose that dimension to the | ||
| # last dimension, then reshape to 2D, then unbind the last dimension | ||
| all_ranks = torch.arange(world_size).reshape( | ||
| data_parallel_size, pipeline_model_parallel_size, | ||
| tensor_model_parallel_size) # noqa | ||
|
|
||
| # Build the tensor model-parallel groups. | ||
| num_tensor_model_parallel_groups: int = (world_size // | ||
| tensor_model_parallel_size) | ||
| global _TP | ||
| assert _TP is None, ("tensor model parallel group is already initialized") | ||
| group_ranks = [] | ||
| for i in range(num_tensor_model_parallel_groups): | ||
| ranks = list( | ||
| range(i * tensor_model_parallel_size, | ||
| (i + 1) * tensor_model_parallel_size)) | ||
| group_ranks.append(ranks) | ||
| group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) | ||
| group_ranks = [x.tolist() for x in group_ranks] | ||
|
|
||
| # message queue broadcaster is only used in tensor model parallel group | ||
| _TP = init_model_parallel_group(group_ranks, | ||
|
|
@@ -893,20 +923,33 @@ def initialize_model_parallel( | |
| group_name="tp") | ||
|
|
||
| # Build the pipeline model-parallel groups. | ||
| num_pipeline_model_parallel_groups: int = (world_size // | ||
| pipeline_model_parallel_size) | ||
| global _PP | ||
| assert _PP is None, ( | ||
| "pipeline model parallel group is already initialized") | ||
| group_ranks = [] | ||
| for i in range(num_pipeline_model_parallel_groups): | ||
| ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) | ||
| group_ranks.append(ranks) | ||
| group_ranks = all_ranks.transpose(1, 2).reshape( | ||
| -1, pipeline_model_parallel_size).unbind(0) | ||
| group_ranks = [x.tolist() for x in group_ranks] | ||
| _PP = init_model_parallel_group(group_ranks, | ||
| get_world_group().local_rank, | ||
| backend, | ||
| group_name="pp") | ||
|
|
||
| global _DP | ||
| assert _DP is None, ("data parallel group is already initialized") | ||
| group_ranks = all_ranks.transpose(0, | ||
| 2).reshape(-1, | ||
| data_parallel_size).unbind(0) | ||
| group_ranks = [x.tolist() for x in group_ranks] | ||
| _DP = init_model_parallel_group(group_ranks, | ||
| get_world_group().local_rank, | ||
| backend, | ||
| group_name="dp") | ||
|
|
||
| logger.info( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. example of the rank assignment for DP=2 x TP=2: |
||
| "rank %s in world size %s is assigned as " | ||
| "DP rank %s, PP rank %s, TP rank %s", rank, world_size, | ||
| _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) | ||
|
|
||
|
|
||
| def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: | ||
| """ | ||
|
|
@@ -1011,6 +1054,11 @@ def destroy_model_parallel(): | |
| _PP.destroy() | ||
| _PP = None | ||
|
|
||
| global _DP | ||
| if _DP: | ||
| _DP.destroy() | ||
| _DP = None | ||
|
|
||
|
|
||
| def destroy_distributed_environment(): | ||
| global _WORLD | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.