-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative decoding 4/9] Lookahead scheduling for speculative decoding #3250
[Speculative decoding 4/9] Lookahead scheduling for speculative decoding #3250
Conversation
fe7d9e5
to
b468716
Compare
Ready for review cc @LiuXiaoxuanPKU |
bcfcc83
to
301603e
Compare
301603e
to
d8837a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Just left some questions.
vllm/core/block_manager_v2.py
Outdated
num_touched_blocks = 0 | ||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||
block_table = self.block_tables[seq.seq_id] | ||
num_new_tokens = seq.get_len() - block_table.num_full_slots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the definition of num_new_tokens
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the number of tokens that do not have a slot allocated in the block table. Will add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understanding, the num_new_tokens == len(unseen_token_ids)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. I will rename them so it's consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
self, | ||
seq: Sequence, | ||
) -> Optional[Tuple[int, int]]: | ||
num_lookahead_slots: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_lookahead_slots
is not used in this function, is it expetced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not expected, will add ensure_num_empty_slots
call with num_lookahead_slots
+ see why the test didn't catch this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed; added test.
) | ||
return scheduler_outputs | ||
|
||
def _can_append_slots(self, seq_group: SequenceGroup) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I just feel the two functions below _can_append_slots
and _can_swap_in
are a bit shallow and do not hide much complexity. Maybe we can call them directly above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pr feedback applied @LiuXiaoxuanPKU
vllm/core/block_manager_v2.py
Outdated
num_touched_blocks = 0 | ||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||
block_table = self.block_tables[seq.seq_id] | ||
num_new_tokens = seq.get_len() - block_table.num_full_slots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
self, | ||
seq: Sequence, | ||
) -> Optional[Tuple[int, int]]: | ||
num_lookahead_slots: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed; added test.
@cadedaniel really awesome series of changes! I assume the answer is no, but does the draft model also have it's own KV cache? If yes, where is it created and updated? |
Thanks! Great question. The draft model can have KV, and by default it does. The spec decode worker has a proposer worker which can be a normal vLLM Worker, with KV cache and everything. The number of blocks in the draft KV and target KV is calculated by this function: vllm/vllm/spec_decode/spec_decode_worker.py Lines 97 to 116 in b3104b2
As for the mapping between token ids and block indices, we keep things simple by using the same block mapping for the draft and target models. We can do this because of the code linked above, which ensures that the draft and target KV have the same amount of logical KV space. With this, the KV of the draft model is populated in lockstep with the target model. There are more details, like how proposal KV are handled, but this is the gist! |
|
This PR introduces the concept of lookahead scheduling. Lookahead scheduling is where we allocate KV slots for each sequence in a decode batch that do not have any token assigned ("empty slots"). Speculative decoding fills these KV slots with the KV of speculative tokens when running the target model. Furthermore, speculative decoding involving a proposal method that has KV cache will also use the KV slots in normal autoregressive generation.
See these step-by-step examples explaining how Lookahead scheduling works.
Note: we could use scratch space for these KV, however in the case where tokens are accepted we would need to copy the accepted KV from the scratch space to the allocated KV slots. By allocating them ahead-of-time, we save us the complexity of scheduling such a memcpy.
Testing
Temporary flag
A temporary flag
--num-lookahead-slots
is added to facilitate testing. It will be removed in PR 6/9 of the speculative decoding oss plan.