diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index fce523dab4..78ce101595 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -363,7 +363,10 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + pytest -sv tests/e2e/multicard/test_data_parallel.py + pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ + --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ + --ignore=tests/e2e/multicard/test_data_parallel.py - name: Run vllm-project/vllm-ascend test on V0 engine if: ${{ github.event_name == 'schedule' }} @@ -380,4 +383,7 @@ jobs: pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 - pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py + pytest -sv tests/e2e/multicard/test_data_parallel.py + pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ + --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ + --ignore=tests/e2e/multicard/test_data_parallel.py diff --git a/examples/dp_offline/data_parallel.py b/examples/dp_offline/data_parallel.py deleted file mode 100644 index b06c52d8c5..0000000000 --- a/examples/dp_offline/data_parallel.py +++ /dev/null @@ -1,85 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py -# SPDX-License-Identifier: Apache-2.0 -# usage: -# python examples/offline_inference_data_parallel.py -# we need to have a launcher to create multiple data parallel -# ranks. And each rank will create a vLLM instance to process its own prompts. - -import gc -import os - - -def main(): - dp_rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - dp_size = int(os.environ['WORLD_SIZE']) - master_addr = os.environ['MASTER_ADDR'] - master_port = os.environ['MASTER_PORT'] - tp_size = 1 - etp_size = 1 - - os.environ["VLLM_DP_RANK"] = str(dp_rank) - os.environ["VLLM_DP_SIZE"] = str(dp_size) - os.environ["VLLM_DP_MASTER_IP"] = master_addr - os.environ["VLLM_DP_MASTER_PORT"] = master_port - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join( - str(i) - for i in range(local_rank * tp_size, (local_rank + 1) * tp_size)) - - import torch - from vllm import LLM, SamplingParams - from vllm.distributed.parallel_state import ( - destroy_distributed_environment, destroy_model_parallel) - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] * 4 - - promts_per_rank = len(prompts) // dp_size - start = dp_rank * promts_per_rank - end = start + promts_per_rank - prompts = prompts[start:end] - if len(prompts) == 0: - prompts = ["Placeholder"] - print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts") - num_seqs = len(prompts) - - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=4, - min_tokens=4) - # Create an LLM. - llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - tensor_parallel_size=tp_size, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=num_seqs, - additional_config={ - 'expert_tensor_parallel_size': etp_size, - 'torchair_graph_config': { - 'enabled': False, - }, - }) - - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"DP rank {dp_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") - - del llm - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() - - -if __name__ == "__main__": - main() diff --git a/examples/dp_offline/run_dp.sh b/examples/dp_offline/run_dp.sh deleted file mode 100644 index 405df604a4..0000000000 --- a/examples/dp_offline/run_dp.sh +++ /dev/null @@ -1,19 +0,0 @@ -export HCCL_IF_IP=${local_ip} -export GLOO_SOCKET_IFNAME=${ifname} -export TP_SOCKET_IFNAME=${ifname} -export HCCL_SOCKET_IFNAME=${ifname} - -# dp_size = node_size * dp_per_node -node_size=1 -node_rank=0 -dp_per_node=4 -master_addr=127.0.0.1 -master_port=12345 - -rm -rf ./.torchair_cache/ -rm -rf ./dynamo_* -rm -rf /root/ascend/log/debug/plog/* - -torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \ - --node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \ - data_parallel.py diff --git a/examples/offline_data_parallel.py b/examples/offline_data_parallel.py new file mode 100644 index 0000000000..64084ac69d --- /dev/null +++ b/examples/offline_data_parallel.py @@ -0,0 +1,241 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py +# +""" +Usage: +Single node: + Dense models: + python examples/offline_data_parallel.py \ + --model="Qwen/Qwen2.5-0.5B-Instruct" \ + --dp-size=2 \ + --tp-size=2 + MOE models: + python examples/offline_data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --enable-expert-parallel + +Multi-node: + Node 0 (assume the node has ip of 10.99.48.128): + python examples/offline_data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=0 \ + --enable-expert-parallel \ + --master-addr=10.99.48.128 \ + --master-port=13345 + Node 1: + python examples/offline_data_parallel.py \ + --model="ibm-research/PowerMoE-3b" \ + --dp-size=2 \ + --tp-size=2 \ + --node-size=2 \ + --node-rank=1 \ + --enable-expert-parallel \ + --master-addr=10.99.48.128 \ + --master-port=13345 +""" + +import os +from time import sleep + +from vllm import LLM, SamplingParams +from vllm.utils import get_open_port + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser(description="Data Parallel Inference") + parser.add_argument( + "--model", + type=str, + default="ibm-research/PowerMoE-3b", + help="Model name or path", + ) + parser.add_argument("--dp-size", + type=int, + default=2, + help="Data parallel size") + parser.add_argument("--tp-size", + type=int, + default=1, + help="Tensor parallel size") + parser.add_argument("--node-size", + type=int, + default=1, + help="Total number of nodes") + parser.add_argument("--node-rank", + type=int, + default=0, + help="Rank of the current node") + parser.add_argument("--master-addr", + type=str, + default="", + help="Master node IP address") + parser.add_argument("--master-port", + type=int, + default=0, + help="Master node port") + parser.add_argument("--enforce-eager", + action="store_true", + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action="store_true", + help="Trust remote code.") + parser.add_argument("--enable-expert-parallel", + action="store_true", + help="Enable expert parallel, used in MOE models.") + return parser.parse_args() + + +def main( + model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + GPUs_per_dp_rank, + enable_expert_parallel, + enforce_eager, + trust_remote_code, +): + # DP only support on V1 engine + os.environ["VLLM_USE_V1"] = "1" + os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip + os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) + + # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the + # engine processes. + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 100 + + # with DP, each rank should process different prompts. + # usually all the DP ranks process a full dataset, + # and each rank processes a different part of the dataset. + floor = len(prompts) // dp_size + remainder = len(prompts) % dp_size + + # Distribute prompts into even groups. + def start(rank): + return rank * floor + min(rank, remainder) + + prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)] + if len(prompts) == 0: + # if any rank has no prompts to process, + # we need to set a placeholder prompt + prompts = ["Placeholder"] + print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") + + # Create a sampling params object. + # since we are doing data parallel, every rank can have different + # sampling params. here we set different max_tokens for different + # ranks for demonstration. + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=[16, 20][global_dp_rank % 2]) + + # Create an LLM. + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=enable_expert_parallel, + trust_remote_code=trust_remote_code, + ) + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + if i >= 5: + # print only 5 outputs + break + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") + + # Give engines time to pause their processing loops before exiting. + sleep(1) + + +if __name__ == "__main__": + args = parse_args() + + dp_size = args.dp_size + tp_size = args.tp_size + node_size = args.node_size + node_rank = args.node_rank + + if node_size == 1: + dp_master_ip = "127.0.0.1" + dp_master_port = get_open_port() + else: + dp_master_ip = args.master_addr + dp_master_port = args.master_port + + assert dp_size % node_size == 0, "dp_size should be divisible by node_size" + dp_per_node = dp_size // node_size + + from multiprocessing import Process + + procs = [] + for local_dp_rank, global_dp_rank in enumerate( + range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)): + proc = Process( + target=main, + args=( + args.model, + dp_size, + local_dp_rank, + global_dp_rank, + dp_master_ip, + dp_master_port, + tp_size, + args.enable_expert_parallel, + args.enforce_eager, + args.trust_remote_code, + ), + ) + proc.start() + procs.append(proc) + exit_code = 0 + for proc in procs: + proc.join(timeout=300) + if proc.exitcode is None: + print( + f"Killing process {proc.pid} that didn't stop within 5 minutes." + ) + proc.kill() + exit_code = 1 + elif proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/tests/e2e/multicard/test_data_parallel.py b/tests/e2e/multicard/test_data_parallel.py new file mode 100644 index 0000000000..57f14ac6db --- /dev/null +++ b/tests/e2e/multicard/test_data_parallel.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/multicard/test_data_parallel.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] + + +@pytest.mark.skipif(True, reason="TODO: fix dp timeout error in ci") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) +def test_data_parallel_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "1", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + "--enforce-eager", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode() + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 diff --git a/tests/multicard/test_data_parallel.py b/tests/multicard/test_data_parallel.py deleted file mode 100644 index 6c0a20de97..0000000000 --- a/tests/multicard/test_data_parallel.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Compare the outputs of vLLM with and without aclgraph. - -Run `pytest tests/multicard/test_data_parallel.py`. -""" - -import os - -import pytest - -from tests.conftest import VllmRunner -from tests.model_utils import check_outputs_equal - -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] - - -@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", - reason="Data parallel only support on v1") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [32]) -def test_data_parallel_correctness( - model: str, - max_tokens: int, -) -> None: - example_prompts = [ - "Hello, my name is", "The president of the United States is", - "The capital of France is", "The future of AI is" - ] - - with VllmRunner(model_name=model, - max_model_len=1024, - max_num_seqs=16, - data_parallel_size=2, - distributed_executor_backend="mp") as vllm_model: - vllm_dp_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - with VllmRunner( - model_name=model, - max_model_len=1024, - max_num_seqs=16, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=vllm_outputs, - outputs_1_lst=vllm_dp_outputs, - name_0="vllm_outputs", - name_1="vllm_dp_outputs", - ) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 0aac026a44..5451508c81 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -119,6 +119,10 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor + + # max value of number of tokens across dp group + max_num_tokens_across_dp: int = 0 + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -155,6 +159,7 @@ def build(self, num_actual_tokens, max_query_len, common_prefix_len, + max_num_tokens_across_dp: int = 0, with_prefill_across_dp: bool = False): block_table = self.runner.input_batch.block_table[0].get_device_tensor( @@ -192,6 +197,7 @@ def build(self, slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, + max_num_tokens_across_dp=max_num_tokens_across_dp, with_prefill_across_dp=with_prefill_across_dp) return attn_metadata