Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)

from ..utils import init_test_distributed_environment, multi_process_parallel

Expand All @@ -38,6 +39,33 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tp_size)
]

index = rank % tp_size
partition_size = num_elements // tp_size
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
t = all_tensors[index]
t = tensor_model_parallel_reduce_scatter(t)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
Expand Down Expand Up @@ -178,6 +206,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tesor_parallel_sequence_parallel(tp_size, test_target):
multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
Expand All @@ -199,3 +238,17 @@ def test_multi_process_pipeline_parallel(pp_size, test_target):
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_sequence_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)
92 changes: 90 additions & 2 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def use_v0_only(monkeypatch):
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
sp_enabled: bool
eager_mode: bool
chunked_prefill: bool

Expand Down Expand Up @@ -81,22 +82,27 @@ def detailed(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=False,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=False,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
sp_enabled=False,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
sp_enabled=False,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
sp_enabled=False,
eager_mode=True,
chunked_prefill=False),
],
Expand All @@ -121,8 +127,9 @@ def fast(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=False,
eager_mode=True,
chunked_prefill=False),
chunked_prefill=False)
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
Expand All @@ -131,6 +138,42 @@ def fast(
load_format=load_format),
)

@staticmethod
def sp(
*,
tp_base: int = 2,
pp_base: int = 1,
task: TaskOption = "auto",
multi_node_only: bool = False,
load_format: Optional[str] = None,
):
return PPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
sp_enabled=True,
eager_mode=False,
chunked_prefill=True),

# current sp doesn't support combination with pp
# ParallelSetup(tp_size=2 * tp_base,
# pp_size=2 * pp_base,
# sp_enabled=True,
# eager_mode=True,
# chunked_prefill=False),
],
distributed_backends=["mp", "mp"],
vllm_major_versions=["0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)

def iter_params(self, model_id: str):
opts = self.test_options

Expand Down Expand Up @@ -271,10 +314,10 @@ def _compare_tp(
(
tp_size,
pp_size,
sp_enabled,
eager_mode,
chunked_prefill,
) = parallel_setup

multi_node_only, load_format = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
Expand Down Expand Up @@ -360,6 +403,9 @@ def _compare_tp(
distributed_backend,
]

if sp_enabled:
pp_args.append("--enable-sequence-parallel")

# compare without pipeline parallelism
# NOTE: use mp backend for TP
# PP tests might involve multiple nodes, and ray might
Expand Down Expand Up @@ -469,3 +515,45 @@ def test_tp_multimodal_generation(
num_gpus_available,
method="generate",
is_multimodal=True)


SP_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.sp(),
}

SP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"meta-llama/Llama-3.2-1B-Instruct",
]


@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"task", "test_options"),
[
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in SP_TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_sp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
method="generate",
is_multimodal=True)
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,8 @@ class ParallelConfig:
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
enable_sequence_parallel: bool = False # Enable sequence parallelism.

# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
Expand Down Expand Up @@ -2134,6 +2136,8 @@ def create_draft_parallel_config(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
enable_sequence_parallel=target_parallel_config.
enable_sequence_parallel,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
Expand Down
5 changes: 5 additions & 0 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
return get_tp_group().all_reduce(input_)


def tensor_model_parallel_reduce_scatter(input_: torch.Tensor) -> torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
return get_tp_group().reduce_scatter(input_)


def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
Expand Down
15 changes: 15 additions & 0 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_

def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor:
input_size = input_.size()
assert input_size[0] % self.world_size == 0, (
f"reduce scatter doesn't work when input size {input_size} is not "
f"divisible by world size {self.world_size}")
output_size = (input_size[0] // self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
dist.reduce_scatter_tensor(output_tensor,
input_,
group=self.device_group)
return output_tensor

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
Expand Down
7 changes: 7 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return self.device_communicator.all_reduce(input_)

def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

return self.device_communicator.reduce_scatter(input_)

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
enable_sequence_parallel: bool = False
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
Expand Down Expand Up @@ -434,6 +435,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--enable-sequence-parallel',
'-sp',
action='store_true',
default=False,
help='If enable sequence parallel')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
Expand Down Expand Up @@ -1242,6 +1248,7 @@ def create_engine_config(
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
enable_sequence_parallel=self.enable_sequence_parallel,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class LLM:
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
enable_sequence_parallel: Enable sequence parallelism on top of tensor
parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
Expand Down Expand Up @@ -164,6 +166,7 @@ def __init__(
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
enable_sequence_parallel: bool = False,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
Expand Down Expand Up @@ -219,6 +222,7 @@ def __init__(
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
enable_sequence_parallel=enable_sequence_parallel,
dtype=dtype,
quantization=quantization,
revision=revision,
Expand Down
Loading