Skip to content

Commit

Permalink
FIX #7592 keeping chunked prefill performance the untouched
Browse files Browse the repository at this point in the history
  • Loading branch information
noooop committed Aug 27, 2024
1 parent 029c71d commit dd12bc8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,16 @@ def _schedule_running(
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.

if enable_chunking:
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running = deque(
sorted(
self.running,
key=lambda seq_group: seq_group.metrics.arrival_time,
))

running_queue = self.running

while running_queue:
Expand Down

0 comments on commit dd12bc8

Please sign in to comment.