Skip to content

Commit

Permalink
[Core] Ignore infeasible swap requests. (#4557)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored May 2, 2024
1 parent 9b5c9f9 commit 0f8a914
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 42 deletions.
85 changes: 85 additions & 0 deletions tests/basic_correctness/test_preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import pytest

from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)

Expand Down Expand Up @@ -136,3 +137,87 @@ def test_swap(
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]

vllm_model = vllm_runner(
model,
dtype=dtype,
swap_space=10,
block_size=BLOCK_SIZE,
# Since beam search have more than 1 sequence, prefill + decode blocks
# are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
)
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
vllm_model = vllm_runner(
model,
dtype=dtype,
block_size=BLOCK_SIZE,
# Not enough gpu blocks to complete a single sequence.
# preemption should happen, and the sequence should be
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)

assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
# Verify the request is ignored and not hang.
for req_output in req_outputs:
outputs = req_output.outputs
assert len(outputs) == 1
assert outputs[0].finish_reason == "length"
2 changes: 1 addition & 1 deletion tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_swap():

# Swap seq group from CPU -> GPU.
cpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_in(seq_group)
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
Expand Down
5 changes: 3 additions & 2 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest # noqa

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler
from vllm.sequence import Logprob, SequenceGroup

Expand Down Expand Up @@ -410,7 +411,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):

# Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = False
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER

_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group2)
Expand All @@ -423,7 +424,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
assert out.scheduled_seq_groups[0].seq_group == seq_group2

# Now although swap is possible, running prefill is prioritized.
scheduler.block_manager.can_swap_in.return_value = True
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
Expand Down
30 changes: 29 additions & 1 deletion tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in():

# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = False
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
Expand All @@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in():
assert len(output.prefill_seq_groups) == 0


def test_infeasible_swap():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)

# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 0
assert len(output.infeasible_seq_groups) == 2
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(output.decode_seq_groups) == 0
assert len(output.prefill_seq_groups) == 0


def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler()
swapped = deque()
Expand Down
19 changes: 8 additions & 11 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def __init__(
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator

def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
Expand All @@ -123,13 +122,10 @@ def allocate_mutable(self,
Returns:
Block: The newly allocated mutable block.
"""
assert device is not None
return self._allocators[device].allocate_mutable(prev_block)

def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Expand All @@ -144,7 +140,6 @@ def allocate_immutable(self,
Block: The newly allocated immutable block containing the provided
token IDs.
"""
assert device is not None
return self._allocators[device].allocate_immutable(
prev_block, token_ids)

Expand Down Expand Up @@ -175,7 +170,7 @@ def fork(self, last_block: Block) -> List[Block]:
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block)

def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
def get_num_free_blocks(self, device: Device) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
Expand All @@ -185,9 +180,11 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
Returns:
int: The number of free blocks available on the specified device.
"""
assert device is not None
return self._allocators[device].get_num_free_blocks()

def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()

def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Expand Down
21 changes: 13 additions & 8 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def free(self, block: Block) -> None:
def fork(self, last_block: Block) -> List[Block]:
pass

@abstractmethod
def get_num_total_blocks(self) -> int:
pass

@abstractmethod
def get_num_free_blocks(self) -> int:
pass
Expand Down Expand Up @@ -152,20 +156,21 @@ class NoFreeBlocksError(ValueError):
class DeviceAwareBlockAllocator(ABC):

@abstractmethod
def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass

@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
pass

@abstractmethod
def allocate_immutable(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
def get_num_free_blocks(self, device: Device) -> int:
pass

@abstractmethod
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
def get_num_total_blocks(self, device: Device) -> int:
pass

@abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def fork(self, last_block: Block) -> List[Block]:

return forked_blocks

def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
def get_num_free_blocks(self) -> int:
return len(self._free_block_indices)

def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)

def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks

def get_num_total_blocks(self) -> int:
return self._hashless_allocator.get_num_total_blocks()

@property
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
Expand Down
19 changes: 17 additions & 2 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def free(self, block: PhysicalTokenBlock) -> None:
def get_num_free_blocks(self) -> int:
pass

@abstractmethod
def get_num_total_blocks(self) -> int:
pass

@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
Expand Down Expand Up @@ -131,6 +135,9 @@ def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)

def get_num_total_blocks(self) -> int:
return self.num_blocks

def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor

Expand Down Expand Up @@ -190,6 +197,9 @@ def free(self, block: PhysicalTokenBlock) -> None:
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)

def get_num_total_blocks(self) -> int:
return self.num_blocks

def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
Expand Down Expand Up @@ -444,7 +454,7 @@ def _get_physical_blocks(

def can_swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool:
num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group)
Expand All @@ -454,7 +464,12 @@ def can_swap_in(self,
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
return AllocStatus.NEVER
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER

def swap_in(self,
seq_group: SequenceGroup,
Expand Down
4 changes: 2 additions & 2 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_tables[child_seq.seq_id] = src_block_table.fork()

def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return False
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.LATER

def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:

@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
num_lookahead_slots: int) -> AllocStatus:
pass

@abstractmethod
Expand Down
Loading

0 comments on commit 0f8a914

Please sign in to comment.