Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
9f6c969
add env vars
youkaichao Feb 20, 2025
756425f
add tests
youkaichao Feb 20, 2025
eac1046
add examples
youkaichao Feb 20, 2025
6f2de3c
adjust init
youkaichao Feb 20, 2025
3d5f971
fix ip
youkaichao Feb 20, 2025
ccb2f75
move init device into init worker so that vllm config is set
youkaichao Feb 20, 2025
9e10661
add logs
youkaichao Feb 20, 2025
fea1ab2
init groups
youkaichao Feb 20, 2025
ab63901
support multiple groups with dp
youkaichao Feb 20, 2025
4652353
add field
youkaichao Feb 20, 2025
c26df65
add utils
youkaichao Feb 20, 2025
ebbcb18
sync on has_unfinished_requests
youkaichao Feb 20, 2025
0c26f38
cancel env vars
youkaichao Feb 20, 2025
afbaca4
cancel env vars
youkaichao Feb 20, 2025
0f93cb3
simplify code
youkaichao Feb 20, 2025
f8ffe7e
change step
youkaichao Feb 20, 2025
8330dcd
v1 support
youkaichao Feb 20, 2025
3ab3465
improve examples
youkaichao Feb 20, 2025
80ae5b6
simplify examples
youkaichao Feb 20, 2025
1bf33ea
unify code
youkaichao Feb 21, 2025
fba6287
unify code
youkaichao Feb 21, 2025
81468ad
fix v0?
youkaichao Feb 21, 2025
b284b36
sync num_tokens_across_dp
youkaichao Feb 21, 2025
be8c281
fix
youkaichao Feb 21, 2025
af53b4b
remove torchrun
youkaichao Feb 21, 2025
32c78e5
fix
youkaichao Feb 21, 2025
18ed136
fix
youkaichao Feb 21, 2025
f79743b
fix?
youkaichao Feb 21, 2025
2114e7d
relax?
youkaichao Feb 21, 2025
e1034cd
relax?
youkaichao Feb 21, 2025
bb7d639
change to python
youkaichao Feb 21, 2025
c41839c
support v1?
youkaichao Feb 21, 2025
9c67a65
move init_device
youkaichao Feb 21, 2025
b137c89
revert init device
youkaichao Feb 21, 2025
103aa2e
init device
youkaichao Feb 21, 2025
10387cf
move engine core
youkaichao Feb 21, 2025
6c73748
fix loop
youkaichao Feb 21, 2025
4c332b7
sync across one batch
youkaichao Feb 21, 2025
90b770e
fix?
youkaichao Feb 21, 2025
2e7caea
revert v0
youkaichao Feb 21, 2025
3809f1b
use v1 test
youkaichao Feb 21, 2025
d6e4eba
revert utils
youkaichao Feb 21, 2025
866395d
fix?
youkaichao Feb 21, 2025
b0b5c05
Merge branch 'main' into manual_dp
youkaichao Feb 21, 2025
01d4242
revert core
youkaichao Feb 21, 2025
37fbae6
use utility
youkaichao Feb 21, 2025
a0b5ab6
fix init
youkaichao Feb 21, 2025
25fdb12
fix?
youkaichao Feb 21, 2025
a3d44d1
fix ca error
youkaichao Feb 21, 2025
5e3b75a
fix communicator
youkaichao Feb 21, 2025
9440691
add v0?
youkaichao Feb 21, 2025
4ac1f34
revert v0
youkaichao Feb 21, 2025
741da58
use eager first
youkaichao Feb 21, 2025
eef9c06
use python3
youkaichao Feb 21, 2025
d7c6212
remove torchrun
youkaichao Feb 21, 2025
7b14349
specify use v1
youkaichao Feb 21, 2025
df0239f
rename need_to_sync_across_dp
youkaichao Feb 21, 2025
216bbb9
encapsulate functions
youkaichao Feb 21, 2025
79404cb
simplify functions
youkaichao Feb 21, 2025
6c04388
simplify functions
youkaichao Feb 21, 2025
7289f6c
Merge branch 'main' into manual_dp
youkaichao Feb 22, 2025
6e5a472
update tests to avoid 0 prompts
youkaichao Feb 22, 2025
c3704a0
update tests to avoid 0 prompts
youkaichao Feb 22, 2025
9586d0c
f string
youkaichao Feb 22, 2025
445b474
list multiply
youkaichao Feb 22, 2025
3d92428
list multiply
youkaichao Feb 22, 2025
0c6b1db
list multiply
youkaichao Feb 22, 2025
779cb33
short line
youkaichao Feb 22, 2025
29e6e60
clean up DP
youkaichao Feb 22, 2025
267cd82
fix port conflict
youkaichao Feb 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ steps:
- tests/compile/test_basic_correctness
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
commands:
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
Expand Down
76 changes: 76 additions & 0 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import os

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port


def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# set devices for each dp_rank
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
GPUs_per_dp_rank))

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank = len(prompts) // dp_size
start = dp_rank * promts_per_rank
end = start + promts_per_rank
prompts = prompts[start:end]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts = ["Placeholder"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is just an example but in practice I guess you'd want to set max_tokens to 1 for any placeholder prompts.

print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")

# Create a sampling params object.
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=16 * (dp_rank + 1))

# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP rank {dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")


if __name__ == "__main__":
from multiprocessing import Process
dp_size = 2
GPUs_per_dp_rank = 2
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
procs = []
for i in range(dp_size):
proc = Process(target=main,
args=(dp_size, i, dp_master_ip, dp_master_port,
GPUs_per_dp_rank))
proc.start()
procs.append(proc)
for proc in procs:
proc.join()
57 changes: 57 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig

import vllm.envs as envs
Expand Down Expand Up @@ -1290,6 +1291,11 @@ class ParallelConfig:

pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.

# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
Expand Down Expand Up @@ -1323,10 +1329,55 @@ class ParallelConfig:
worker_cls: str = "auto"
sd_worker_cls: str = "auto"

# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
world_size_across_dp: int = field(init=False)

rank: int = 0

def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
Comment on lines +1349 to +1350
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the port is already being used by other services?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it will error.

We can document and say we will use more than one port starting from the specified port. And the assumption usually should be fine.

NOTE: even if we only use the specified port, there're still chances that some other services already use that port before we start to use that port. It is unavoidable if we are running multiple services in the same host. But for cloud deployment, where each service runs in a separate container, it should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can just check if this port is being used using socket? So we just keep searching for the next available port

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not feasible because non-zero ranks will directly connect to the specified port, and it does not know if it is the master rank or some other services. and it also needs to wait for some time in case the master rank is not started yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the code in 267cd82, at least vllm's internal port usage will not conflict with the dp master ports.

return answer

def stateless_init_dp_group(self) -> "ProcessGroup":
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group)

# use gloo since the engine process might not have cuda device
dp_group = stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo")

return dp_group

@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished],
dtype=torch.int32,
device="cpu")
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished

def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
Expand All @@ -1344,6 +1395,12 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I'm hitting issues like:

RuntimeError: The server socket has failed to listen on any local network address. port: 29500, useIpv6: 0, code: -98, name: EADDRINUSE, message: address already in use

This is true even if I change the master port with torchrun --master-port .... Currently hacking around it by changing this to self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's strange. I also met it once but then it disappeared.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this disappeared when i remove torchrun in af53b4b

self.world_size_across_dp = self.world_size * self.data_parallel_size

ray_only_devices = ["tpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init__(self,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
if "tp" not in unique_name:
# only tp uses custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (
Expand Down
11 changes: 7 additions & 4 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self,
return

rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
Expand Down Expand Up @@ -201,8 +202,10 @@ def create_shared_buffer(

@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None) -> None:
rank = dist.get_rank(group=group)
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None) -> None:
if rank is None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))

Expand Down Expand Up @@ -298,8 +301,8 @@ def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)

def __del__(self):
self.close()
76 changes: 62 additions & 14 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,13 @@ def get_tp_group() -> GroupCoordinator:

_PP: Optional[GroupCoordinator] = None

_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
assert _DP is not None, ("data parallel group is not initialized")
return _DP


def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
Expand Down Expand Up @@ -811,6 +818,21 @@ def init_distributed_environment(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None and config.parallel_config.data_parallel_size > 1:
parallel_config = config.parallel_config
# adjust to take into account data parallelism
# offset the rank by the data parallel rank
rank = parallel_config.data_parallel_rank * world_size + rank
# adjust the world size to take into account data parallelism
world_size = parallel_config.world_size_across_dp
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = f"tcp://{ip}:{port}" # noqa
logger.info(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size, rank, distributed_init_method)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
Expand Down Expand Up @@ -870,20 +892,28 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)

data_parallel_size = 1
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size

# the layout order is: DP x PP x TP
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa

# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]

# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(group_ranks,
Expand All @@ -893,20 +923,33 @@ def initialize_model_parallel(
group_name="tp")

# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size)
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="pp")

global _DP
assert _DP is None, ("data parallel group is already initialized")
group_ranks = all_ranks.transpose(0,
2).reshape(-1,
data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="dp")

logger.info(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example of the rank assignment for DP=2 x TP=2:

rank 0 in world size 4 is assigned as DP rank 0, PP rank 0, TP rank 0
rank 1 in world size 4 is assigned as DP rank 0, PP rank 0, TP rank 1
rank 2 in world size 4 is assigned as DP rank 1, PP rank 0, TP rank 0
rank 3 in world size 4 is assigned as DP rank 1, PP rank 0, TP rank 1

"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)


def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Expand Down Expand Up @@ -1011,6 +1054,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None

global _DP
if _DP:
_DP.destroy()
_DP = None


def destroy_distributed_environment():
global _WORLD
Expand Down
Loading