diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 1994b722aa..4dbdefb631 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -1,11 +1,15 @@ from __future__ import annotations +import os + import pytest from vllm import SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode from tests.e2e.conftest import VllmRunner +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + @pytest.fixture def sampling_config(): @@ -17,12 +21,12 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" -def mtp_correctness( - sampling_config: SamplingParams, - model_name: str, - num_speculative_tokens: int, - graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, -): +def mtp_correctness(sampling_config: SamplingParams, + model_name: str, + num_speculative_tokens: int, + graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, + enforce_eager=False, + disable_padded_drafter_batch=True): example_prompts = [ "Hello, my name is", "The president of the United States is", @@ -37,7 +41,7 @@ def mtp_correctness( tensor_parallel_size=1, gpu_memory_utilization=0.7, max_model_len=256, - enforce_eager=False) as ref_llm: + enforce_eager=enforce_eager) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) graph_mode_str = "PIECEWISE" @@ -54,8 +58,9 @@ def mtp_correctness( speculative_config={ "method": "deepseek_mtp", "num_speculative_tokens": num_speculative_tokens, + "disable_padded_drafter_batch": disable_padded_drafter_batch, }, - enforce_eager=False, + enforce_eager=enforce_eager, max_model_len=2000, compilation_config=CompilationConfig( cudagraph_mode=graph_mode_str), @@ -82,6 +87,20 @@ def mtp_correctness( del spec_llm +def test_mtp1_correctness_eager( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1, enforce_eager=True) + + +def test_mtp2_correctness_eager( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 2, enforce_eager=True) + + @pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed") def test_mtp1_correctness_piecewise_graph( sampling_config: SamplingParams, @@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph( model_name: str, ): mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) + + +def test_mtp1_correctness_eager_with_pad( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, + model_name, + 1, + enforce_eager=True, + disable_padded_drafter_batch=False) + + +def test_mtp2_correctness_eager_with_pad( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, + model_name, + 2, + enforce_eager=True, + disable_padded_drafter_batch=False) + + +@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed") +def test_mtp1_correctness_piecewise_graph_with_pad( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, + model_name, + 1, + disable_padded_drafter_batch=False) + + +@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed") +def test_mtp2_correctness_piecewise_graph_with_pad( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, + model_name, + 2, + disable_padded_drafter_batch=False) diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 64076c2bda..3e17944c5c 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -19,14 +19,21 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer +from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer -def get_spec_decode_method(method, vllm_config, device, runner): +def get_spec_decode_method(method, + vllm_config, + device, + runner, + is_torchair_graph=False): if method == "ngram": return NgramProposer(vllm_config, device, runner) elif method in ["eagle", "eagle3"]: return EagleProposer(vllm_config, device, runner) elif method == 'deepseek_mtp': + if is_torchair_graph: + return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 443bf6745b..17274a51fb 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,37 +1,41 @@ -import types +from typing import Optional +import numpy as np import torch import torch.nn as nn -import torchair -from torchair import patch_for_hcom from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import \ process_weights_after_loading from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType -from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ - TorchairDeepSeekMTP -from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, - TorchairCommonAttentionMetadata) from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, vllm_version_is) if vllm_version_is("0.11.0"): from vllm.model_executor.model_loader.utils import set_default_torch_dtype + from vllm.utils import is_pin_memory_available else: + from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import set_default_torch_dtype +logger = init_logger(__name__) + PADDING_SLOT_ID = -1 @@ -45,34 +49,77 @@ def __init__( ): self.name = SpecDcodeType.MTP self.vllm_config = vllm_config - self.device = device - self.runner = runner - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None + self.draft_model_config = self.speculative_config.draft_model_config + self.method = self.speculative_config.method - # persistent buffers for graph - self.input_ids = torch.zeros(self.runner.max_num_tokens, + self.runner = runner + self.device = device + self.dtype = vllm_config.model_config.dtype + self.max_model_len = vllm_config.model_config.max_model_len + self.block_size = vllm_config.cache_config.block_size + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.token_arange_np = np.arange(self.max_num_tokens) + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() + + self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.draft_indexer_metadata_builder: Optional[ + AttentionMetadataBuilder] = None + self.attn_layer_names: list[str] = [] + self.indexer_layer_names: list[str] = [] + + self.use_aclgraph = self.runner._use_aclgraph() + + self.cudagraph_batch_sizes = (list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_aclgraph else []) + + # persistent buffers for aclgraph graph + self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.runner.max_num_tokens, - dtype=torch.int64, - device=self.device) + device=device) + self.uses_mrope = self.vllm_config.model_config.uses_mrope + if self.uses_mrope: + # M-RoPE need (3, max_num_tokens) + self.mrope_positions = torch.zeros((3, self.max_num_tokens), + dtype=torch.int64, + device=device) + else: + # RoPE need (max_num_tokens,) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) self.hidden_states = torch.zeros( - (self.runner.max_num_tokens, - vllm_config.model_config.get_hidden_size()), - dtype=self.runner.dtype, - device=self.device) - self.torchair_compiled_model = None # type: ignore - self.torchair_compiled_models = {} # type: ignore - self.torchair_graph_enabled = get_ascend_config( - ).torchair_graph_config.enabled - self.enable_shared_expert_dp = get_ascend_config( - ).enable_shared_expert_dp + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=self.runner.device, + max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, dtype=torch.int32) + + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) self.use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") @@ -89,14 +136,8 @@ def load_model(self, model) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - if self.torchair_graph_enabled or ( - self.enable_shared_expert_dp - and self.vllm_config.model_config.use_mla): - self.model = TorchairDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - else: - self.model = DeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) + self.model = DeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) draft_attn_layer_names = (get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase).keys() - @@ -121,34 +162,17 @@ def dummy_run(self, num_tokens_across_dp=None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None) -> None: - if not self.torchair_graph_enabled: - ( - num_tokens, - num_tokens_across_dp, - with_prefill, - ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) + + ( + num_tokens, + num_tokens_across_dp, + with_prefill, + ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) - is_running_torchair = self.torchair_graph_enabled and \ - not with_prefill - - if is_running_torchair: - skip_attn = False - if skip_attn: - attn_metadata = None - else: - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - decode_token_per_req=self.runner.decode_token_per_req, - ) - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - common_attn_metadata) + attn_metadata = None input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] @@ -166,40 +190,14 @@ def dummy_run(self, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor): - if is_running_torchair: - assert attn_metadata is not None - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(previous_hidden_states) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static( - attn_metadata.decode.input_positions) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - torchair_compiled_model( - input_ids=input_ids, - positions=positions, - hidden_states=previous_hidden_states, - inputs_embeds=None, - intermediate_tensors=None, - attn_metadata=attn_metadata, - kv_caches=self.runner.kv_caches[-1:], - spec_step_idx=0) - else: - self.model(input_ids=input_ids, - positions=positions, - hidden_states=previous_hidden_states) + self.model(input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states) if with_prefill: break def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], + sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, @@ -208,235 +206,240 @@ def generate_token_ids(self, hidden_states: torch.Tensor = None, attn_metadata=None, aux_hidden_states: torch.Tensor = None): + common_attn_metadata = self.runner.spec_decode_common_attn_metadata if attn_metadata is not None and isinstance(attn_metadata, dict): attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.runner.input_batch.req_ids[i] - req_state = self.runner.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - accepted_token_indices = None + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.prepare_next_token_ids_cpu( + sampled_token_ids, self.runner.requests, + self.runner.input_batch, scheduler_output.num_scheduled_tokens) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ + self.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.runner.requests, + self.runner.input_batch, + self.runner.discard_request_indices.gpu, + self.runner.num_discarded_requests + ) + if spec_decode_metadata is None: + token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.runner.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - cu_num_tokens, accepted_token_indices, target_token_ids, \ - target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs( - attn_metadata.query_start_loc, - num_rejected_tokens, - self.runner.input_ids[:num_scheduled_tokens], - positions[:num_scheduled_tokens], - hidden_states[:num_scheduled_tokens], - attn_metadata.slot_mapping[:num_scheduled_tokens], - is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(), - ) + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self._prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + target_token_ids = self.runner.input_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] draft_token_ids = self._propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_tables, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, - token_indices=accepted_token_indices) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + ) + + return draft_token_ids def _prepare_inputs( self, - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - token_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - slot_mapping: torch.Tensor, - is_torchair_graph: bool = False - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - if is_torchair_graph: - cu_num_tokens = cu_target_query_lens - relative_index = query_len_per_req - num_rejected_tokens - 1 - token_indices = cu_num_tokens[:-1] + relative_index - # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model - target_token_ids = token_ids - target_positions = positions - target_hidden_states = hidden_states - target_slot_mapping = slot_mapping - else: - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.zeros( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - BLOCK_SIZE = 1024 - self._prepare_input_kernel( - token_indices, - cu_target_query_lens, - cu_num_tokens, - block_size=BLOCK_SIZE, - ) - target_token_ids = token_ids[token_indices] - target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = slot_mapping[token_indices] - return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + ) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = (self.token_arange_np[:total_num_tokens] - + new_query_start_locs_expanded) + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + spec_common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + positions=common_attn_metadata.positions[token_indices], + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + ) + return spec_common_attn_metadata, token_indices def _propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, - sampling_metadata: SamplingMetadata, - token_indices=None) -> torch.Tensor: + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - if token_indices is not None and self.torchair_graph_enabled: - last_token_indices = token_indices - self.input_ids[last_token_indices] = next_token_ids - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - # FIXME: reorder_batch() needs to be called before build() - # because fields of attn_metadata_builder needs to be updated. - # However, currently reorder_batch() takes input_batch and - # scheduler_output as arguments, we should probably refactor - # the method to use new data structures which are independent - # from input_batch and scheduler_output. - # self.runner.attn_metadata_builder.reorder_batch( - # input_batch=self.runner.input_batch, - # scheduler_output=self.runner.scheduler_output, - # ) - is_running_torchair = self.torchair_graph_enabled and \ - not self.runner.with_prefill - - if is_running_torchair: - # Torchair graph mode, padding is same as the main model - num_input_tokens = self.runner.graph_pad_size - elif (self.runner.use_aclgraph - and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): + assert self.runner is not None + + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp + + if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: # Acl graph mode, add padding to the batch size num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: # Eager mode, no padding needed num_input_tokens = num_tokens - seq_lens = target_positions[last_token_indices] + 1 - seq_lens = seq_lens.int() - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=cu_num_tokens[:batch_size + 1], - query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), - seq_lens_cpu=seq_lens.cpu(), - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping=target_slot_mapping, - positions=target_positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - graph_pad_size=self.runner.graph_pad_size, - decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=None, - seq_lens=None) - - if not self.torchair_graph_enabled: - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, - self.runner.get_model()) - - attn_metadata = {} - for layer_name in self.attn_layer_name: - attn_metadata[layer_name] = attn_metadata_mtp - - else: - attn_metadata = self.runner.attn_metadata_builder.build( - 0, common_attn_metadata, self.runner.get_model()) - + # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - - if not self.torchair_graph_enabled: - # torch mode need to update num_tokens_across_dp - (num_input_tokens, num_tokens_across_dp, - with_prefill) = self.runner._sync_metadata_across_dp( - num_input_tokens, self.runner.with_prefill) - else: - # torchair mode can reuse self.runner.num_tokens_across_dp - num_tokens_across_dp = self.runner.num_tokens_across_dp - with_prefill = self.runner.with_prefill + # eager/acl piecewise mode need to update num_tokens_across_dp + (num_input_tokens, num_tokens_across_dp, + with_prefill) = self.runner._sync_metadata_across_dp( + num_input_tokens, self.runner.with_prefill) moe_comm_type = self.runner._select_moe_comm_method( num_input_tokens, with_prefill) @@ -444,6 +447,15 @@ def _propose( uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) + if aclgraph_runtime_mode not in [ + CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE + ]: + # Fallback to piecewise graph, when acl full graph is enabled + logger.debug( + "Currently the eagle proposer only supports cudagraph_mode " + f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " + "to CUDAGraphMode.PIECEWISE") + aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE for step in range(self.num_speculative_tokens): with set_ascend_forward_context( @@ -461,26 +473,11 @@ def _propose( with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] - if is_running_torchair: - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_input_tokens) - hidden_states = torchair_compiled_model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self. - hidden_states[:num_input_tokens], - inputs_embeds=None, - intermediate_tensors=None, - spec_step_idx=0, - **model_kwargs) - else: - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens] - ) + + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens]) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): @@ -515,10 +512,7 @@ def _propose( if step == self.num_speculative_tokens - 1 or with_prefill: break - if not self.torchair_graph_enabled: - attn_metadata_i = attn_metadata[self.attn_layer_name[0]] - else: - attn_metadata_i = attn_metadata + attn_metadata_i = attn_metadata[self.attn_layer_name[0]] if step == 0: positions = target_positions[last_token_indices] @@ -529,21 +523,16 @@ def _propose( last_token_indices = self.arange[:batch_size] if attn_metadata_i.num_decode_tokens != 0: attn_metadata_i.num_decode_tokens = batch_size - if is_running_torchair: - attn_metadata_i.num_actual_tokens = batch_size - attn_metadata_i.query_lens = [1] * batch_size input_ids = draft_token_ids_list[-1].int() positions += 1 - if not self.torchair_graph_enabled: - attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ - 1:batch_size + 1].tolist() - attn_metadata_i.decode.cos = builder.cos_cache[ - positions].unsqueeze(1).unsqueeze(2) - attn_metadata_i.decode.sin = builder.sin_cache[ - positions].unsqueeze(1).unsqueeze(2) - + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch @@ -601,61 +590,6 @@ def _propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ - -1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - patch_for_hcom() - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.runner.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=not self.use_sparse, - fullgraph=True, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=not self.use_sparse, - fullgraph=True, - cache_dir=TORCHAIR_CACHE_DIR, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - # TODO Using torch instead of triton may result in poor performance def _prepare_input_kernel(self, out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, @@ -676,3 +610,160 @@ def _prepare_input_kernel(self, out_ptr: torch.Tensor, global_indices_flat = global_indices[mask] values_flat = values[mask] out_ptr[global_indices_flat] = values_flat + + def prepare_next_token_ids_cpu( + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[ + req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[: + num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1], + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + positions=common_attn_metadata.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + seq_lens=common_attn_metadata.seq_lens) + + token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] - + 1 - num_rejected_tokens_gpu) + + return spec_common_attn_metadata, token_indices, token_indices_to_sample diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 56befcc112..419d5f45a5 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -34,6 +34,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform +from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.torchair.utils import ( TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata, check_torchair_cache_exist, converting_weight_acl_format, @@ -83,6 +84,20 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self._check_batch_sizes_consistency() + def _set_up_drafter(self): + super()._set_up_drafter() + if self.speculative_config: + # Torchair do not support disable_padded_drafter_batch + # Enforce to disable this feature + self.speculative_config.disable_padded_drafter_batch = True + + def _get_drafter(self): + return get_spec_decode_method(self.speculative_config.method, + self.vllm_config, + self.device, + self, + is_torchair_graph=True) + def _may_pad_kv_consumer_num_seq(self): # pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens # self.max_num_reqs here is greater than the actual maximum request number diff --git a/vllm_ascend/torchair/torchair_mtp_proposer.py b/vllm_ascend/torchair/torchair_mtp_proposer.py new file mode 100644 index 0000000000..c26b8dd401 --- /dev/null +++ b/vllm_ascend/torchair/torchair_mtp_proposer.py @@ -0,0 +1,554 @@ +import types + +import torch +import torch.nn as nn +import torchair +from torchair import patch_for_hcom +from vllm.config import (CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.utils import \ + process_weights_after_loading +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode import MtpProposer +from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ + TorchairDeepSeekMTP +from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, + TorchairCommonAttentionMetadata) +from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, + vllm_version_is) + +if vllm_version_is("0.11.0"): + from vllm.model_executor.model_loader.utils import set_default_torch_dtype +else: + from vllm.utils.torch_utils import set_default_torch_dtype + +PADDING_SLOT_ID = -1 + + +class TorchairMtpProposer(MtpProposer): + + def __init__( + self, + vllm_config: VllmConfig, + device, + runner, + ): + super().__init__(vllm_config, device, runner) + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore + + def load_model(self, model) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase).keys()) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_device = self.vllm_config.device_config.device + + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + + self.model = TorchairDeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + draft_attn_layer_names = (get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase).keys() - + target_attn_layer_names) + + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = list(draft_attn_layer_names) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + process_weights_after_loading(self.model, draft_model_config, + target_device) + + @torch.inference_mode() + def dummy_run(self, + num_tokens: int, + with_prefill: bool = False, + skip_attn: bool = False, + num_reqs: int = 0, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None) -> None: + moe_comm_type = self.runner._select_moe_comm_method( + num_tokens, with_prefill) + + if not with_prefill: + skip_attn = False + if skip_attn: + attn_metadata = None + else: + common_attn_metadata = TorchairCommonAttentionMetadata( + num_reqs=num_reqs, + num_actual_tokens=1, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + decode_token_per_req=self.runner.decode_token_per_req, + ) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + common_attn_metadata) + + input_ids = self.input_ids[:num_tokens] + positions = self.positions[:num_tokens] + previous_hidden_states = self.hidden_states[:num_tokens] + for _ in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=0): + if not with_prefill: + assert attn_metadata is not None + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(previous_hidden_states) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static( + attn_metadata.decode.input_positions) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + torchair_compiled_model( + input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states, + inputs_embeds=None, + intermediate_tensors=None, + attn_metadata=attn_metadata, + kv_caches=self.runner.kv_caches[-1:], + spec_step_idx=0) + else: + self.model(input_ids=input_ids, + positions=positions, + hidden_states=previous_hidden_states) + if with_prefill: + break + + def generate_token_ids(self, + valid_sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata = None, + scheduler_output: SchedulerOutput = None, + spec_decode_metadata: SpecDecodeMetadata = None, + positions: torch.Tensor = None, + num_scheduled_tokens: int = 0, + hidden_states: torch.Tensor = None, + attn_metadata=None, + aux_hidden_states: torch.Tensor = None): + if attn_metadata is not None and isinstance(attn_metadata, dict): + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.runner.input_batch.req_ids[i] + req_state = self.runner.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + accepted_token_indices = None + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, accepted_token_indices, target_token_ids, \ + target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + self.runner.input_ids[:num_scheduled_tokens], + positions[:num_scheduled_tokens], + hidden_states[:num_scheduled_tokens], + attn_metadata.slot_mapping[:num_scheduled_tokens], + ) + + draft_token_ids = self._propose_torchair( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + token_indices=accepted_token_indices) + spec_token_ids = draft_token_ids.tolist() + return spec_token_ids + + def _torchair_prepare_inputs( + self, + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + token_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + + cu_num_tokens = cu_target_query_lens + relative_index = query_len_per_req - num_rejected_tokens - 1 + token_indices = cu_num_tokens[:-1] + relative_index + # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model + target_token_ids = token_ids + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = slot_mapping + + return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping + + def _propose_torchair( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + token_indices=None) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + if token_indices is not None: + last_token_indices = token_indices + + self.input_ids[last_token_indices] = next_token_ids + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + + if not self.runner.with_prefill: + # Torchair graph mode, padding is same as the main model + num_input_tokens = self.runner.graph_pad_size + elif (self.runner.use_aclgraph + and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): + # Acl graph mode, add padding to the batch size + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + # Eager mode, no padding needed + num_input_tokens = num_tokens + + seq_lens = target_positions[last_token_indices] + 1 + seq_lens = seq_lens.int() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=cu_num_tokens[:batch_size + 1], + query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), + seq_lens_cpu=seq_lens.cpu(), + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor(), + slot_mapping=target_slot_mapping, + positions=target_positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + graph_pad_size=self.runner.graph_pad_size, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=None, + seq_lens=None) + + attn_metadata = self.runner.attn_metadata_builder.build( + 0, common_attn_metadata, self.runner.get_model()) + + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states + + # torchair mode can reuse self.runner.num_tokens_across_dp + num_tokens_across_dp = self.runner.num_tokens_across_dp + with_prefill = self.runner.with_prefill + + moe_comm_type = self.runner._select_moe_comm_method( + num_input_tokens, with_prefill) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) + + for step in range(self.num_speculative_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + with_prefill=with_prefill, + num_tokens_across_dp=num_tokens_across_dp, + reserved_mc2_mask=self.runner.reserved_mc2_mask, + moe_comm_type=moe_comm_type, + aclgraph_runtime_mode=aclgraph_runtime_mode, + in_profile_run=self.runner.in_profile_run, + num_actual_tokens=num_tokens): + with ProfileExecuteDuration().capture_async('mtp_forward'): + model_kwargs = {} + model_kwargs["attn_metadata"] = attn_metadata + + model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] + if not self.runner.with_prefill: + torchair_compiled_model = self._get_torchair_lazy_compiled_model( + num_input_tokens) + hidden_states = torchair_compiled_model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self. + hidden_states[:num_input_tokens], + inputs_embeds=None, + intermediate_tensors=None, + spec_step_idx=0, + **model_kwargs) + else: + hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens] + ) + + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable(): + if not self.runner.with_prefill: + max_num_reqs_across_dp = num_input_tokens + else: + max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices)) + + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + if lmhead_tp_enable() and num_indices < logits.shape[0]: + logits = logits[:num_indices] + draft_token_ids = logits.argmax(dim=-1) + + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + if step == 0: + draft_token_ids_list = [draft_token_ids] + else: + draft_token_ids_list.append(draft_token_ids) + + # prepare next mtp inputs + # mtp>1: prefill skip or decode skip last loop + if with_prefill: + for _ in range(self.num_speculative_tokens - 1): + draft_token_ids_list.append(draft_token_ids) + if step == self.num_speculative_tokens - 1 or with_prefill: + break + + attn_metadata_i = attn_metadata + + if step == 0: + positions = target_positions[last_token_indices] + hidden_states = hidden_states[last_token_indices] + slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] + attn_metadata_i.slot_mapping.fill_(-1) + attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] + last_token_indices = self.arange[:batch_size] + if attn_metadata_i.num_decode_tokens != 0: + attn_metadata_i.num_decode_tokens = batch_size + if not self.runner.with_prefill: + attn_metadata_i.num_actual_tokens = batch_size + attn_metadata_i.query_lens = [1] * batch_size + + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.runner.model_config.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + # Increment the sequence lengths. + attn_metadata_i.seq_lens[:batch_size] += 1 + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + exceeds_max_model_len_cpu = exceeds_max_model_len.to( + attn_metadata_i.seq_lens.device, non_blocking=True) + attn_metadata_i.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len_cpu, 1) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping += 1 + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:hidden_states.shape[0]] = hidden_states + attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + + if attn_metadata_i.prefill is not None: + attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( + ) + attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.prefill.max_seq_lens += 1 + attn_metadata_i.prefill.max_seq_lens = min( + attn_metadata_i.prefill.max_seq_lens, + self.runner.model_config.max_model_len) + if attn_metadata_i.decode is not None: + attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( + ) + attn_metadata_i.decode.input_positions = self.positions[: + num_input_tokens] + attn_metadata_i.decode.max_seq_lens += 1 + attn_metadata_i.decode.max_seq_lens = min( + attn_metadata_i.decode.max_seq_lens, + self.runner.model_config.max_model_len) + + # mtp>1: [batch_size, k] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ + -1]: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" + ) + + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + config.experimental_config.enable_view_optimize = \ + get_ascend_config().torchair_graph_config.enable_view_optimize + torch.npu.set_compile_mode(jit_compile=False) + if not self.runner.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=not self.use_sparse, + fullgraph=True, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], + dynamic=not self.use_sparse, + fullgraph=True, + cache_dir=TORCHAIR_CACHE_DIR, + config=config, + ge_cache=False) + return self.torchair_compiled_models[batch_size] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5e13da5307..3033c46570 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -133,6 +133,7 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer +from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, enable_sp, get_ascend_soc_version, is_310p, @@ -369,32 +370,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.attn_mask_builder = AttentionMaskBuilder( self.model_config.max_model_len, self.dtype) - # Set up speculative decoding. - self.spec_attn_mask = None - self.drafter: Optional[Union[NgramProposer, EagleProposer, - MtpProposer]] = None - self.actual_seq_lengths_q: list[int] = [] - self.decode_token_per_req = 1 - if self.speculative_config: - spec_token_num = self.speculative_config.num_speculative_tokens - assert spec_token_num > 0 - self.decode_token_per_req = 1 + spec_token_num - self.spec_attn_mask = torch.triu(torch.ones(2048, - 2048, - dtype=torch.bool), - diagonal=1).to(self.device) - if get_pp_group().is_last_rank: - self.drafter = get_spec_decode_method( - self.speculative_config.method, self.vllm_config, - self.device, self) - if vllm_version_is("0.11.0"): - self.rejection_sampler = AscendRejectionSampler() - else: - self.rejection_sampler = AscendRejectionSampler( - self.sampler) - self.actual_seq_lengths_q = list( - range(self.decode_token_per_req, self.max_num_tokens + 1, - self.decode_token_per_req)) + self._set_up_drafter() # kv role self.is_kv_producer = False @@ -590,6 +566,39 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # TODO: EVS Support (Video tokens pruning) (see vllm#22980) self.is_multimodal_pruning_enabled = False + def _set_up_drafter(self): + # Set up speculative decoding. + self.spec_attn_mask = None + self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, + TorchairMtpProposer]] = None + self.actual_seq_lengths_q: list[int] = [] + self.decode_token_per_req = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + assert spec_token_num > 0 + self.decode_token_per_req = 1 + spec_token_num + self.spec_attn_mask = torch.triu(torch.ones(2048, + 2048, + dtype=torch.bool), + diagonal=1).to(self.device) + if get_pp_group().is_last_rank: + self.drafter = self._get_drafter() + if vllm_version_is("0.11.0"): + self.rejection_sampler = AscendRejectionSampler() + else: + self.rejection_sampler = AscendRejectionSampler( + self.sampler) + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_discarded_requests = 0 + + def _get_drafter(self): + return get_spec_decode_method(self.speculative_config.method, + self.vllm_config, self.device, self) + def _may_pad_kv_consumer_num_seq(self): # For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario, # we may want to pad self.max_num_seqs in kv_consumer nodes to avoid @@ -609,7 +618,7 @@ def _init_mc2_tokens_capacity(self): tp_size = self.parallel_config.tensor_parallel_size # Use integer arithmetic for ceiling division. num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size - self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size + self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size def _make_buffer(self, *size: Union[int, torch.SymInt], @@ -1522,6 +1531,20 @@ def _prepare_inputs( self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) attn_metadata: dict[str, Any] = {} + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + num_reqs = self.input_batch.num_reqs + discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( + discard_request_indices) + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + # _prepare_inputs may reorder the batch, so we must gather # multi-modal outputs after that to ensure the correct order if self.is_multimodal_model: @@ -1615,7 +1638,7 @@ def _prepare_inputs( # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) - spec_decode_common_attn_metadata = None + self.spec_decode_common_attn_metadata = None if use_spec_decode and self.need_accepted_tokens: self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs]) @@ -1676,7 +1699,7 @@ def _prepare_inputs( common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens_cpu, + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], seq_lens=self.seq_lens_cpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=slot_mapping_size, @@ -1700,8 +1723,8 @@ def _prepare_inputs( ) if self.speculative_config and \ - spec_decode_common_attn_metadata is None: - spec_decode_common_attn_metadata = common_attn_metadata + self.spec_decode_common_attn_metadata is None: + self.spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: common_prefix_len = 0 @@ -1998,7 +2021,7 @@ def apply_grammar_bitmask( def propose_draft_token_ids( self, - valid_sampled_token_ids: list[list[int]], + valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]], sampling_metadata: SamplingMetadata, scheduler_output: "SchedulerOutput", spec_decode_metadata: SpecDecodeMetadata, @@ -2255,6 +2278,7 @@ def execute_model( logits = self.apply_grammar_bitmask( scheduler_output, logits) + with ProfileExecuteDuration().capture_async("Sample"): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: @@ -2296,21 +2320,12 @@ def execute_model( if self.need_accepted_tokens: self._update_states_after_model_execute(output_token_ids) - discard_sampled_tokens_req_indices: list[int] = [] - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - discard_sampled_tokens_req_indices.append(i) + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + generator = self.input_batch.generators.get(int(i)) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. @@ -2346,10 +2361,11 @@ def execute_model( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] - invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices = discard_sampled_tokens_req_indices.tolist( + ) invalid_req_indices_set = set(invalid_req_indices) assert sampled_token_ids.shape[-1] == 1 @@ -2394,18 +2410,33 @@ def execute_model( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + def propose_draft_token_ids(sampled_token_ids): + assert self.spec_decode_common_attn_metadata is not None + self._draft_token_ids = self.propose_draft_token_ids( + sampled_token_ids, + sampling_metadata, + scheduler_output, + spec_decode_metadata, + positions, + scheduler_output.total_num_scheduled_tokens, + hidden_states, + attn_metadata, + aux_hidden_states, + ) + + with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: - self._draft_token_ids = self.propose_draft_token_ids( - valid_sampled_token_ids, - sampling_metadata, - scheduler_output, - spec_decode_metadata, - positions, - scheduler_output.total_num_scheduled_tokens, - hidden_states, - attn_metadata, - aux_hidden_states, - ) + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.method == "deepseek_mtp" and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch_for_eagle: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + if self.speculative_config and not use_padded_batch_for_eagle: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 51972d0dc6..48c712b15c 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -92,8 +92,10 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: + elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] + else: + return -1 class InputBatch: