diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index da18ece7555a..0f6098d2b400 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -174,6 +174,7 @@ def allocate_slots( num_new_tokens: int, num_new_computed_tokens: int = 0, new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: @@ -273,7 +274,7 @@ def allocate_slots( # generated (accepted) tokens. self.single_type_manager.cache_blocks( request, self.req_to_block_hashes[request.request_id], - num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) + num_computed_tokens + num_new_tokens - num_draft_tokens) return KVCacheBlocks(new_blocks) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index efc0de350fba..4c6b3eea0cb7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -227,10 +227,15 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - + request.num_tokens, 0) + while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, + num_draft_tokens=num_draft_tokens, num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled.