|
34 | 34 | from vllm.sampling_params import SamplingType |
35 | 35 | from vllm.sequence import IntermediateTensors |
36 | 36 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, |
37 | | - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, |
38 | | - check_use_alibi, is_pin_memory_available) |
| 37 | + GiB_bytes, LayerBlockType, LazyLoader, |
| 38 | + async_tensor_h2d, cdiv, check_use_alibi, |
| 39 | + is_pin_memory_available) |
39 | 40 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata |
40 | 41 | from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
41 | 42 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget |
@@ -898,7 +899,6 @@ def _calc_spec_decode_metadata( |
898 | 899 | target_logits_indices=target_logits_indices, |
899 | 900 | bonus_logits_indices=bonus_logits_indices, |
900 | 901 | logits_indices=logits_indices, |
901 | | - total_num_scheduled_tokens=cu_num_scheduled_tokens[-1], |
902 | 902 | ) |
903 | 903 | return metadata |
904 | 904 |
|
@@ -1397,8 +1397,7 @@ def execute_model( |
1397 | 1397 | dtype=torch.int32, |
1398 | 1398 | target_device=self.device, |
1399 | 1399 | pin_memory=True) |
1400 | | - num_tokens = spec_decode_metadata.total_num_scheduled_tokens - \ |
1401 | | - sum(num_rejected_tokens) |
| 1400 | + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) |
1402 | 1401 | cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
1403 | 1402 | eagle_attn_metadata.query_start_loc, |
1404 | 1403 | num_rejected_tokens_tensor, |
|
0 commit comments