Skip to content

Commit

Permalink
FIX #7592 keeping chunked prefill performance the untouched
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Aug 30, 2024
1 parent 2148441 commit 5e8eda1
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,16 +1027,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:

# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)

# Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self.running.extend([s.seq_group for s in prefills.seq_groups])

# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
Expand Down

0 comments on commit 5e8eda1

Please sign in to comment.