Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103

Merged
merged 73 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d74ff5c
first test passes
cadedaniel Feb 24, 2024
74b5c48
test
cadedaniel Feb 24, 2024
f6a730b
test
cadedaniel Feb 24, 2024
aafebd0
test
cadedaniel Feb 29, 2024
415db01
test
cadedaniel Feb 29, 2024
76dfe1a
test
cadedaniel Feb 29, 2024
069b564
test
cadedaniel Feb 29, 2024
7f13ccd
test for metrics, separate out metrics functionality
cadedaniel Feb 29, 2024
c91a55b
metrics test
cadedaniel Feb 29, 2024
3a69d54
clean
cadedaniel Feb 29, 2024
73212b7
test
cadedaniel Feb 29, 2024
f796ab0
fixes
cadedaniel Feb 29, 2024
384dc9d
nvtx_range
cadedaniel Feb 29, 2024
2d1a192
profile and cache tests
cadedaniel Feb 29, 2024
bec5cba
test
cadedaniel Feb 29, 2024
273baea
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Feb 29, 2024
7a18f37
lint
cadedaniel Feb 29, 2024
4eb8e04
attempt add tests to ci
cadedaniel Feb 29, 2024
e0ec4b4
refactor outline
cadedaniel Mar 4, 2024
b7e580b
wip
cadedaniel Mar 4, 2024
665ed8e
WIP
cadedaniel Mar 5, 2024
8fcb257
sampler mock raw tensors
cadedaniel Mar 5, 2024
79a1f6c
wip
cadedaniel Mar 5, 2024
7a42183
asd
cadedaniel Mar 5, 2024
c86b44e
asd
cadedaniel Mar 5, 2024
f5e5d76
wip
cadedaniel Mar 5, 2024
68284ed
bugfix
cadedaniel Mar 5, 2024
87cc31a
wip
cadedaniel Mar 5, 2024
c142006
wip
cadedaniel Mar 5, 2024
c026de9
wip
cadedaniel Mar 5, 2024
1486a84
wip
cadedaniel Mar 5, 2024
264b5cb
wip
cadedaniel Mar 5, 2024
8cc8caf
wip
cadedaniel Mar 5, 2024
ee1efff
wip
cadedaniel Mar 5, 2024
cadac54
wip
cadedaniel Mar 5, 2024
2c2dd86
wip
cadedaniel Mar 5, 2024
b112463
wip
cadedaniel Mar 5, 2024
e09c666
wip
cadedaniel Mar 5, 2024
807aa02
fix
cadedaniel Mar 5, 2024
5545e73
clean
cadedaniel Mar 5, 2024
49c9798
wip
cadedaniel Mar 6, 2024
978a711
remove
cadedaniel Mar 6, 2024
d548db4
clean
cadedaniel Mar 6, 2024
a277fe0
wip
cadedaniel Mar 6, 2024
c1357d6
wip
cadedaniel Mar 6, 2024
136a59b
wip
cadedaniel Mar 6, 2024
5657feb
wip
cadedaniel Mar 6, 2024
059beba
rename
cadedaniel Mar 6, 2024
4ce7119
wip
cadedaniel Mar 6, 2024
db52bee
clean
cadedaniel Mar 6, 2024
20297f2
clean
cadedaniel Mar 6, 2024
13aebbf
wip
cadedaniel Mar 6, 2024
524ada4
clean
cadedaniel Mar 6, 2024
e3f57b0
wip
cadedaniel Mar 6, 2024
aff7a34
clean
cadedaniel Mar 6, 2024
c2beb94
rename
cadedaniel Mar 6, 2024
2c835de
first autoformat
cadedaniel Mar 6, 2024
292c34f
wip
cadedaniel Mar 6, 2024
c387d56
wip
cadedaniel Mar 6, 2024
39382b7
move
cadedaniel Mar 6, 2024
f78325c
move
cadedaniel Mar 6, 2024
5cafc12
move
cadedaniel Mar 6, 2024
3a5dcfb
name
cadedaniel Mar 6, 2024
8e7ee97
sequence test and docs
cadedaniel Mar 6, 2024
e5e334f
docs
cadedaniel Mar 6, 2024
2d27c57
lint
cadedaniel Mar 6, 2024
3c61a52
typo
cadedaniel Mar 6, 2024
be66076
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
b165a73
lint
cadedaniel Mar 6, 2024
17725fb
fix test
cadedaniel Mar 6, 2024
364a415
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
11b6f39
pr feedback
cadedaniel Mar 9, 2024
ab00fcf
better comment
cadedaniel Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.

- label: Engine Test
command: pytest -v -s engine
command: pytest -v -s engine test_sequence.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
Expand All @@ -52,6 +52,9 @@ steps:
- label: Worker Test
command: pytest -v -s worker

- label: Speculative decoding tests
command: pytest -v -s spec_decode

- label: LoRA Test
command: pytest -v -s lora --forked

Expand Down
File renamed without changes.
95 changes: 95 additions & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import pytest

from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer

from .utils import mock_worker, create_seq_group_metadata_from_prompts


@pytest.mark.parametrize('num_target_seq_ids', [100])
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)

all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]

for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id


@pytest.mark.parametrize('k', [1, 2, 6])
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)

expected_output = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())

scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access

actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]

assert actual_output == expected_output


@pytest.mark.parametrize('k', [1, 2, 6])
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""

prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]

token_ids = list(range(k))

num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1

final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)

block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]

input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100

scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
)

assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids

assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]
157 changes: 157 additions & 0 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
import math
import pytest

from unittest.mock import MagicMock

from vllm.spec_decode.metrics import AsyncMetricsCollector


def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None


def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None


@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(rej_sampler)
collector.init_gpu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None


def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None


@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5

num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
num_draft_tokens, k)

rej_sampler = MagicMock()
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
rej_sampler.num_draft_tokens = num_draft_tokens

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]

collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)

assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens

if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)
Loading
Loading