Skip to content

Commit

Permalink
[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Browse files Browse the repository at this point in the history
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
  • Loading branch information
3 people authored May 16, 2024
1 parent 30e7543 commit 973617a
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 182 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py

- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests"
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
Expand All @@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
Expand Down Expand Up @@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand Down Expand Up @@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
50 changes: 0 additions & 50 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,6 @@
from .conftest import get_output_from_llm_generator


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
44 changes: 44 additions & 0 deletions tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""

import pytest

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
65 changes: 65 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""

import pytest
import torch

from vllm.utils import is_hip

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
37 changes: 0 additions & 37 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
10 changes: 5 additions & 5 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict

group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict

rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
Expand Down
71 changes: 17 additions & 54 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):

def _init_executor(self) -> None:
"""Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()

def _get_worker_kwargs(
self,
Expand All @@ -45,66 +44,30 @@ def _get_worker_kwargs(
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0,
)

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):

if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"

wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker

def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()

def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None

from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker

target_worker = self._create_worker()

draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)

spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)

assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = spec_decode_worker

# Load model handled in spec decode worker.
self.driver_worker.init_device()

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
Expand Down
Loading

0 comments on commit 973617a

Please sign in to comment.