Skip to content

Commit 88bc4b0

Browse files
committed
support sequence parallel
Signed-off-by: cascade812 <cascade812@outlook.com>
1 parent 61c6a5a commit 88bc4b0

File tree

15 files changed

+306
-24
lines changed

15 files changed

+306
-24
lines changed

tests/distributed/test_comm_ops.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
1313
tensor_model_parallel_all_gather,
14-
tensor_model_parallel_all_reduce)
14+
tensor_model_parallel_all_reduce,
15+
tensor_model_parallel_reduce_scatter)
1516

1617
from ..utils import init_test_distributed_environment, multi_process_parallel
1718

@@ -38,6 +39,33 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
3839
torch.testing.assert_close(t, expected)
3940

4041

42+
@ray.remote(num_gpus=1, max_calls=1)
43+
def reduce_scatter_test_worker(tp_size: int, pp_size: int, rank: int,
44+
distributed_init_port: str):
45+
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
46+
# so that each worker can see all the GPUs
47+
# they will be able to set the device to the correct GPU
48+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
49+
device = torch.device(f"cuda:{rank}")
50+
torch.cuda.set_device(device)
51+
init_test_distributed_environment(tp_size, pp_size, rank,
52+
distributed_init_port)
53+
54+
num_elements = 8
55+
all_tensors = [
56+
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
57+
(r + 1) for r in range(tp_size)
58+
]
59+
60+
index = rank % tp_size
61+
partition_size = num_elements // tp_size
62+
all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
63+
expected = all_reduce[index * partition_size:(index + 1) * partition_size]
64+
t = all_tensors[index]
65+
t = tensor_model_parallel_reduce_scatter(t)
66+
torch.testing.assert_close(t, expected)
67+
68+
4169
@ray.remote(num_gpus=1, max_calls=1)
4270
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
4371
distributed_init_port: str):
@@ -178,6 +206,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
178206
multi_process_parallel(tp_size, 1, test_target)
179207

180208

209+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
210+
reason="Need at least 2 GPUs to run the test.")
211+
@pytest.mark.parametrize("tp_size", [2])
212+
@pytest.mark.parametrize("test_target", [
213+
all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker,
214+
broadcast_tensor_dict_test_worker
215+
])
216+
def test_multi_process_tesor_parallel_sequence_parallel(tp_size, test_target):
217+
multi_process_parallel(tp_size, 1, test_target)
218+
219+
181220
@pytest.mark.skipif(torch.cuda.device_count() < 2,
182221
reason="Need at least 2 GPUs to run the test.")
183222
@pytest.mark.parametrize("pp_size", [2])
@@ -199,3 +238,17 @@ def test_multi_process_pipeline_parallel(pp_size, test_target):
199238
def test_multi_process_tensor_parallel_pipeline_parallel(
200239
tp_size, pp_size, test_target):
201240
multi_process_parallel(tp_size, pp_size, test_target)
241+
242+
243+
@pytest.mark.skipif(torch.cuda.device_count() < 4,
244+
reason="Need at least 4 GPUs to run the test.")
245+
@pytest.mark.parametrize("tp_size", [2])
246+
@pytest.mark.parametrize("pp_size", [2])
247+
@pytest.mark.parametrize("test_target", [
248+
send_recv_test_worker, send_recv_tensor_dict_test_worker,
249+
all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker,
250+
broadcast_tensor_dict_test_worker
251+
])
252+
def test_multi_process_tensor_parallel_sequence_parallel_pipeline_parallel(
253+
tp_size, pp_size, test_target):
254+
multi_process_parallel(tp_size, pp_size, test_target)

tests/distributed/test_pipeline_parallel.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def use_v0_only(monkeypatch):
3939
class ParallelSetup(NamedTuple):
4040
tp_size: int
4141
pp_size: int
42+
sp_enabled: bool
4243
eager_mode: bool
4344
chunked_prefill: bool
4445

@@ -81,22 +82,27 @@ def detailed(
8182
parallel_setups=[
8283
ParallelSetup(tp_size=tp_base,
8384
pp_size=pp_base,
85+
sp_enabled=False,
8486
eager_mode=False,
8587
chunked_prefill=False),
8688
ParallelSetup(tp_size=tp_base,
8789
pp_size=2 * pp_base,
90+
sp_enabled=False,
8891
eager_mode=False,
8992
chunked_prefill=True),
9093
ParallelSetup(tp_size=tp_base,
9194
pp_size=2 * pp_base,
95+
sp_enabled=False,
9296
eager_mode=True,
9397
chunked_prefill=False),
9498
ParallelSetup(tp_size=2 * tp_base,
9599
pp_size=pp_base,
100+
sp_enabled=False,
96101
eager_mode=False,
97102
chunked_prefill=True),
98103
ParallelSetup(tp_size=2 * tp_base,
99104
pp_size=pp_base,
105+
sp_enabled=False,
100106
eager_mode=True,
101107
chunked_prefill=False),
102108
],
@@ -121,8 +127,9 @@ def fast(
121127
parallel_setups=[
122128
ParallelSetup(tp_size=tp_base,
123129
pp_size=pp_base,
130+
sp_enabled=False,
124131
eager_mode=True,
125-
chunked_prefill=False),
132+
chunked_prefill=False)
126133
],
127134
distributed_backends=["mp"],
128135
vllm_major_versions=["0"],
@@ -131,6 +138,42 @@ def fast(
131138
load_format=load_format),
132139
)
133140

141+
@staticmethod
142+
def sp(
143+
*,
144+
tp_base: int = 2,
145+
pp_base: int = 1,
146+
task: TaskOption = "auto",
147+
multi_node_only: bool = False,
148+
load_format: Optional[str] = None,
149+
):
150+
return PPTestSettings(
151+
parallel_setups=[
152+
ParallelSetup(tp_size=tp_base,
153+
pp_size=pp_base,
154+
sp_enabled=True,
155+
eager_mode=False,
156+
chunked_prefill=False),
157+
ParallelSetup(tp_size=2 * tp_base,
158+
pp_size=pp_base,
159+
sp_enabled=True,
160+
eager_mode=False,
161+
chunked_prefill=True),
162+
163+
# current sp doesn't support combination with pp
164+
# ParallelSetup(tp_size=2 * tp_base,
165+
# pp_size=2 * pp_base,
166+
# sp_enabled=True,
167+
# eager_mode=True,
168+
# chunked_prefill=False),
169+
],
170+
distributed_backends=["mp", "mp"],
171+
vllm_major_versions=["0", "1"],
172+
task=task,
173+
test_options=PPTestOptions(multi_node_only=multi_node_only,
174+
load_format=load_format),
175+
)
176+
134177
def iter_params(self, model_id: str):
135178
opts = self.test_options
136179

@@ -271,10 +314,10 @@ def _compare_tp(
271314
(
272315
tp_size,
273316
pp_size,
317+
sp_enabled,
274318
eager_mode,
275319
chunked_prefill,
276320
) = parallel_setup
277-
278321
multi_node_only, load_format = test_options
279322

280323
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
@@ -360,6 +403,9 @@ def _compare_tp(
360403
distributed_backend,
361404
]
362405

406+
if sp_enabled:
407+
pp_args.append("--enable-sequence-parallel")
408+
363409
# compare without pipeline parallelism
364410
# NOTE: use mp backend for TP
365411
# PP tests might involve multiple nodes, and ray might
@@ -469,3 +515,45 @@ def test_tp_multimodal_generation(
469515
num_gpus_available,
470516
method="generate",
471517
is_multimodal=True)
518+
519+
520+
SP_TEXT_GENERATION_MODELS = {
521+
# [Decoder-only]
522+
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.sp(),
523+
}
524+
525+
SP_TEST_MODELS = [
526+
# TODO support other models
527+
# [LANGUAGE GENERATION]
528+
"meta-llama/Llama-3.2-1B-Instruct",
529+
]
530+
531+
532+
@pytest.mark.parametrize(
533+
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
534+
"task", "test_options"),
535+
[
536+
params for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
537+
for params in settings.iter_params(model_id)
538+
if model_id in SP_TEST_MODELS
539+
],
540+
)
541+
@fork_new_process_for_each_test
542+
def test_tp_sp_generation(
543+
model_id: str,
544+
parallel_setup: ParallelSetup,
545+
distributed_backend: str,
546+
vllm_major_version: str,
547+
task: TaskOption,
548+
test_options: PPTestOptions,
549+
num_gpus_available,
550+
):
551+
_compare_tp(model_id,
552+
parallel_setup,
553+
distributed_backend,
554+
vllm_major_version,
555+
task,
556+
test_options,
557+
num_gpus_available,
558+
method="generate",
559+
is_multimodal=True)

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,8 @@ class ParallelConfig:
13421342
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
13431343
data_parallel_size: int = 1 # Number of data parallel groups.
13441344
data_parallel_rank: int = 0 # Rank of the data parallel group.
1345+
enable_sequence_parallel: bool = False # Enable sequence parallelism.
1346+
13451347
# IP of the data parallel master.
13461348
data_parallel_master_ip: str = "127.0.0.1"
13471349
data_parallel_master_port: int = 29500 # Port of the data parallel master.
@@ -2134,6 +2136,8 @@ def create_draft_parallel_config(
21342136
pipeline_parallel_size=target_parallel_config.
21352137
pipeline_parallel_size,
21362138
tensor_parallel_size=speculative_draft_tensor_parallel_size,
2139+
enable_sequence_parallel=target_parallel_config.
2140+
enable_sequence_parallel,
21372141
distributed_executor_backend=target_parallel_config.
21382142
distributed_executor_backend,
21392143
max_parallel_loading_workers=target_parallel_config.

vllm/distributed/communication_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
1313
return get_tp_group().all_reduce(input_)
1414

1515

16+
def tensor_model_parallel_reduce_scatter(input_: torch.Tensor) -> torch.Tensor:
17+
"""Reduce-scatter the input tensor across model parallel group."""
18+
return get_tp_group().reduce_scatter(input_)
19+
20+
1621
def tensor_model_parallel_all_gather(input_: torch.Tensor,
1722
dim: int = -1) -> torch.Tensor:
1823
"""All-gather the input tensor across model parallel group."""

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
3535
dist.all_reduce(input_, group=self.device_group)
3636
return input_
3737

38+
def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor:
39+
input_size = input_.size()
40+
assert input_size[0] % self.world_size == 0, (
41+
f"reduce scatter doesn't work when input size {input_size} is not "
42+
f"divisible by world size {self.world_size}")
43+
output_size = (input_size[0] // self.world_size, ) + input_size[1:]
44+
# Allocate output tensor.
45+
output_tensor = torch.empty(output_size,
46+
dtype=input_.dtype,
47+
device=input_.device)
48+
dist.reduce_scatter_tensor(output_tensor,
49+
input_,
50+
group=self.device_group)
51+
return output_tensor
52+
3853
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
3954
if dim < 0:
4055
# Convert negative dim to positive.

vllm/distributed/parallel_state.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,13 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
312312
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
313313
return self.device_communicator.all_reduce(input_)
314314

315+
def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor:
316+
# Bypass the function if we are using only 1 GPU.
317+
if self.world_size == 1:
318+
return input_
319+
320+
return self.device_communicator.reduce_scatter(input_)
321+
315322
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
316323
world_size = self.world_size
317324
# Bypass the function if we are using only 1 GPU.

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class EngineArgs:
114114
# number of P/D disaggregation (or other disaggregation) workers
115115
pipeline_parallel_size: int = 1
116116
tensor_parallel_size: int = 1
117+
enable_sequence_parallel: bool = False
117118
enable_expert_parallel: bool = False
118119
max_parallel_loading_workers: Optional[int] = None
119120
block_size: Optional[int] = None
@@ -434,6 +435,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
434435
type=int,
435436
default=EngineArgs.tensor_parallel_size,
436437
help='Number of tensor parallel replicas.')
438+
parser.add_argument('--enable-sequence-parallel',
439+
'-sp',
440+
action='store_true',
441+
default=False,
442+
help='If enable sequence parallel')
437443
parser.add_argument(
438444
'--enable-expert-parallel',
439445
action='store_true',
@@ -1242,6 +1248,7 @@ def create_engine_config(
12421248
parallel_config = ParallelConfig(
12431249
pipeline_parallel_size=self.pipeline_parallel_size,
12441250
tensor_parallel_size=self.tensor_parallel_size,
1251+
enable_sequence_parallel=self.enable_sequence_parallel,
12451252
enable_expert_parallel=self.enable_expert_parallel,
12461253
max_parallel_loading_workers=self.max_parallel_loading_workers,
12471254
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/entrypoints/llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class LLM:
7474
environments.
7575
tensor_parallel_size: The number of GPUs to use for distributed
7676
execution with tensor parallelism.
77+
enable_sequence_parallel: Enable sequence parallelism on top of tensor
78+
parallelism.
7779
dtype: The data type for the model weights and activations. Currently,
7880
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
7981
the `torch_dtype` attribute specified in the model config file.
@@ -164,6 +166,7 @@ def __init__(
164166
trust_remote_code: bool = False,
165167
allowed_local_media_path: str = "",
166168
tensor_parallel_size: int = 1,
169+
enable_sequence_parallel: bool = False,
167170
dtype: str = "auto",
168171
quantization: Optional[str] = None,
169172
revision: Optional[str] = None,
@@ -219,6 +222,7 @@ def __init__(
219222
trust_remote_code=trust_remote_code,
220223
allowed_local_media_path=allowed_local_media_path,
221224
tensor_parallel_size=tensor_parallel_size,
225+
enable_sequence_parallel=enable_sequence_parallel,
222226
dtype=dtype,
223227
quantization=quantization,
224228
revision=revision,

0 commit comments

Comments
 (0)