diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index aa622f7d423..0792fa76505 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -269,6 +269,7 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_basic.py pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_accuracy.py + pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_mtp.py - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) shell: bash -l {0} diff --git a/tests/e2e/multicard/long_sequence/test_mtp.py b/tests/e2e/multicard/long_sequence/test_mtp.py new file mode 100644 index 00000000000..f42bdf6a93d --- /dev/null +++ b/tests/e2e/multicard/long_sequence/test_mtp.py @@ -0,0 +1,165 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# + +import os + +import pytest + +from tests.e2e.conftest import VllmRunner +from vllm_ascend.utils import vllm_version_is + +os.environ["HCCL_BUFFSIZE"] = "512" + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_pcp_dcp_mtp1_eager(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "wemaster/deepseek_mtp_main_random_bf16" + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128, + speculative_config={ + "num_speculative_tokens": 1, + "method": "deepseek_mtp", + }, + enforce_eager=True, + ) as runner: + runner.generate_greedy(prompts, 32) + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_pcp_dcp_mtp3_eager(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "wemaster/deepseek_mtp_main_random_bf16" + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128, + speculative_config={ + "num_speculative_tokens": 3, + "method": "deepseek_mtp", + }, + enforce_eager=True, + ) as runner: + runner.generate_greedy(prompts, 32) + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_pcp_dcp_mtp3_piecewise_graph(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "wemaster/deepseek_mtp_main_random_bf16" + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128, + speculative_config={ + "num_speculative_tokens": 3, + "method": "deepseek_mtp", + }, + compilation_config={ + "cudagraph_mode": "PIECEWISE", + "cudagraph_capture_sizes": [4, 8, 16], + }, + ) as runner: + runner.generate_greedy(prompts, 32) + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_pcp_dcp_mtp3_full_graph(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "wemaster/deepseek_mtp_main_random_bf16" + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128, + speculative_config={ + "num_speculative_tokens": 3, + "method": "deepseek_mtp", + }, + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 16], + }, + ) as runner: + runner.generate_greedy(prompts, 32) + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_dcp_mtp3_full_graph(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "wemaster/deepseek_mtp_main_random_bf16" + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128, + speculative_config={ + "num_speculative_tokens": 3, + "method": "deepseek_mtp", + }, + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 16], + }, + ) as runner: + runner.generate_greedy(prompts, 32) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 3047b3d9625..9e9eb295282 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -11,6 +11,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_config import init_ascend_config @@ -215,10 +216,23 @@ def test_generate_token_ids(self, mock_cpu_gpu_buffer): mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32) mock_deps.runner.spec_decode_common_attn_metadata = MagicMock() mock_deps.runner.pcp_size = 2 - mock_deps.runner.input_ids_pcp_full = torch.arange(32, - dtype=torch.int32) - mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor( - [0, 8, 16, 24, 32]) + mock_deps.runner.dcp_size = 1 + mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer( + 32, + dtype=torch.int32, + pin_memory=False, + device='cpu', + ) + mock_deps.runner.input_ids_pcp_full.cpu = \ + torch.arange(32, dtype=torch.int32) + mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer( + 5, + dtype=torch.int32, + pin_memory=False, + device='cpu', + ) + mock_deps.runner.query_start_loc_pcp_full.cpu = \ + torch.tensor([0, 8, 16, 24, 32]) mock_deps.positions = torch.arange(16, dtype=torch.int32) mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16) mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1], @@ -232,6 +246,7 @@ def test_generate_token_ids(self, mock_cpu_gpu_buffer): proposer.speculative_config = MagicMock( disable_padded_drafter_batch=False) proposer.pcp_size = mock_deps.runner.pcp_size + proposer.dcp_size = mock_deps.runner.dcp_size proposer.prepare_next_token_ids_padded = MagicMock( return_value=(torch.tensor([101, 200, 302]), 3)) proposer.prepare_inputs_padded = MagicMock( diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index facb8920191..8ff26a6f9fe 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -50,6 +50,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, mock_runner.input_batch = MagicMock() mock_runner.input_batch.num_reqs = num_reqs + mock_runner.speculative_config = None num_computed_tokens = [] num_prompt_tokens = [] @@ -169,23 +170,24 @@ def test_pcp_allgather_restore_idx_slicing(): @pytest.mark.parametrize( - "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", + "tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \ + "pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens", [ # Case 1: prefill only - ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]), + ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]), - # Case 2: mix prefill and decode - ([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, [8, 4, 4]), + # Case 2: mix prefill and decode (with spec decode) + ([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]), # Case 3: request which need to be padded - ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]), + ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]), # Case 4: single request - ([10], 1, [0], [10], 4, 0, [4]), + ([10], 1, [0], [10], 4, 0, 1, [4]), ]) def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, - expected_pcp_tokens): + decode_threshold, expected_pcp_tokens): mock_runner = MagicMock(spec=NPUModelRunner) mock_runner.pcp_size = pcp_size mock_runner.pcp_rank = pcp_rank @@ -201,6 +203,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, mock_runner.num_pcp_pads = [0] * num_reqs mock_runner.arange_np = np.arange(10000) + mock_runner.decode_threshold = decode_threshold mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( mock_runner, NPUModelRunner) @@ -243,6 +246,7 @@ def test_update_tokens_for_pcp_with_padding(): mock_runner.num_pcp_pads = [0, 0, 0] mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) + mock_runner.decode_threshold = 1 mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( mock_runner, NPUModelRunner) @@ -279,6 +283,7 @@ def test_update_tokens_for_pcp_unpad_mask(): mock_runner.num_pcp_pads = [0, 0] mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) + mock_runner.decode_threshold = 1 mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( mock_runner, NPUModelRunner) @@ -369,6 +374,9 @@ def pcp_mtp_mock_runner(): mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer( mock_runner, max_num_tokens, dtype=torch.int32) + mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer( + mock_runner, max_num_reqs, dtype=torch.int32) + mock_runner.decode_threshold = 1 mock_runner.arange_np = np.arange(max_model_len) mock_runner.input_batch = MagicMock() diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index c8e753ad083..2ddc41db1d7 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -27,6 +27,7 @@ split_decodes_and_prefills, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, + get_mtp_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( @@ -92,6 +93,10 @@ def build( num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens + # In dcp only spec decode graph padding case, + # num_actual_tokens_pcp_padded may be less than num_actual_tokens + num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded, + num_actual_tokens) num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp assert num_computed_tokens_of_pcp_dcp is not None @@ -113,15 +118,6 @@ def build( common_attn_metadata.block_table_tensor[:graph_pad_size]) else: block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - # NOTE: Currently, MTP-fullgraph is incompatibility pcp - if self.pcp_size > 1: - num_decodes_flatten = num_decodes * self.decode_threshold - block_table = common_attn_metadata.block_table_tensor[: - num_decodes_flatten - + - num_prefills] - - # NOTE: Currently, MTP-fullgraph is incompatibility pcp slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] input_positions = common_attn_metadata.positions[: @@ -144,6 +140,13 @@ def build( seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] num_computed_tokens_cpu = (seq_lens - query_lens) + # For pcp + spec decode, we flatten seq_lens and block_table + # to avoid irregular spec_attn_mask shape + num_decodes_flatten = query_lens[:num_decodes].sum().item() + block_table = common_attn_metadata.block_table_tensor[: + num_decodes_flatten + + num_prefills] + prefill_metadata = None chunked_context_metadata = None if num_prefills > 0: @@ -201,7 +204,7 @@ def build( dtype=torch.int32) local_context_lens_allranks = torch.tensor( - num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs] + num_computed_tokens_of_pcp_dcp[num_decodes_flatten:] ).reshape(-1, self.dcp_size * self.pcp_size) # Note(qcs): The max local context lengths # padded to `cp_local_block_size`. @@ -280,9 +283,8 @@ def build( cos=cos, pcp_metadata=pcp_metadata, ) - if self.pcp_size > 1: - prefill_metadata.block_table = block_table[ - num_decodes_flatten:, ...] + prefill_metadata.block_table = \ + block_table[num_decodes_flatten:, ...] decode_metadata = None if num_decodes > 0: @@ -293,13 +295,7 @@ def build( max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] - if self.pcp_size > 1: - # For pcp + spec decode, we flatten seq_lens and block_table - # to avoid irregular spec_attn_mask shape - block_table = block_table[:num_decodes_flatten, ...] - else: - block_table = block_table[:num_decodes, ...] - # NOTE: Currently, MTP-fullgraph is incompatibility pcp + block_table = block_table[:num_decodes_flatten, ...] # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. if graph_pad_size > num_decodes and \ self.speculative_config.disable_padded_drafter_batch: @@ -308,8 +304,7 @@ def build( # [bs, pcp_size, dcp_size] num_computed_tokens_of_cp_dcp_array = np.array( - num_computed_tokens_of_pcp_dcp)[:num_decodes * - self.decode_threshold] + num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten] cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank] @@ -1057,8 +1052,11 @@ def _forward_decode_pcp_dcp( "return_lse": True, "calc_type": "calc_type_ring", } - graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() if forward_context.capturing: stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 222002568bc..81c92b3797d 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -67,6 +67,12 @@ class AscendPrefillContextParallelMetadata: pcp_prefill_mask: torch.Tensor = None + # original query_lens before pcp split + query_lens_pcp_full_cpu: torch.Tensor = None + + # original max_query_len before pcp split + max_query_len_pcp_full: int = 0 + @dataclass class AscendCommonAttentionMetadata: @@ -189,6 +195,8 @@ def split_decodes_and_prefills( """ Assuming a reordered batch, finds the boundary between prefill and decode requests. + While pcp > 1, query_lens is split across pcp ranks, so we pass in the + original query_lens and max_query_len to distinguish prefills and decodes. Args: common_attn_metadata: AscendCommonAttentionMetadata object containing the @@ -201,7 +209,13 @@ def split_decodes_and_prefills( num_decode_tokens: The number of tokens in the decode requests. num_prefill_tokens: The number of tokens in the prefill requests. """ - max_query_len = common_attn_metadata.max_query_len + long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata + query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \ + if long_seq_metadata else None + max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \ + if long_seq_metadata else 0 + max_query_len = common_attn_metadata.max_query_len \ + if max_query_len_pcp_full == 0 else max_query_len_pcp_full num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu @@ -209,7 +223,8 @@ def split_decodes_and_prefills( if max_query_len <= decode_threshold: return num_reqs, 0, num_tokens, 0 - query_lens = query_start_loc[1:] - query_start_loc[:-1] + query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \ + if query_lens_pcp_full is None else query_lens_pcp_full is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index c3c52cb2963..9a8862b6599 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -440,7 +440,10 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): - graph_params = get_graph_params() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. with torch.npu.stream(update_stream): diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 14a61a7970e..f17cea1008c 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -32,6 +32,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + update_mla_attn_dcp_pcp_params, update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, @@ -98,6 +99,7 @@ def __init__( self.pcp_size = self.runner.pcp_size self.dcp_size = self.runner.dcp_size self.pcp_rank = self.runner.pcp_rank + self.dcp_rank = self.runner.dcp_rank self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None self.draft_indexer_metadata_builder: Optional[ @@ -267,6 +269,13 @@ def dummy_run(self, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, ) + if self.pcp_size * self.dcp_size > 1: + # update long_seq related params and flatten block_table + common_attn_metadata.prefill_context_parallel_metadata = \ + self.runner.long_seq_metadata + common_attn_metadata.block_table_tensor = \ + self.runner.input_batch.block_table[0].get_device_tensor()[ + :num_reqs * self.decode_threshold] builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata_mtp = builder.build_for_graph_capture( @@ -310,9 +319,15 @@ def dummy_run(self, if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ not forward_context.capturing: if self.vllm_config.model_config.use_mla and not self.use_sparse: - update_mla_attn_params( - self.update_stream, forward_context, num_tokens, - self.vllm_config.speculative_config) + if self.pcp_size * self.dcp_size > 1: + update_mla_attn_dcp_pcp_params( + self.update_stream, forward_context, + num_tokens) + else: + update_mla_attn_params( + self.update_stream, forward_context, + num_tokens, + self.vllm_config.speculative_config) if self.enable_shared_expert_dp: positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( positions, True) @@ -364,11 +379,11 @@ def generate_token_ids(self, valid_sampled_tokens_count) req_scheduled_tokens = scheduler_output.num_scheduled_tokens - if self.pcp_size > 1: + if self.pcp_size * self.dcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] @@ -396,12 +411,11 @@ def generate_token_ids(self, target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.pcp_size > 1: - common_attn_metadata.query_start_loc_cpu = \ + common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \ query_start_loc_pcp_full_cpu[:num_reqs + 1] - common_attn_metadata.query_start_loc = \ + common_attn_metadata.query_start_loc[:num_reqs + 1] = \ query_start_loc_pcp_full[:num_reqs + 1] if self.speculative_config.disable_padded_drafter_batch: - # NOTE: Currently, MTP-fullgraph is incompatibility with pcp token_indices_to_sample = None common_attn_metadata, token_indices =\ self._prepare_inputs( @@ -630,15 +644,18 @@ def _propose( self.input_ids[last_token_indices] = next_token_ids # update pcp related params - if self.pcp_size > 1: + if self.pcp_size * self.dcp_size > 1: assert long_seq_metadata is not None common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata + ori_last_token_indices = last_token_indices.clone() + query_lens_d = self.runner.query_lens[:num_decode_reqs] + if self.pcp_size > 1: # 1. preprocess decode/prefill input_ids & target_hidden_states # decode input_ids: keep unchanged # decode target_hidden_states: remove padding # prefill input_ids: add padding and pcp split # prefill target_hidden_states: pcp split - num_tokens_d = num_decode_reqs * self.decode_threshold + num_tokens_d = query_lens_d.sum().item() num_tokens_d_padded = num_tokens_d * self.pcp_size input_ids_d = self.input_ids[:num_tokens_d] input_ids_p = self.input_ids[num_tokens_d:num_tokens] @@ -646,12 +663,17 @@ def _propose( target_hidden_states[:num_tokens_d_padded] if num_tokens_d: # remove padding (from pcp all-gather) in decode part - target_hidden_states_d = target_hidden_states_d_padded.reshape( - [ - num_decode_reqs, self.decode_threshold * self.pcp_size, - -1 - ])[:, :self.decode_threshold, :].reshape( - [num_tokens_d, -1]) + mask_start_loc = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1] + ]) + mask_len = query_lens_d + mask = [] + for req_id in range(num_decode_reqs): + mask += list( + range(mask_start_loc[req_id], + mask_start_loc[req_id] + mask_len[req_id])) + target_hidden_states_d = target_hidden_states_d_padded[mask] else: target_hidden_states_d = target_hidden_states_d_padded target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] @@ -670,25 +692,26 @@ def _propose( torch.cat([input_ids_d, input_ids_p], dim=0)) target_hidden_states = torch.cat( [target_hidden_states_d, target_hidden_states_p], dim=0) - # 2. update attn_metadata params that may be influenced by pcp - common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max(self.decode_threshold, - max_query_len_p) - common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p - common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p - query_start_loc_p = cu_num_tokens_p[1:] + \ - common_attn_metadata.query_start_loc[num_decode_reqs].item() - common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \ - query_start_loc_p - common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \ - query_start_loc_p - # 3. update sample_indices according to main model + # 2. update sample_indices according to main model if num_decode_reqs: last_token_indices[:num_decode_reqs] = \ self.runner.logits_indices[last_token_indices[:num_decode_reqs]] if num_prefill_reqs: last_token_indices[-num_prefill_reqs:] = \ self.runner.logits_indices[-num_prefill_reqs:] + # 3. update attn_metadata params that may be influenced by pcp + common_attn_metadata.num_actual_tokens = num_tokens + common_attn_metadata.max_query_len = max( + self.decode_threshold, max_query_len_p) + common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p + common_attn_metadata.seq_lens_cpu[ + -num_prefill_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + \ + common_attn_metadata.query_start_loc[num_decode_reqs].item() + common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \ + query_start_loc_p + common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \ + query_start_loc_p assert self.runner is not None @@ -796,10 +819,15 @@ def _propose( forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if self.vllm_config.model_config.use_mla and not self.use_sparse: - update_mla_attn_params( - self.update_stream, forward_context, - num_input_tokens, - self.vllm_config.speculative_config) + if self.pcp_size * self.dcp_size > 1: + update_mla_attn_dcp_pcp_params( + self.update_stream, forward_context, + num_input_tokens) + else: + update_mla_attn_params( + self.update_stream, forward_context, + num_input_tokens, + self.vllm_config.speculative_config) if self.enable_shared_expert_dp: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( @@ -814,7 +842,9 @@ def _propose( last_token_indices, (0, max_num_reqs_across_dp - num_indices)) - if self.pcp_size > 1: + if self.pcp_size > 1 and step == 0: + # remove graph padding before all_gather + hidden_states = hidden_states[:num_tokens] hidden_states = get_pcp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( hidden_states, 0, self.runner. @@ -855,6 +885,51 @@ def _propose( last_token_indices = self.arange[:batch_size] if getattr(attn_metadata_i, "num_decode_tokens", 0): attn_metadata_i.num_decode_tokens = batch_size + if self.pcp_size * self.dcp_size > 1: + positions = target_positions[ori_last_token_indices] + # For pcp/dcp, tokens are split across different cp ranks, + # so we can not simply update slot_mapping by += 1. + # Instead, we pre-allocate mtp slot_mapping in model_runner + # (_generate_pcp_mtp_input), and use updated slot_indices + # to get corresponding slot_mapping in each step. + num_reject_tokens = torch.tensor( + self.runner.cu_num_tokens_pcp_full, + dtype=torch.int32).to( + self.device) - ori_last_token_indices - 1 + num_accept_tokens = \ + query_lens_d.to(self.device) - num_reject_tokens + ori_seq_len = attn_metadata_i.seq_lens + mtp_slot_mapping = self.runner.mtp_slot_pad + + # slot_mapping index base offset: + # scheduled tokens + pre-allocated mtp tokens + accepted tokens + slot_idx_base = ( + torch.cat([ + torch.tensor( + [0], dtype=torch.int32, device=self.device), + (torch.cumsum(query_lens_d, dim=0)[:-1] * + self.pcp_size).to(self.device) + ]) + + torch.arange(num_decode_reqs, device=self.device) * + (self.num_speculative_tokens - 1) * self.pcp_size + + (num_accept_tokens - 1) * self.pcp_size) + slot_indices_list = [] + for req_id in range(num_decode_reqs): + slot_indices_list.append( + torch.arange(slot_idx_base[req_id], + slot_idx_base[req_id] + self.pcp_size, + device=self.device)) + slot_indices = torch.cat(slot_indices_list, dim=0) + + # fold block_table (restore it to original size before flattened) + block_indices = torch.cat([ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(query_lens_d, dim=0)[:-1] + ]) + attn_metadata_i.decode.block_table[:batch_size] = \ + attn_metadata_i.decode.block_table[block_indices] + attn_metadata_i.decode.block_table = \ + attn_metadata_i.decode.block_table[:batch_size] input_ids = draft_token_ids_list[-1].int() positions += 1 @@ -901,13 +976,40 @@ def _propose( # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. slot_mapping += 1 + if self.pcp_size > 1: + exceeds_max_model_len = exceeds_max_model_len.repeat_interleave( + slot_mapping.size(0) // exceeds_max_model_len.size(0)) slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states - attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + if self.pcp_size * self.dcp_size > 1: + # update local seq_len and batch_seq_mask + num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens( + ori_seq_len + step + 1, + self.pcp_size, + self.dcp_size, + self.runner.parallel_config.cp_kv_cache_interleave_size, + ) + cp_seq_len = \ + num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] + batch_seq_mask = (cp_seq_len == 0) + builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( + batch_seq_mask, non_blocking=True) + batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask. + shape[0]] + cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) + attn_metadata_i.decode.cp_seq_len = cp_seq_len + attn_metadata_i.decode.batch_seq_mask = batch_seq_mask + # update slot_mapping + slot_indices += self.pcp_size + slot_mapping = mtp_slot_mapping[slot_indices] + attn_metadata_i.slot_mapping[:batch_size * + self.pcp_size] = slot_mapping + else: + attn_metadata_i.slot_mapping[:batch_size] = slot_mapping if self.speculative_config.disable_padded_drafter_batch: self.positions[batch_size:num_input_tokens] = 0 self.input_ids[batch_size:num_input_tokens] = 0 diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 5b450f6c5bb..9fba9b5292f 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -75,7 +75,7 @@ def __init__(self, logical_table_size = max_num_blocks_per_req duplicate_size = 1 - if self.pcp_world_size > 1: + if self.pcp_world_size * self.dcp_world_size > 1: duplicate_size += num_speculative_tokens self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5e1cc0fc952..570d8a3ef20 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -280,7 +280,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int32, device=self.device) self.num_actual_tokens_pcp_padded = 0 - if self.speculative_config and self.pcp_size > 1: + if self.speculative_config and self.pcp_size * self.dcp_size > 1: self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.query_start_loc_pcp_full = self._make_buffer( @@ -289,8 +289,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int64, device="cpu", pin_memory=True) - self.decode_token_per_req += self.speculative_config.num_speculative_tokens self.positions_pcp_full_np = self.positions_pcp_full.numpy() + self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) self.decode_threshold = 1 + ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0) @@ -575,6 +576,7 @@ def _prepare_inputs( if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: self.generate_kv_idx(scheduler_output) + tokens_before_update = tokens.copy() tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( tokens) num_scheduled_tokens = np.array(tokens, dtype=np.int32) @@ -591,7 +593,8 @@ def _prepare_inputs( num_valid_tokens = np.array([ num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip(tokens, req_ids) + for num_tokens, i in zip((tokens_before_update if self. + pcp_size > 1 else tokens), req_ids) ], dtype=np.int32) @@ -909,7 +912,8 @@ def _prepare_inputs( >= self.input_batch.num_prompt_tokens[req_idx]) else -1) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs]) + num_draft_tokens, cu_num_tokens, + self.num_pcp_pads[:num_reqs].numpy()) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). @@ -931,10 +935,11 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.speculative_config and self.pcp_size > 1: + if self.speculative_config and self.pcp_size * self.dcp_size > 1: self._generate_pcp_mtp_input( num_reqs, scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens) + scheduler_output.num_scheduled_tokens, with_prefill, + req_indices, positions_np, cu_num_tokens) long_seq_metadata = self._generate_pcp_metadata( total_num_scheduled_tokens) @@ -1040,7 +1045,7 @@ def _prepare_inputs( prefill_context_parallel_metadata=long_seq_metadata, ) - if self.speculative_config and self.pcp_size > 1: + if self.speculative_config and self.pcp_size * self.dcp_size > 1: # For pcp + spec decode, we flatten block_table # to avoid irregular spec_attn_mask shape, e.g., # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, @@ -1048,12 +1053,13 @@ def _prepare_inputs( # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \ - self.query_start_loc_pcp_full.cpu[:num_reqs] + ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs] + ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold + num_decode_reqs_flatten = \ + ori_query_lens_cpu[:num_decode_reqs].sum().item() blk_table_tensor[ num_decode_reqs_flatten:num_decode_reqs_flatten + num_prefill_reqs].copy_( @@ -1061,9 +1067,15 @@ def _prepare_inputs( num_prefill_reqs].clone()) blk_table_tensor[:num_decode_reqs_flatten].copy_( blk_table_tensor[:num_decode_reqs].repeat_interleave( - self.decode_threshold, dim=0)) + ori_query_lens[:num_decode_reqs], dim=0)) common_attn_metadata.block_table_tensor = \ blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + if 'pad_size' in locals() and pad_size > 0: + ori_query_lens_cpu[-pad_size:] = \ + torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) + long_seq_metadata.max_query_len_pcp_full = \ + ori_query_lens_cpu.max().item() if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: @@ -1861,7 +1873,7 @@ def _build_dummy_attn_metadata( decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=long_seq_metadata, ) - if self.pcp_size > 1: + if self.pcp_size * self.dcp_size > 1: common_attn_metadata.block_table_tensor = \ block_table_tensor[:num_reqs * self.decode_threshold] attn_state = AscendAttentionState.DecodeOnly @@ -3029,9 +3041,7 @@ def _update_tokens_for_pcp(self, tokens): num_reqs = self.input_batch.num_reqs self.num_pcp_pads = self.num_pcp_pads[:num_reqs] tokens = np.array(tokens, dtype=np.int32) - num_decode_reqs = sum( - self.input_batch.num_computed_tokens_cpu[:num_reqs] >= - self.input_batch.num_prompt_tokens[:num_reqs]) + num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum() num_decode_tokens = sum(tokens[:num_decode_reqs]) num_padded_scheduled_tokens = np.ceil( tokens / @@ -3118,8 +3128,10 @@ def _get_cp_local_seq_lens( def _generate_pcp_metadata(self, total_num_scheduled_tokens): # In dummy run num_reqs == 0, update it from seq_lens num_reqs = self.input_batch.num_reqs or self.query_lens.size(0) - num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] - >= self.input_batch.num_prompt_tokens[:num_reqs]) + query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \ + if self.pcp_size > 1 and self.speculative_config else self.query_lens + num_decodes = (query_lens <= self.decode_threshold).sum().item() + num_prefills = num_reqs - num_decodes num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None @@ -3137,16 +3149,41 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens): dtype=torch.int32, ) # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape + # to avoid irregular spec_attn_mask shape. + # Same as block_table, we flatten decode seq_lens to query_lens, + # and keep prefill seq_lens unchanged. for decode_idx in range(self.decode_threshold): num_computed_tokens_of_pcp_dcp[ self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ self._get_cp_local_seq_lens( - torch.tensor(context_lens), + torch.tensor(context_lens) - decode_idx, self.pcp_size, self.dcp_size, self.parallel_config.cp_kv_cache_interleave_size, ) + if self.decode_threshold > 1: + num_computed_tokens_of_pcp_dcp_list = [] + if num_decodes: + num_decodes_flatten = \ + self.query_lens[:num_decodes].sum().item() + if self.query_lens[:num_decodes].min().item( + ) == self.decode_threshold: + decode_flatten_idx = list(range(num_decodes_flatten)) + else: + decode_flatten_idx = [] + for req_id in range(num_decodes): + offset = (req_id + 1) * self.decode_threshold + decode_flatten_idx += \ + list(range(offset - self.query_lens[req_id], offset)) + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) + if num_prefills: + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[ + (num_decodes + 1) * self.decode_threshold - + 1::self.decode_threshold]) + num_computed_tokens_of_pcp_dcp = torch.cat( + num_computed_tokens_of_pcp_dcp_list, dim=0) long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. @@ -3278,6 +3315,10 @@ def _generate_pcp_mtp_input( num_reqs: int, total_num_scheduled_tokens: int, num_scheduled_tokens: dict[str, int], + with_prefill: bool = True, + req_indices=None, + positions_np=None, + cu_num_tokens=None, ): """ While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, @@ -3288,6 +3329,8 @@ def _generate_pcp_mtp_input( num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) for i, req_id in enumerate(self.input_batch.req_ids): num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] + self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( + num_scheduled_tokens_pcp_full) req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) @@ -3313,11 +3356,45 @@ def _generate_pcp_mtp_input( torch.from_numpy(token_indices_pcp_full), out=self.input_ids_pcp_full. cpu[:total_num_scheduled_tokens_pcp_full]) + self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_( self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full], non_blocking=True, ) + self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full + # For mtpx, pre-allocate mtp slot_mapping here + if self.decode_threshold > 2 and not with_prefill: + num_tokens_ori = sum(list(num_scheduled_tokens.values())) + num_tokens_mtp = \ + num_tokens_ori + num_reqs * (self.decode_threshold - 2) + num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size + req_indices_split = np.array_split(req_indices, + cu_num_tokens)[:num_reqs] + positions_split = np.array_split(positions_np, + cu_num_tokens)[:num_reqs] + for req_idx in range(num_reqs): + ori_req_indice = req_indices_split[req_idx] + ori_position = positions_split[req_idx] + req_indices_split[req_idx] = np.append( + ori_req_indice, + np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) + positions_split[req_idx] = np.append( + ori_position, + np.arange(ori_position[-1] + 1, + ori_position[-1] + self.decode_threshold - 1)) + req_indices_mtp = np.concatenate(req_indices_split) + positions_mtp = np.concatenate(positions_split) + self.input_batch.block_table.compute_slot_mapping( + req_indices_mtp, positions_mtp) + mtp_slot_ori = self.input_batch.block_table.block_tables[ + 0].slot_mapping.cpu[:num_tokens_mtp] + unpad_mask = np.repeat(False, num_tokens_mtp_pad) + unpad_mask[::self.pcp_size] = True + mtp_slot_pad = \ + torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) + mtp_slot_pad[unpad_mask] = mtp_slot_ori + self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) @contextmanager