diff --git a/.github/workflows/vllm_ascend_test_long_term.yaml b/.github/workflows/vllm_ascend_test_long_term.yaml index e249849e19..b4138964cb 100644 --- a/.github/workflows/vllm_ascend_test_long_term.yaml +++ b/.github/workflows/vllm_ascend_test_long_term.yaml @@ -100,7 +100,7 @@ jobs: # spec decode test VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py # TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed - # VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py + VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process pytest -sv tests/e2e/long_term/spec_decode --ignore=tests/e2e/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/e2e/long_term/spec_decode/e2e/test_v1_mtp_correctness.py pytest -sv tests/e2e/long_term/test_accuracy.py diff --git a/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py b/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py index 19ab0bc220..35cb19a14e 100644 --- a/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py +++ b/tests/e2e/long_term/spec_decode/e2e/test_v1_spec_decode.py @@ -11,7 +11,7 @@ @pytest.fixture def test_prompts(): prompt_types = ["repeat", "sentence"] - num_prompts = 100 + num_prompts = 10 prompts = [] random.seed(0) @@ -69,6 +69,7 @@ def test_ngram_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' + pytest.skip("Not current support for the test.") with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -116,11 +117,12 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - pytest.skip("Not current support for the test.") + if not use_eagle3: + pytest.skip("Not current support for the test.") with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - ref_llm = LLM(model=model_name, max_model_len=2048) + ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -129,13 +131,17 @@ def test_eagle_correctness( spec_llm = LLM( model=model_name, trust_remote_code=True, + enable_chunked_prefill=True, + max_num_seqs=1, + max_num_batched_tokens=2048, + gpu_memory_utilization=0.6, speculative_config={ "method": "eagle3" if use_eagle3 else "eagle", "model": spec_model_name, - "num_speculative_tokens": 3, - "max_model_len": 2048, + "num_speculative_tokens": 2, + "max_model_len": 128, }, - max_model_len=2048, + max_model_len=128, enforce_eager=True, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) diff --git a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py b/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py index 27986cb149..3a9068ff6b 100644 --- a/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py +++ b/tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py @@ -38,7 +38,7 @@ def run_test(model_name, queue, more_args=None): - model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4" + model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4,enforce_eager=True" if more_args is not None: model_args = f"{model_args},{more_args}" results = lm_eval.simple_evaluate( diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py new file mode 100644 index 0000000000..3a82018402 --- /dev/null +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import torch +import torch.nn as nn +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.v1.sample.metadata import SamplingMetadata + +from vllm_ascend.attention.attention import AttentionMaskBuilder +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +logger = init_logger(__name__) + +PADDING_SLOT_ID = -1 + + +class EagleProposer: + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + self.vllm_config = vllm_config + self.speculative_config = vllm_config.speculative_config + self.draft_model_config = self.speculative_config.draft_model_config + self.method = self.speculative_config.method + self.runner = runner + self.model_config = vllm_config.model_config + self.dtype = vllm_config.model_config.dtype + self.max_model_len = vllm_config.model_config.max_model_len + self.block_size = vllm_config.cache_config.block_size + self.num_speculative_tokens = ( + self.speculative_config.num_speculative_tokens) + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + self.device = device + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=device, + dtype=torch.int32) + mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) + self.attn_mask_len = min(self.model_config.max_model_len, + int(mask_len)) + self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + self.attn_mask_len, self.dtype) + + def _make_attention_mask( + self, + seq_lens, + query_lens, + position, + ) -> torch.Tensor: + return self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, query_lens, position, self.dtype, self.device) + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + device = cu_num_tokens.device + cu_num_tokens = cu_num_tokens.cpu() + block_table = block_table.cpu() + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + target_positions = target_positions.cpu() + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids[0] + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + attn_metadata = self.runner.attn_metadata_builder.build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + ) + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions.to(device) + self.hidden_states[:num_tokens] = target_hidden_states + attn_metadata.block_tables = block_table.to(device) + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + ) + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids = logits.argmax(dim=-1) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_tensor = torch.zeros( + (self.num_speculative_tokens, *draft_token_ids.shape), + dtype=draft_token_ids.dtype) + draft_token_ids_tensor[0] = draft_token_ids + + positions_cpu = target_positions[last_token_indices].cpu().to( + torch.int64) + hidden_states = hidden_states[last_token_indices] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + + if self.num_speculative_tokens > 2: + raise ValueError("Speculative tokens > 2 are not supported yet.") + + attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + for now_speculative in range(self.num_speculative_tokens - 1): + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_tensor[now_speculative].to(device) + positions_cpu += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions_cpu >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, + positions_cpu) + clamped_positions = clamped_positions_cpu.to(device) + + # TODO: Increment the sequence lengths. + + attn_metadata.seq_lens += 1 + # TODO: Consider max model length. + # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + # self.max_model_len) + # For the requests that exceed the max model length, we set the + # TODO: sequence length to 1 to minimize their overheads in attention. + + # Compute the slot mapping. + block_numbers = (clamped_positions_cpu // self.block_size) + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + slot_mapping_cpu = (block_ids * self.block_size + + clamped_positions_cpu % self.block_size) + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping_cpu.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + # NOTE: ASCEND slot_mapping must on cpu + attn_metadata.slot_mapping = slot_mapping_cpu.to( + torch.int32).to(device) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + positions = positions_cpu.to(device) + attn_mask = self._make_attention_mask( + seq_lens=attn_metadata.seq_lens, + query_lens=attn_metadata.max_query_len, + position=positions, + ) + attn_metadata.attn_mask = attn_mask + attn_metadata.block_tables = block_table.to(device) + # Run the model. + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + ) + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], + None) + + # TODO(wenlong): get more than one token for tree attention + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() + + # [batch_size, num_speculative_tokens] + draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) + return draft_token_ids + + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + # [a - n1, b - n2, c - n3] -> + # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + cu_num_tokens = torch.zeros_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_target_query_lens.device, + ) + BLOCK_SIZE = 1024 + prepare_eagle_input_sequential( + token_indices, + cu_target_query_lens, + cu_num_tokens, + block_size=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def load_model(self, target_model: nn.Module) -> None: + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) + + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + + self.attn_layer_names = list(draft_attn_layer_names) + self.attn_layer_name = next(iter(draft_attn_layer_names)) + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = target_model.model.embed_tokens + else: + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.vllm_config.speculative_config.method != "eagle3" and \ + hasattr(target_model, "lm_head"): + logger.info("Loading EAGLE LM head weights from the target model.") + if supports_multimodal(target_model): + self.model.lm_head = target_model.get_language_model().lm_head + else: + self.model.lm_head = target_model.lm_head + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + ) -> None: + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + + +def prepare_eagle_input_sequential(out_tensor: torch.Tensor, + cu_query_lens: torch.Tensor, + cu_num_tokens: torch.Tensor, + block_size: int): + num_programs = len(cu_num_tokens) - 1 + for pid in range(num_programs): + start_pos = cu_num_tokens[pid].item() + end_pos = cu_num_tokens[pid + 1].item() + num_tokens = end_pos - start_pos + index_start = cu_query_lens[pid].item() + num_blocks = int( + torch.ceil(torch.tensor(num_tokens / block_size)).item()) + + for i in range(num_blocks): + offset_tensor = torch.arange(0, + block_size, + dtype=torch.int32, + device=out_tensor.device) + global_start_offset = i * block_size + target_indices = torch.tensor( + start_pos + global_start_offset, + dtype=torch.int32, + device=out_tensor.device) + offset_tensor + values_to_store = torch.tensor( + index_start, dtype=torch.int32, + device=out_tensor.device) + offset_tensor + mask = (target_indices >= start_pos) & \ + (target_indices < end_pos) & \ + (offset_tensor < num_tokens) + out_tensor[target_indices[mask]] = values_to_store[mask] + + +# NOTE(woosuk): Currently, the below code is not used and we always use argmax +# to sample the draft tokens. We will use this after we find a way to manage +# the draft prob tensor. +# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs + # will be used later for rejection sampling. + next_token_ids = probs.div(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 89f30bc43c..7f29562c27 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -57,7 +57,6 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported @@ -70,11 +69,13 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, + AscendMetadata) from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is +from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer if TYPE_CHECKING: @@ -206,8 +207,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # Set up speculative decoding. + self.use_aux_hidden_state_outputs = False self.use_spec_decode = False self.spec_attn_mask = None + self.use_eagle = False if self.speculative_config: self.use_spec_decode = True self.spec_attn_mask = torch.triu(torch.ones(2048, @@ -217,9 +220,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle": - self.drafter = EagleProposer(self.vllm_config, - self.device) # type: ignore + elif self.speculative_config.method in ["eagle", "eagle3"]: + self.use_eagle = True + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == 'deepseek_mtp': self.drafter = MtpProposer(self.vllm_config, self) else: @@ -589,6 +595,140 @@ def _get_forward_metadata_across_dp( group=get_dp_group().cpu_group) return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + def get_eagle_atten_dict( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, AscendMetadata]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) + self.query_lens = torch.from_numpy(num_scheduled_tokens) + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table = self.input_batch.block_table[kv_cache_group_id] + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + # Copy the tensors to the NPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + else: + # Common case (1D positions) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.seq_lens[num_reqs:].fill_(0) + self.query_start_loc[num_reqs + 1:].fill_(-1) + + attn_metadata: dict[str, AscendMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + ) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata + def get_model(self) -> nn.Module: return self.model @@ -776,7 +916,7 @@ def _process_reqs( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor]: + torch.Tensor, int, torch.Tensor, torch.Tensor]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -874,7 +1014,10 @@ def _process_reqs( attn_state = AscendAttentionState.DecodeOnly # Speculative decoding. elif np.all(num_valid_tokens == 1): - attn_state = AscendAttentionState.SpecDecoding + if self.use_eagle: + attn_state = AscendAttentionState.ChunkedPrefill + else: + attn_state = AscendAttentionState.SpecDecoding # splitfuse elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill @@ -1051,8 +1194,32 @@ def _process_reqs( num_draft_tokens, cu_num_tokens) sample_indices = spec_decode_metadata.logits_indices + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states + return (attn_metadata, hidden_states, spec_decode_metadata, positions, - total_num_scheduled_tokens, sample_indices) + total_num_scheduled_tokens, sample_indices, aux_hidden_states) + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange def _calc_spec_decode_metadata( self, @@ -1193,6 +1360,7 @@ def _get_spec_token_ids( num_scheduled_tokens: int, hidden_states: torch.Tensor, attn_metadata: SpecDecodeMetadata, + aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.use_spec_decode: # Speculative decoding is not enabled. @@ -1202,9 +1370,85 @@ def _get_spec_token_ids( spec_token_ids = self._generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) elif self.speculative_config.method == "eagle": - raise NotImplementedError( - "eagle method for spec decode doesn't work on vllm-ascend currently" - ) + raise NotImplementedError("Eagle Is Not Supported Yet.") + elif self.speculative_config.method == "eagle3": + assert isinstance(self.drafter, EagleProposer) + if self.speculative_config.use_eagle(): + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + eagle_attn_metadata = attn_metadata[ + self.drafter.attn_layer_name] + num_input_tokens = scheduler_output.total_num_scheduled_tokens + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([ + h[:num_scheduled_tokens] for h in aux_hidden_states + ], + dim=-1) + else: + target_hidden_states = hidden_states[: + num_scheduled_tokens] + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc + else: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + num_tokens = num_scheduled_tokens - sum( + num_rejected_tokens) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + eagle_attn_metadata.query_start_loc, + num_rejected_tokens, num_tokens) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] + + positions = self.positions[:num_input_tokens] + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=eagle_attn_metadata.block_tables, + sampling_metadata=sampling_metadata, + ) + spec_token_ids = draft_token_ids.tolist() elif self.speculative_config.method == 'deepseek_mtp': assert isinstance(self.drafter, MtpProposer) spec_token_ids = self._generate_mtp_token_ids( @@ -1226,14 +1470,16 @@ def execute_model( # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) + num_scheduled_tokens, sample_indices, + aux_hidden_states) = (self._process_reqs(scheduler_output, + intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): + logits = self.model.compute_logits(hidden_states[sample_indices], None) - + if self.use_eagle: + attn_metadata = self.get_eagle_atten_dict(scheduler_output) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: logits = self.apply_grammar_bitmask(scheduler_output, logits) @@ -1272,6 +1518,7 @@ def execute_model( ) sampler_output.sampled_token_ids = output_token_ids + discard_sampled_tokens_req_indices: list[int] = [] # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. discard_sampled_tokens_req_indices = [] @@ -1318,6 +1565,7 @@ def execute_model( num_scheduled_tokens, hidden_states, attn_metadata, + aux_hidden_states, ) if vllm_version_is("0.9.1"): model_runner_output = ModelRunnerOutput( @@ -1440,6 +1688,7 @@ def _dummy_run( num_tokens: int, is_compile: bool = False, with_prefill: bool = True, + skip_attn: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1454,6 +1703,16 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + if skip_attn: + attn_metadata = None + else: + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + ) + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1519,7 +1778,15 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) - return hidden_states + if self.use_aux_hidden_state_outputs: + hidden_states, _ = hidden_states + else: + hidden_states = hidden_states + if self.use_spec_decode and \ + self.speculative_config.method in ('eagle', 'eagle3'): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + return hidden_states def profile_run(self) -> None: # FIXME Profile with multimodal encoder & encoder cache. @@ -1567,7 +1834,13 @@ def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) if hasattr(self, "drafter"): logger.info("Loading drafter model...") - self.drafter.load_model() + if self.use_aux_hidden_state_outputs: + self.drafter.load_model(self.model) + else: + self.drafter.load_model() + if self.use_aux_hidden_state_outputs: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -1640,6 +1913,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + self.kv_cache_config = kv_cache_config import torch_npu kv_caches: Dict[str, torch.Tensor] = {}