Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 4/9] Lookahead scheduling for speculative decoding #3250

Merged
merged 26 commits into from
Apr 1, 2024

Conversation

cadedaniel
Copy link
Collaborator

@cadedaniel cadedaniel commented Mar 7, 2024

This PR introduces the concept of lookahead scheduling. Lookahead scheduling is where we allocate KV slots for each sequence in a decode batch that do not have any token assigned ("empty slots"). Speculative decoding fills these KV slots with the KV of speculative tokens when running the target model. Furthermore, speculative decoding involving a proposal method that has KV cache will also use the KV slots in normal autoregressive generation.

Untitled drawing (1)

See these step-by-step examples explaining how Lookahead scheduling works.

Note: we could use scratch space for these KV, however in the case where tokens are accepted we would need to copy the accepted KV from the scratch space to the allocated KV slots. By allocating them ahead-of-time, we save us the complexity of scheduling such a memcpy.

Testing

  • This PR finishes the copy-on-write scheduler integration from [Core] [Bugfix] Refactor block manager subsystem for better testability #3492 (since now there can be multiple CoW per append_slots).
    • A test is added where we verify v1/v2 block manager output equality when there are copy-on-writes: test_v1_v2_greedy_equality_with_cow
  • test_lookahead_greedy_equality_with_preemption tests equality of generation when lookahead is enabled vs disabled, and includes preemption

Temporary flag

A temporary flag --num-lookahead-slots is added to facilitate testing. It will be removed in PR 6/9 of the speculative decoding oss plan.

@cadedaniel cadedaniel changed the title [WIP] [Speculative decoding 4/9] Scheduler allocates >1 slot per sequence per step [WIP] [Speculative decoding 4/9] Lookahead scheduling for speculative decoding Mar 28, 2024
@cadedaniel cadedaniel changed the title [WIP] [Speculative decoding 4/9] Lookahead scheduling for speculative decoding [Speculative decoding 4/9] Lookahead scheduling for speculative decoding Mar 28, 2024
@cadedaniel cadedaniel marked this pull request as ready for review March 28, 2024 02:16
@cadedaniel
Copy link
Collaborator Author

Ready for review cc @LiuXiaoxuanPKU

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Just left some questions.

num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_new_tokens = seq.get_len() - block_table.num_full_slots
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the definition of num_new_tokens here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the number of tokens that do not have a slot allocated in the block table. Will add a comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the num_new_tokens == len(unseen_token_ids)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I will rename them so it's consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
num_lookahead_slots: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_lookahead_slots is not used in this function, is it expetced?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not expected, will add ensure_num_empty_slots call with num_lookahead_slots + see why the test didn't catch this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed; added test.

vllm/core/block_manager_v2.py Show resolved Hide resolved
)
return scheduler_outputs

def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I just feel the two functions below _can_append_slots and _can_swap_in are a bit shallow and do not hide much complexity. Maybe we can call them directly above?

Copy link
Collaborator Author

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr feedback applied @LiuXiaoxuanPKU

num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_new_tokens = seq.get_len() - block_table.num_full_slots
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
num_lookahead_slots: int,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed; added test.

@cadedaniel cadedaniel enabled auto-merge (squash) April 1, 2024 18:25
@cadedaniel cadedaniel disabled auto-merge April 1, 2024 20:34
@cadedaniel cadedaniel enabled auto-merge (squash) April 1, 2024 22:02
@cadedaniel cadedaniel merged commit 93deb0b into vllm-project:main Apr 1, 2024
33 checks passed
@cadedaniel cadedaniel deleted the multi-step-scheduler branch April 1, 2024 22:59
@animan42
Copy link

@cadedaniel really awesome series of changes! I assume the answer is no, but does the draft model also have it's own KV cache? If yes, where is it created and updated?

@cadedaniel
Copy link
Collaborator Author

@cadedaniel really awesome series of changes! I assume the answer is no, but does the draft model also have it's own KV cache? If yes, where is it created and updated?

Thanks! Great question. The draft model can have KV, and by default it does. The spec decode worker has a proposer worker which can be a normal vLLM Worker, with KV cache and everything. The number of blocks in the draft KV and target KV is calculated by this function:

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.determine_num_available_blocks())
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks

As for the mapping between token ids and block indices, we keep things simple by using the same block mapping for the draft and target models. We can do this because of the code linked above, which ensures that the draft and target KV have the same amount of logical KV space. With this, the KV of the draft model is populated in lockstep with the target model.

There are more details, like how proposal KV are handled, but this is the gist!

@xunfeng1980
Copy link

  File "/data/anaconda3/envs/qwen-q/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/anaconda3/envs/qwen-q/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/data/vllm/vllm/entrypoints/openai/api_server.py", line 157, in <module>
    engine = AsyncLLMEngine.from_engine_args(
  File "/data/vllm/vllm/engine/async_llm_engine.py", line 347, in from_engine_args
    engine = cls(
  File "/data/vllm/vllm/engine/async_llm_engine.py", line 311, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/data/vllm/vllm/engine/async_llm_engine.py", line 421, in _init_engine
    return engine_class(*args, **kwargs)
  File "/data/vllm/vllm/engine/llm_engine.py", line 119, in __init__
    self.model_executor = executor_class(
  File "/data/vllm/vllm/executor/gpu_executor.py", line 37, in __init__
    assert (not speculative_config
AssertionError: Speculative decoding not yet supported for GPU backend

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants