Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding][Re-take] Enable TP>1 speculative decoding #4840

Merged
merged 20 commits into from
May 16, 2024
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py

- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests"
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
Expand All @@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
Expand Down Expand Up @@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
Expand Down Expand Up @@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
Expand Down
50 changes: 0 additions & 50 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,6 @@
from .conftest import get_output_from_llm_generator


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0

prompts = [
"Hello, my name is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
Expand Down
44 changes: 44 additions & 0 deletions tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""

import pytest

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
65 changes: 65 additions & 0 deletions tests/spec_decode/e2e/test_integration_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""

import pytest
import torch

from vllm.utils import is_hip

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
37 changes: 0 additions & 37 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)
10 changes: 5 additions & 5 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict

group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict

rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
Expand Down
71 changes: 17 additions & 54 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):

def _init_executor(self) -> None:
"""Initialize the worker and load the model.

If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()

def _get_worker_kwargs(
self,
Expand All @@ -45,66 +44,30 @@ def _get_worker_kwargs(
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0,
)

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):

if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"

wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker

def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()

def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None

from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker

target_worker = self._create_worker()

draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)

spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)

assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

self.driver_worker = spec_decode_worker

# Load model handled in spec decode worker.
self.driver_worker.init_device()

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
Expand Down
Loading
Loading