Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 4/9] Lookahead scheduling for speculative decoding #3250

Merged
merged 26 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",

# skip cuda graph creation for fast test.
"enforce_eager": True,

# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.

This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)

print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",

# Our prompts will generate 128 tokens; since the prompts themselves are
# small, we don't need much KV space beyond 128.
"max_model_len": 160,

# skip cuda graph creation for fast test.
"enforce_eager": True,

# Lookahead scheduling only supported in v2 block manager.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"block_size": 16,

# Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size
"forced_num_gpu_blocks": 2 * (8 + 1),
},
{
"block_size": 8,

# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"forced_num_gpu_blocks": 2 * (16 + 1),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"num_lookahead_slots": 0,
}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
}])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size):
"""Verify vLLM produces the same output with greedy sampling, when lookahead
scheduling is used vs. not.

Lookahead scheduling is not expected to modify the output, as it simply
allocates empty slots ahead of the known token ids in a sliding fashion.

This test constrains the total number of blocks to force preemption. It also
varies the block size so that the lookahead size is less than and greater
than the block size.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

print('Getting token ids without lookahead scheduling')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids with lookahead scheduling')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
Expand Down
103 changes: 103 additions & 0 deletions tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest

from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list

from ..utils import create_seq_group


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
num_gpu_blocks: int, watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
)
num_watermark_blocks = int(watermark * num_gpu_blocks)

num_output_blocks_per_seq = 1

# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks = num_output_blocks_per_seq

for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
seq_group = create_seq_group(
seq_prompt_len=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
)

assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks

can_allocate_result = block_manager.can_allocate(seq_group)

num_required_blocks = num_prompt_blocks + num_output_blocks

if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
assert can_allocate_result == AllocStatus.NEVER
elif num_gpu_blocks >= num_required_blocks:
assert can_allocate_result == AllocStatus.OK
else:
assert can_allocate_result == AllocStatus.LATER


@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
def test_append_slots(block_size, prompt_len, num_slots_to_append,
num_lookahead_slots):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""

num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
watermark=watermark,
)

seq_group = create_seq_group(
seq_prompt_len=prompt_len,
seq_output_lens=[0],
)

# Allocate seq
assert block_manager.can_allocate(seq_group)
block_manager.allocate(seq_group)

# Seq seq to RUNNING
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING

# Append tokens to the sequeqnce
for token_id in range(num_slots_to_append):
seq.append_token_id(token_id, {token_id: Logprob(0.0)})

# Append slots for new tokens and lookahead slots.
free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
block_manager.append_slots(seq, num_lookahead_slots)
num_consumed_blocks = (free_blocks_before_append -
block_manager.get_num_free_gpu_blocks())

# Expect consumed blocks to be new blocks required to support the new slots.
expected_consumed_blocks = len(
chunk_list(
list(
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
assert num_consumed_blocks == expected_consumed_blocks
50 changes: 0 additions & 50 deletions tests/core/block/test_block_space_manager.py

This file was deleted.

75 changes: 75 additions & 0 deletions tests/core/block/test_block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,

# After free, expect all blocks to be freed.
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks


@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
num_new_tokens: int,
num_lookahead_slots: int,
allocator_type: str):
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.

This is done by using copy-on-write, which requires any modified block to
be copied before write if the refcount > 1. We set the refcount>1 by forking
a sequence, then measure the free blocks before and after an append. If the
number of consumed blocks equals what `get_num_blocks_touched_by_append_
slots` returns, then the calculation is correct.
"""

num_gpu_blocks = 1024

allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
block_size=block_size,
)

token_ids = list(range(sequence_len))
token_ids_to_append = list(range(num_new_tokens))

block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)

block_table.allocate(token_ids=token_ids, device=Device.GPU)

# Add lookahead before fork so both sequences have the same lookahead
# blocks.
block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)

# Fork sequence so that every block has refcount > 1.
_ = block_table.fork()

# Determine how many blocks should be touched.
expected_num_touched_blocks = (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=token_ids_to_append,
num_lookahead_slots=num_lookahead_slots))

# Measure how many blocks are touched by measuring num_free_blocks before
# and after the append.
#
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
num_consumed_blocks = (num_free_blocks_before_append -
allocator.get_num_free_blocks(Device.GPU))

# TODO(cade) ensure equality when num_lookahead_slots > 0.
# The reason we have < is because lookahead blocks are not copied eagerly;
# they are copied on first write. This will cause issues for beam search +
# speculative decoding. This is acceptable for now as it is a large effort
# to combine the two. To fix this, we can ensure single sequence ownership
# of lookahead blocks by appending empty slots to each block, which will
# trigger the CoW.
#
# Until then, we can accept that the consumed tokens are <= the expected
# tokens when appending with lookahead.
if num_lookahead_slots > 0:
assert num_consumed_blocks <= expected_num_touched_blocks
else:
assert num_consumed_blocks == expected_num_touched_blocks
Loading
Loading