Skip to content

Commit 3aa3b46

Browse files
authored
[V1][PP] Support pp with ray backend in V1 (#1800)
### What this PR does / why we need it? Support pipeline parallel with ray backend in V1Engine. Fixes #1751 ### Does this PR introduce _any_ user-facing change? Users could specify ray as distributed backend when inferencing with pp ### How was this patch tested? CI passed with new added test. - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@32142b3 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 9a3bdf2 commit 3aa3b46

File tree

5 files changed

+32
-18
lines changed

5 files changed

+32
-18
lines changed

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ pytest >= 6.0
66
pytest-asyncio
77
pytest-mock
88
lm-eval
9-
ray
109
types-jsonschema
1110
xgrammar
1211
zmq
1312
types-psutil
1413
pytest-cov
1514
regex
1615
sentence_transformers
16+
ray>=2.47.1
17+
protobuf==4.25.6

tests/e2e/multicard/test_pipeline_parallel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
TENSOR_PARALLELS = [2]
2626
PIPELINE_PARALLELS = [2]
27+
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
2728

2829
prompts = [
2930
"Hello, my name is",
@@ -34,10 +35,13 @@
3435
@pytest.mark.parametrize("model", MODELS)
3536
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
3637
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
37-
def test_models(model: str, tp_size: int, pp_size: int) -> None:
38+
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
39+
def test_models(model: str, tp_size: int, pp_size: int,
40+
distributed_executor_backend: str) -> None:
3841
with VllmRunner(model,
3942
tensor_parallel_size=tp_size,
4043
pipeline_parallel_size=pp_size,
44+
distributed_executor_backend=distributed_executor_backend,
4145
enforce_eager=True,
4246
gpu_memory_utilization=0.7) as vllm_model:
4347
vllm_model.generate_greedy(prompts, 64)

tests/ut/attention/test_attention_v1.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,13 @@ def test_forward_head_size_192(self, mock_vanilla_prefill,
400400
layer = self.layer_no_quant
401401
mock_vanilla_prefill.return_value = MagicMock()
402402

403-
def mock_tensor(data, device=None, **kwargs):
404-
if device == "npu":
405-
return metadata.attn_mask
406-
return torch.tensor(data, **kwargs)
407-
408-
with patch("torch.tensor", side_effect=mock_tensor):
409-
output = self.impl_192.forward(layer,
410-
query,
411-
key,
412-
value,
413-
kv_cache,
414-
metadata,
415-
trace_flag=False)
403+
output = self.impl_192.forward(layer,
404+
query,
405+
key,
406+
value,
407+
kv_cache,
408+
metadata,
409+
trace_flag=False)
416410

417411
mock_vanilla_prefill.assert_called_once()
418412
assert output.shape == (10, 8 * 192)

vllm_ascend/attention/attention_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,10 @@ def forward(
396396
if self.head_size == 192:
397397
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
398398
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
399-
cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu")
400-
cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu")
399+
cu_seqlen_q = torch.tensor(cu_seqlen_q,
400+
device=query.device)
401+
cu_seqlen_k = torch.tensor(cu_seqlen_k,
402+
device=query.device)
401403
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
402404
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
403405
max_seqlen_q = torch.max(attn_metadata.query_lens)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
233233
self.spec_attn_mask = torch.triu(torch.ones(2048,
234234
2048,
235235
dtype=torch.bool),
236-
diagonal=1).to("npu")
236+
diagonal=1).to(self.device)
237237
if get_pp_group().is_last_rank:
238238
if self.speculative_config.method == "ngram":
239239
self.drafter = NgramProposer(self.vllm_config)
@@ -1120,6 +1120,19 @@ def _process_reqs(
11201120
input_ids = self.input_ids[:padded_batch_size]
11211121
positions = self.positions[:padded_batch_size]
11221122

1123+
if get_pp_group().is_first_rank:
1124+
intermediate_tensors = None
1125+
else:
1126+
assert intermediate_tensors is not None
1127+
assert self.intermediate_tensors is not None
1128+
for k, v in intermediate_tensors.items():
1129+
self.intermediate_tensors[k][:num_input_tokens].copy_(
1130+
v[:num_input_tokens], non_blocking=True)
1131+
intermediate_tensors = IntermediateTensors({
1132+
k: v[:num_input_tokens]
1133+
for k, v in self.intermediate_tensors.items()
1134+
})
1135+
11231136
# Run forward pass
11241137
with set_forward_context(attn_metadata,
11251138
self.vllm_config,

0 commit comments

Comments
 (0)