diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 6e3b679cb81b2..e1639d85f8833 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -18,6 +18,7 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM( model=args.model, + draft_model=args.draft_model, tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, @@ -26,11 +27,13 @@ def main(args: argparse.Namespace): enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, device=args.device, + use_flash_attn=args.use_flash_attn, + parallel_decoding_lookahead=args.parallel_decoding_lookahead, ) sampling_params = SamplingParams( n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, + temperature=0.0 if args.use_beam_search else args.temperature, top_p=1.0, use_beam_search=args.use_beam_search, ignore_eos=True, @@ -89,6 +92,7 @@ def run_to_completion(profile_dir: Optional[str] = None): description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -103,6 +107,7 @@ def run_to_completion(profile_dir: Optional[str] = None): default=1, help='Number of generated sequences per prompt.') parser.add_argument('--use-beam-search', action='store_true') + parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument('--num-iters', type=int, default=3, @@ -145,5 +150,14 @@ def run_to_completion(profile_dir: Optional[str] = None): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument( + "--use-flash-attn", + action="store_true", + help="Use flash attention (requires flash-attn >= 2.5.0).") + parser.add_argument( + "--parallel-decoding-lookahead", + type=int, + default=1, + help="Number of lookahead steps for speculativespeculative decoding.") args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ad502526c97c..10d5996023200 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -61,22 +61,27 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], model: str, + draft_model: str, tokenizer: str, quantization: Optional[str], tensor_parallel_size: int, seed: int, n: int, use_beam_search: bool, + temperature: float, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, device: str, + use_flash_attn: Optional[bool] = False, + parallel_decoding_lookahead: Optional[int] = 1, ) -> float: from vllm import LLM, SamplingParams llm = LLM( model=model, + draft_model=draft_model, tokenizer=tokenizer, quantization=quantization, tensor_parallel_size=tensor_parallel_size, @@ -87,13 +92,15 @@ def run_vllm( enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, device=device, + use_flash_attn=use_flash_attn, + parallel_decoding_lookahead=parallel_decoding_lookahead, ) # Add the requests to the engine. for prompt, _, output_len in requests: sampling_params = SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=0.0 if use_beam_search else temperature, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=True, @@ -206,12 +213,13 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, args.device) + elapsed_time = run_vllm( + requests, args.model, args.draft_model, args.tokenizer, + args.quantization, args.tensor_parallel_size, args.seed, args.n, + args.use_beam_search, args.temperature, args.trust_remote_code, + args.dtype, args.max_model_len, args.enforce_eager, + args.kv_cache_dtype, args.device, args.use_flash_attn, + args.parallel_decoding_lookahead) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -248,6 +256,7 @@ def main(args: argparse.Namespace): help="Output length for each request. Overrides the " "output length from the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', @@ -259,6 +268,7 @@ def main(args: argparse.Namespace): default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--num-prompts", type=int, default=1000, @@ -302,6 +312,15 @@ def main(args: argparse.Namespace): default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') + parser.add_argument( + "--use-flash-attn", + action="store_true", + help="Use flash attention (requires flash-attn >= 2.5.0).") + parser.add_argument( + "--parallel-decoding-lookahead", + type=int, + default=1, + help="Number of lookahead steps for speculative decoding.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_attention.py b/benchmarks/kernels/benchmark_attention.py new file mode 100644 index 0000000000000..07fabb5028986 --- /dev/null +++ b/benchmarks/kernels/benchmark_attention.py @@ -0,0 +1,258 @@ +from typing import Optional +import argparse +import random +import time + +import numpy as np +import torch + +try: + from flash_attn import flash_attn_func, flash_attn_with_kvcache +except ImportError: + flash_attn_func, flash_attn_with_kvcache = None, None + +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + +from vllm._C import cache_ops +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random + +NUM_BLOCKS = 1024 + + +@torch.inference_mode() +def main( + version: str, + num_seqs: int, + context_len: int, + num_query_heads: int, + num_kv_heads: int, + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + do_profile: bool, + device: str = "cuda", + kv_cache_dtype: Optional[str] = None, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + use_flash_attn = version in ["flash-attn", "flash-attn-kvcache"] + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + raise ValueError( + "skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + + context_lens = [context_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=device) + zero_context_lens_tensor = torch.zeros_like(context_lens_tensor) + + scale = float(1.0 / (head_size**0.5)) + qkv = torch.empty(num_seqs, + max_context_len, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype, + device=device) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=2) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device=device) + + # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables, slot_mapping = [], [] + for seq_idx in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + slot_mapping.append([]) + for i in range(context_lens[seq_idx]): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping[-1].append(slot) + for _ in range(max_context_len - context_lens[seq_idx]): + slot_mapping[-1].append(-1) + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) + + # Create the KV cache. + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + use_flash_attn=use_flash_attn) + key_cache, value_cache = key_caches[0], value_caches[0] + + if version == "xformers": + attn_bias = BlockDiagonalCausalMask.from_seqlens(context_lens) + if num_queries_per_kv > 1: + # Handle MQA and GQA + key_repeated = torch.repeat_interleave(key, + num_queries_per_kv, + dim=2) + value_repeated = torch.repeat_interleave(value, + num_queries_per_kv, + dim=2) + else: + key_repeated = key + value_repeated = value + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if version == "xformers": + cache_ops.reshape_and_cache( + key.reshape(-1, *key.shape[2:]), + value.reshape(-1, *key.shape[2:]), + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + ) + output = xops.memory_efficient_attention_forward( + query.reshape(1, -1, *query.shape[2:]), + key_repeated.reshape(1, -1, *key_repeated.shape[2:]), + value_repeated.reshape(1, -1, *value_repeated.shape[2:]), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.reshape(query.shape) + elif version == "flash-attn": + flat_slot_mapping = slot_mapping.flatten() + slot_block_index = flat_slot_mapping // block_size + slot_block_offset = flat_slot_mapping % block_size + key_cache[slot_block_index, + slot_block_offset, :, :] = key.reshape( + -1, *key.shape[2:]) + value_cache[slot_block_index, + slot_block_offset, :, :] = value.reshape( + -1, *key.shape[2:]) + output = flash_attn_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + elif version == "flash-attn-kvcache": + output = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + k=key, + v=value, + cache_seqlens=zero_context_lens_tensor, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + else: + raise ValueError(f"Invalid version: {version}") + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=3, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=100, profile=False) + print( + f"Version: {version}, Context Length: {context_len}, Batch size: {num_seqs}, Kernel running time: {latency * 1000000:.3f} us" + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the paged attention kernel.") + parser.add_argument( + "--version", + type=str, + choices=["xformers", "flash-attn", "flash-attn-kvcache"], + default="xformers") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--use-alibi", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8_e5m2"], + default="auto", + help= + 'Data type for kv cache storage. If "auto", will use model data type.') + parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") + args = parser.parse_args() + print(args) + + if args.num_query_heads % args.num_kv_heads != 0: + raise ValueError("num_query_heads must be divisible by num_kv_heads") + main( + version=args.version, + num_seqs=args.batch_size, + context_len=args.context_len, + num_query_heads=args.num_query_heads, + num_kv_heads=args.num_kv_heads, + head_size=args.head_size, + block_size=args.block_size, + use_alibi=args.use_alibi, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + kv_cache_dtype=args.kv_cache_dtype, + ) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d921dea1220e1..dc30d8a1dcc88 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,6 +5,11 @@ import torch +try: + from flash_attn import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops @@ -33,6 +38,14 @@ def main( if torch.cuda.is_available(): torch.cuda.manual_seed(seed) + use_flash_attn = version == "flash-attn" + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + raise ValueError( + "skip: flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, num_query_heads, @@ -53,6 +66,8 @@ def main( context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): @@ -64,14 +79,16 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + use_flash_attn=use_flash_attn) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -131,6 +148,17 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: alibi_slopes, kv_cache_dtype, ) + elif version == "flash-attn": + flash_attn_with_kvcache( + q=query.reshape(num_seqs, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=context_lens, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -158,7 +186,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, - choices=["v1", "v2"], + choices=["v1", "v2", "flash-attn"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--context-len", type=int, default=4096) diff --git a/requirements.txt b/requirements.txt index de08bd29beaf9..44ce399aab311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ aioprometheus[starlette] pynvml == 11.5.0 triton >= 2.1.0 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +packaging +flash-attn >= 2.5.0 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fb571de63d4e1..add9c992939f9 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,6 +6,11 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache +except ImportError: + flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache = None, None, None + from vllm._C import ops, cache_ops from vllm.utils import get_max_shared_memory_bytes from vllm.utils import is_hip @@ -111,7 +116,7 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("version", ["v1", "v2", "flash-attn"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -144,6 +149,16 @@ def test_paged_attention( query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(-scale, scale) + use_flash_attn = version == "flash-attn" + if use_flash_attn: + if dtype not in [torch.half, torch.bfloat16 + ] or kv_cache_dtype != "auto": + pytest.skip( + "flash-attn requires dtype and kv_cache_dtype to be half or bfloat16" + ) + if head_size >= 128: + pytest.skip("flash-attn tests may OOM due to larger block size") + assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None @@ -156,6 +171,8 @@ def test_paged_attention( context_lens = torch.tensor(context_lens, dtype=torch.int) # Create the block tables. + if use_flash_attn: + block_size = ((block_size + 256 - 1) // 256) * 256 max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): @@ -170,7 +187,7 @@ def test_paged_attention( key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, dtype, seed, - device) + device, use_flash_attn) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -221,13 +238,30 @@ def test_paged_attention( alibi_slopes, kv_cache_dtype, ) + elif version == "flash-attn": + output = flash_attn_with_kvcache( + q=query.reshape(num_seqs, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=context_lens, + block_table=block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + output = output.reshape_as(query) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. + x = 16 // torch.tensor([], dtype=dtype).element_size() + if use_flash_attn: + key_cache = key_cache.unflatten(-1, (head_size // x, x)).permute( + 0, 2, 4, 1, 3) + value_cache = value_cache.permute(0, 2, 3, 1) + if kv_cache_dtype == "fp8_e5m2": # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) dequantized_key_cache = torch.empty(size=key_cache_shape, @@ -266,7 +300,9 @@ def test_paged_attention( # so we use a relaxed tolerance for the test. if kv_cache_dtype == "fp8_e5m2": atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + if use_flash_attn and use_alibi: + atol, rtol = 2e-1, 5e-2 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) def ref_multi_query_kv_attention( @@ -303,6 +339,8 @@ def ref_multi_query_kv_attention( # TODO(woosuk): Add tests for USE_ALIBI=True. +@pytest.mark.parametrize("version", + ["xformers", "flash-attn", "flash-attn-varlen"]) @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -311,6 +349,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( + version: str, num_seqs: int, num_heads: Tuple[int, int], head_size: int, @@ -329,6 +368,17 @@ def test_multi_query_kv_attention( max_len = min(MAX_SEQ_LEN, 4096) seq_lens = random.sample(range(1, max_len), num_seqs) num_tokens = sum(seq_lens) + max_seq_len = max(seq_lens) + + use_flash_attn = version in ["flash-attn", "flash-attn-varlen"] + if use_flash_attn and dtype not in [torch.half, torch.bfloat16]: + pytest.skip( + "flash-attn requires kv_cache_dtype to be half or bfloat16") + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device=device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -343,30 +393,81 @@ def test_multi_query_kv_attention( num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + key_repeated = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value_repeated = torch.repeat_interleave(value, + num_queries_per_kv, + dim=1) + else: + key_repeated = key + value_repeated = value + + if version == "xformers": + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key_repeated.unsqueeze(0), + value_repeated.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + elif version == "flash-attn": + # padding the inputs, use the same logic with batched prefill + # in attention.py. + qs, ks, vs = [], [], [] + for i, seq_len in enumerate(seq_lens): + left, right = cu_seq_lens[i], cu_seq_lens[i + 1] + qs.append( + torch.nn.functional.pad( + query[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + ks.append( + torch.nn.functional.pad( + key[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + vs.append( + torch.nn.functional.pad( + value[left:right], (0, 0, 0, 0, 0, max_seq_len - seq_len))) + query_padded = torch.stack(qs, dim=0) + key_padded = torch.stack(ks, dim=0) + value_padded = torch.stack(vs, dim=0) + + output = flash_attn_func( + query_padded, + key_padded, + value_padded, + softmax_scale=scale, + causal=True, + ) + outputs = [] + for i, seq_len in enumerate(seq_lens): + outputs.append(output[i, :seq_len]) + output = torch.cat(outputs, dim=0) + elif version == "flash-attn-varlen": + # We test `flash_attn_varlen_func` here (which is more equalivant to + # xformers's MEAF kernel), but it is not actually used in attention.py + # for prefilling as inputs are padded in vLLM. + output = flash_attn_varlen_func( + query, + key, + value, + cu_seqlens_q=cu_seq_lens, + cu_seqlens_k=cu_seq_lens, + max_seqlen_q=max_seq_len, + max_seqlen_k=max_seq_len, + softmax_scale=scale, + causal=True, + ) + else: + raise AssertionError(f"Unknown version: {version}") - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, - key, - value, + key_repeated, + value_repeated, scale, dtype, ) atol = get_default_atol(output) if is_hip() else 1e-3 rtol = get_default_rtol(output) if is_hip() else 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ac93b32588cca..c068b38a66910 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -8,7 +8,8 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -NUM_HEADS = [12] +NUM_HEADS = [64] +NUM_QUERIES_PER_KV = [1, 8, 64] HEAD_SIZES = [128] DTYPES = [torch.float16] CUDA_DEVICES = [ @@ -17,12 +18,14 @@ @pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, + num_queries_per_kv: int, head_size: int, dtype: torch.dtype, device: str, @@ -41,28 +44,29 @@ def test_contexted_kv_attention( subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(subquery_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype) + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) k_cache = torch.zeros(cache_size, block_size, - num_heads, + num_kv_heads, head_size, dtype=dtype) v_cache = torch.zeros(cache_size, block_size, - num_heads, + num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) + k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( @@ -93,19 +97,21 @@ def test_contexted_kv_attention( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8).permute(0, 2, 3, 1, 4).contiguous() # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_heads, + v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() # Warm up the Triton kernel by calling it once before actually measuring generation time @@ -123,12 +129,29 @@ def test_contexted_kv_attention( attn_op = xops.fmha.cutlass.FwOp() + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( subquery_lens, seq_lens) output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), + query, + key, + value, attn_bias=attn_bias, p=0.0, scale=scale, @@ -137,9 +160,9 @@ def test_contexted_kv_attention( torch.cuda.synchronize() start_time = time.time() output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), + query, + key, + value, attn_bias=attn_bias, p=0.0, scale=scale, @@ -148,5 +171,5 @@ def test_contexted_kv_attention( torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0) + output_ref = output_ref.squeeze(0, 2) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 8d74509fea488..36fed1e205e5d 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -84,7 +84,7 @@ def create_worker(cls: type, ) (model_config, cache_config, parallel_config, scheduler_config, - device_config, _) = engine_args.create_engine_configs() + device_config, _, _) = engine_args.create_engine_configs() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) diff --git a/vllm/config.py b/vllm/config.py index 0b8a2a27f6d43..a587ae7dac5a5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,6 +6,11 @@ import torch from transformers import PretrainedConfig +try: + import flash_attn +except ImportError: + flash_attn = None + from vllm.logger import init_logger from vllm.transformers_utils.config import get_config from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version @@ -79,6 +84,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + use_flash_attn: Optional[bool] = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -93,6 +99,7 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.use_flash_attn = use_flash_attn if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -117,6 +124,7 @@ def __init__( self._verify_tokenizer_mode() self._verify_quantization() self._verify_cuda_graph() + self._verify_flash_attn() def _verify_load_format(self) -> None: load_format = self.load_format.lower() @@ -193,6 +201,18 @@ def _verify_cuda_graph(self) -> None: self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) + def _verify_flash_attn(self) -> None: + if flash_attn is None: + raise ValueError( + "flash-attn is not installed. Please install flash-attn>=2.5.0 to use " + "the flash attention kernel.") + if Version(flash_attn.__version__) < Version("2.5.0"): + raise ValueError( + "flash-attn >= 2.5.0 is required. Please upgrade flash-attn to " + "the latest version.") + if is_hip(): + raise ValueError("flash-attn cannot doesn't support ROCm.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -427,6 +447,7 @@ class SchedulerConfig: max_model_len: Maximum length of a sequence (including prompt and generated text). max_paddings: Maximum number of paddings to be added to a batch. + parallel_decoding_lookahead: Number of tokens to look ahead for parallel decoding. """ def __init__( @@ -435,6 +456,7 @@ def __init__( max_num_seqs: int, max_model_len: int, max_paddings: int, + parallel_decoding_lookahead: Optional[int] = 1, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -445,6 +467,7 @@ def __init__( self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.max_paddings = max_paddings + self.parallel_decoding_lookahead = parallel_decoding_lookahead self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3946096d4296a..a55fb184cb15a 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -170,6 +170,17 @@ def can_append_slot(self, seq_group: SequenceGroup) -> bool: num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) return num_seqs <= num_free_gpu_blocks + def can_append_slots(self, + seq_group: SequenceGroup, + reserve: Optional[int] = 1) -> bool: + # Simple heuristic: as the maximum possible parallel decoding lookahead + # is 8 (less than block size), if there is at least one free block for + # each sequence, we can append. + assert reserve <= self.block_size, f"Expect reserve <= block_size, got {reserve} > {self.block_size}" + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5e7cc3091d775..d96867e952ca1 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -39,6 +39,7 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + parallel_decoding_lookahead: Optional[int] = 1, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -49,6 +50,7 @@ def __init__( # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + self.parallel_decoding_lookahead = parallel_decoding_lookahead self.num_loras = len(self.lora_requests) if self.num_loras > 0: @@ -69,6 +71,17 @@ def _sort_by_lora_ids(self) -> bool: def lora_requests(self) -> Set[LoRARequest]: return {g.lora_request for g in self.scheduled_seq_groups} + def __str__(self) -> str: + return ( + f"SchedulerOutputs(scheduled_seq_groups={self.scheduled_seq_groups}, " + f"prompt_run={self.prompt_run}, " + f"num_batched_tokens={self.num_batched_tokens}, " + f"blocks_to_swap_in={self.blocks_to_swap_in}, " + f"blocks_to_swap_out={self.blocks_to_swap_out}, " + f"blocks_to_copy={self.blocks_to_copy}, " + f"ignored_seq_groups={self.ignored_seq_groups}, " + f"parallel_decoding_lookahead={self.parallel_decoding_lookahead})") + class Scheduler: @@ -279,7 +292,9 @@ def _schedule(self) -> SchedulerOutputs: preempted: List[SequenceGroup] = [] while self.running: seq_group = self.running.popleft() - while not self.block_manager.can_append_slot(seq_group): + while not self.block_manager.can_append_slots( + seq_group, + reserve=self.scheduler_config.parallel_decoding_lookahead): if self.running: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop() @@ -293,6 +308,9 @@ def _schedule(self) -> SchedulerOutputs: break else: # Append new slots to the sequence group. + self._reserve_logical_slots(seq_group, + lookahead=self.scheduler_config. + parallel_decoding_lookahead) self._append_slot(seq_group, blocks_to_copy) running.append(seq_group) self.running = running @@ -336,6 +354,9 @@ def _schedule(self) -> SchedulerOutputs: curr_loras.add(lora_int_id) self.swapped.popleft() self._swap_in(seq_group, blocks_to_swap_in) + self._reserve_logical_slots(seq_group, + lookahead=self.scheduler_config. + parallel_decoding_lookahead) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) @@ -349,14 +370,29 @@ def _schedule(self) -> SchedulerOutputs: seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running) + lookahead = 1 + for seq_group in self.running: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + k = self.scheduler_config.parallel_decoding_lookahead + k = min(k, self.scheduler_config.max_model_len - seq.get_len()) + if seq_group.sampling_params.max_tokens: + k = min( + k, seq_group.sampling_params.max_tokens - + seq.get_output_len()) + lookahead = max(lookahead, k) + + if lookahead > 1: + num_batched_tokens *= lookahead + scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=self.running, + scheduled_seq_groups=running, prompt_run=False, num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], + parallel_decoding_lookahead=lookahead, ) return scheduler_outputs @@ -388,6 +424,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: lora_request=seq_group.lora_request, prefix=seq_group.prefix, state=seq_group.state, + parallel_decoding_lookahead=scheduler_outputs. + parallel_decoding_lookahead, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs @@ -407,6 +445,12 @@ def _allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING + def _reserve_logical_slots(self, + seq_group: SequenceGroup, + lookahead: int = 1) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.reserve_logical_blocks(lookahead - 1) + def _append_slot( self, seq_group: SequenceGroup, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4efd171b871d..9f05eb7cf7d4f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,6 +11,7 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str + draft_model: Optional[str] = None tokenizer: Optional[str] = None tokenizer_mode: str = 'auto' trust_remote_code: bool = False @@ -45,6 +46,8 @@ class EngineArgs: lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'cuda' + use_flash_attn: Optional[bool] = False + parallel_decoding_lookahead: Optional[int] = 1 def __post_init__(self): if self.tokenizer is None: @@ -64,6 +67,12 @@ def add_cli_args( type=str, default='facebook/opt-125m', help='name or path of the huggingface model to use') + parser.add_argument( + '--draft-model', + type=str, + default=None, + help='name or path of the huggingface model to use for draft ' + 'generation.') parser.add_argument( '--tokenizer', type=str, @@ -271,6 +280,15 @@ def add_cli_args( choices=["cuda"], help=('Device type for vLLM execution. ' 'Currently, only CUDA-compatible devices are supported.')) + parser.add_argument( + '--use-flash-attn', + action='store_true', + help='Use flash attention (requires flash-attn >= 2.5.0).') + parser.add_argument( + '--parallel-decoding-lookahead', + type=int, + default=EngineArgs.parallel_decoding_lookahead, + help='Number of tokens to look ahead during speculative decoding') return parser @classmethod @@ -284,18 +302,33 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - DeviceConfig, Optional[LoRAConfig]]: + DeviceConfig, Optional[LoRAConfig], Optional[ModelConfig]]: + + assert self.parallel_decoding_lookahead == 1 or self.draft_model is not None, \ + 'parallel_decoding_lookahead > 1 requires draft_model to be specified.' + # discard the draft model if parallel_decoding_lookahead == 1 \ + # and draft model is not required + if self.parallel_decoding_lookahead == 1: + self.draft_model = None + + if self.use_flash_attn: + # flash-attn's flash_attn_with_kvcache requires block size must be + # a multiple of 256. + self.block_size = ((self.block_size + 256 - 1) // 256) * 256 + device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, - self.enforce_eager, self.max_context_len_to_capture) + self.enforce_eager, self.max_context_len_to_capture, + self.use_flash_attn) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) + parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, @@ -304,7 +337,9 @@ def create_engine_configs( scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, - self.max_paddings) + self.max_paddings, + self.parallel_decoding_lookahead) + lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, @@ -312,8 +347,17 @@ def create_engine_configs( lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + draft_model_config = ModelConfig( + self.draft_model, self.tokenizer, self.tokenizer_mode, + self.trust_remote_code, self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, self.code_revision, + self.tokenizer_revision, self.max_model_len, self.quantization, + self.enforce_eager, self.max_context_len_to_capture, + self.use_flash_attn) if self.draft_model else None + return (model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config) + device_config, lora_config, draft_model_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7cba654602779..8b6195ae1b1c4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput logger = init_logger(__name__) @@ -186,7 +187,42 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. - all_outputs = await self._run_workers_async( + + # execute the draft model + if scheduler_outputs.prompt_run: + draft_step_func = 'execute_model' + kwargs = {} + else: + draft_step_func = 'execute_model_multi_step' + kwargs = { + # speculative decoding: sampling (k-1) token from the draft model + "num_steps": + scheduler_outputs.parallel_decoding_lookahead - 1, + } + all_draft_outputs: List[ + SamplerOutput] = await self._run_workers_async( + draft_step_func, + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": + scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": + scheduler_outputs.blocks_to_swap_out, + # blocks copy won't happen to draft model: no prefix cache and no beam search + "blocks_to_copy": {}, + **kwargs, + }, + driver_worker=self.draft_driver_worker, + workers=self.draft_workers) + draft_output = all_draft_outputs[0] if len( + all_draft_outputs) > 0 else None + + # add possible draft tokens to the sequences + self._apply_draft_output(seq_group_metadata_list, + scheduler_outputs, draft_output) + + # execute the target model + all_outputs: List[SamplerOutput] = await self._run_workers_async( "execute_model", driver_kwargs={ "seq_group_metadata_list": seq_group_metadata_list, @@ -197,6 +233,11 @@ async def step_async(self) -> List[RequestOutput]: # Only the driver worker returns the sampling results. output = all_outputs[0] + + # apply reject sampling + self._apply_reject_sampling(seq_group_metadata_list, + scheduler_outputs, draft_output, + output) else: output = [] @@ -254,26 +295,39 @@ async def _run_workers_async( *args, driver_args: Optional[List[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, + # use `Ellipsis` as the default argument value as `None` will be + # treat as given worker (usually the draft worker) is not available. + driver_worker: List[Any] = Ellipsis, # List[Worker] + workers: Any = Ellipsis, # List[Worker] **kwargs, ) -> Any: """Runs the given method on all workers.""" coros = [] + if driver_worker is Ellipsis: + driver_worker = self.driver_worker + if workers is Ellipsis: + workers = self.workers + if driver_args is None: driver_args = args if driver_kwargs is None: driver_kwargs = kwargs # Run the driver worker asynchronously. - driver_executor = getattr(self.driver_worker, method) - coros.append(asyncio.get_event_loop().run_in_executor( - None, partial(driver_executor, *driver_args, **driver_kwargs))) + if driver_worker is not None: + driver_executor = getattr(driver_worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(driver_executor, *driver_args, **driver_kwargs))) - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) + # Run the ray workers asynchronously. + for worker in workers: + coros.append( + worker.execute_method.remote(method, *args, **kwargs)) - all_outputs = await asyncio.gather(*coros) + all_outputs = await asyncio.gather(*coros) + else: + all_outputs = [None] return all_outputs diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0de40f54db61..d49163c17ce91 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) -from vllm.lora.request import LoRARequest +import torch + from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs @@ -14,12 +15,14 @@ from vllm.engine.metrics import StatLogger, Stats from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) + SequenceGroupMetadata, SequenceGroupOutput, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - TokenizerGroup) + get_tokenizer, TokenizerGroup) from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method if ray: @@ -73,12 +76,14 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + draft_model_config: Optional[ModelConfig], placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: logger.info( "Initializing an LLM engine with config: " f"model={model_config.model!r}, " + f"draft_model={draft_model_config.model if draft_model_config else None !r}, " f"tokenizer={model_config.tokenizer!r}, " f"tokenizer_mode={model_config.tokenizer_mode}, " f"revision={model_config.revision}, " @@ -103,21 +108,51 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.draft_model_config = draft_model_config self.log_stats = log_stats self._verify_args() + self._verify_vocab() self._init_tokenizer() self.seq_counter = Counter() + self.driver_worker = None + self.driver_dummy_worker = None + self.workers = [] + self.draft_driver_worker = None + self.draft_driver_dummy_worker = None + self.draft_workers = [] + + self.reject_sampler = None + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker + # Create the parallel GPU workers. if self.parallel_config.worker_use_ray: # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") if ray_usage != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - self._init_workers_ray(placement_group) + + self.driver_worker, self.driver_dummy_worker, self.workers = self._init_workers_ray( + placement_group, Worker, model_config) + if self.draft_model_config is not None: + self.draft_driver_worker, self.draft_driver_dummy_worker, self.draft_workers = self._init_workers_ray( + placement_group, + MultiStepWorker, + model_config=self.draft_model_config) else: - self._init_workers() + self.driver_worker, self.workers = self._init_workers( + Worker, model_config=model_config) + if self.draft_model_config is not None: + self.draft_driver_worker, self.draft_workers = self._init_workers( + MultiStepWorker, model_config=draft_model_config) + + if self.draft_model_config is not None: + self._init_reject_sampler() # Profile the memory usage and initialize the cache. self._init_cache() @@ -137,19 +172,24 @@ def __init__( def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) - def _init_workers(self): + def _init_workers( + self, + worker_cls: Any, + model_config: ModelConfig, + ) -> Tuple[Any, Optional[Any], List[Any]]: # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker + from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") - self.workers: List[Worker] = [] + workers: List[Union[Worker, MultiStepWorker]] = [] distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = Worker( - self.model_config, + driver_worker = worker_cls( + model_config, self.parallel_config, self.scheduler_config, self.device_config, @@ -160,8 +200,19 @@ def _init_workers(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) - self._run_workers("init_model") - self._run_workers("load_model") + self._run_workers("init_model", + driver_worker=driver_worker, + workers=workers) + self._run_workers("load_model", + driver_worker=driver_worker, + workers=workers) + return driver_worker, workers + + def _init_reject_sampler(self): + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + + self.reject_sampler = RejectionSampler(strict_mode=False) + self.reject_sampler.init_gpu_tensors(rank=0) def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( @@ -175,15 +226,20 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): self.tokenizer: TokenizerGroup = TokenizerGroup( self.model_config.tokenizer, **init_kwargs) - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray( + self, + placement_group: "PlacementGroup", + worker_cls: Any, + model_config: ModelConfig, + **ray_remote_kwargs, + ) -> Tuple[RayWorkerVllm, RayWorkerVllm, List[RayWorkerVllm]]: if self.parallel_config.tensor_parallel_size == 1: num_gpus = self.cache_config.gpu_memory_utilization else: num_gpus = 1 - self.driver_dummy_worker: RayWorkerVllm = None - self.workers: List[RayWorkerVllm] = [] + driver_dummy_worker: RayWorkerVllm = None + workers: List[RayWorkerVllm] = [] driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): @@ -199,26 +255,26 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + )(RayWorkerVllm).remote(model_config.trust_remote_code) worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: + if worker_ip == driver_ip and driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. - self.driver_dummy_worker = worker + driver_dummy_worker = worker else: - self.workers.append(worker) + workers.append(worker) - if self.driver_dummy_worker is None: + if driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + driver_dummy_worker.get_node_and_gpu_ids.remote()) worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + [worker.get_node_and_gpu_ids.remote() for worker in workers]) node_workers = defaultdict(list) node_gpus = defaultdict(list) @@ -234,29 +290,25 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Set CUDA_VISIBLE_DEVICES for the driver. set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + for worker, (node_id, _) in zip(workers, worker_node_and_gpu_ids): worker.set_cuda_visible_devices.remote(node_gpus[node_id]) distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - # Initialize torch distributed process group for the workers. - model_config = copy.deepcopy(self.model_config) + model_config = copy.deepcopy(model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) device_config = copy.deepcopy(self.device_config) for rank, (worker, (node_id, - _)) in enumerate(zip(self.workers, + _)) in enumerate(zip(workers, worker_node_and_gpu_ids), start=1): local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( - lambda rank=rank, local_rank=local_rank: Worker( + lambda rank=rank, local_rank=local_rank: worker_cls( model_config, parallel_config, scheduler_config, @@ -270,7 +322,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( + driver_worker = worker_cls( model_config, parallel_config, scheduler_config, @@ -283,12 +335,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - self._run_workers("init_model", cupy_port=get_open_port()) + self._run_workers("init_model", + cupy_port=get_open_port(), + driver_worker=driver_worker, + workers=workers) self._run_workers( "load_model", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers, + driver_worker=driver_worker, + workers=workers, ) + return driver_worker, driver_dummy_worker, workers def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -298,6 +356,27 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) + def _verify_vocab(self) -> None: + if not self.draft_model_config: + return + + target_tokenizer = get_tokenizer(self.model_config.tokenizer) + draft_tokenizer = get_tokenizer(self.draft_model_config.tokenizer) + assert target_tokenizer.vocab_size == draft_tokenizer.vocab_size, ( + f"Target model's vocab size ({target_tokenizer.vocab_size}) " + f"does not match with draft model's vocab size " + f"({draft_tokenizer.vocab_size}).") + + # referred from llama.cpp + SPEC_VOCAB_CHECK_START_TOKEN_ID = 5 + for i in range(SPEC_VOCAB_CHECK_START_TOKEN_ID, + target_tokenizer.vocab_size): + target_token = target_tokenizer.convert_ids_to_tokens(i) + draft_token = draft_tokenizer.convert_ids_to_tokens(i) + assert target_token == draft_token, ( + f"Target model's vocab does not match with draft model's vocab " + f"at index {i}: {target_token} != {draft_token}") + def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache. @@ -319,6 +398,8 @@ def _init_cache(self) -> None: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameters. """ + from vllm.worker.cache_engine import CacheEngine + # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", @@ -333,6 +414,26 @@ def _init_cache(self) -> None: # operators can be applied to all workers. num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) + + # redistribute kv-cache blocks between target and draft model + cache_block_size = CacheEngine.get_cache_block_size( + self.cache_config.block_size, + self.cache_config.cache_dtype, + self.model_config, + self.parallel_config, + ) + gpu_blocks_memory = num_gpu_blocks * self.cache_config.block_size * cache_block_size + + if self.draft_model_config is not None: + cache_block_size += CacheEngine.get_cache_block_size( + self.cache_config.block_size, + self.cache_config.cache_dtype, + self.draft_model_config, + self.parallel_config, + ) + num_gpu_blocks = gpu_blocks_memory // (self.cache_config.block_size * + cache_block_size) + # FIXME(woosuk): Change to debug log. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") @@ -359,6 +460,17 @@ def _init_cache(self) -> None: # if enforce_eager is False. self._run_workers("warm_up_model") + # Initialize the cache. + self._run_workers("init_cache_engine", + cache_config=self.cache_config, + driver_worker=self.draft_driver_worker, + workers=self.draft_workers) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model", + driver_worker=self.draft_driver_worker, + workers=self.draft_workers) + @classmethod def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -453,6 +565,10 @@ def add_request( prompt_token_ids=prompt_token_ids, lora_request=lora_request) + if self.scheduler_config.parallel_decoding_lookahead > 1 and sampling_params.use_beam_search: + raise NotImplementedError( + "Speculative decoding doesn't support beam search.") + # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -573,6 +689,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, parent_child_dict[sample.parent_seq_id].append(sample) # List of (child, parent) child_seqs: List[Tuple[Sequence, Sequence]] = [] + child_seqs_accepted: List[int] = [] # Process the child samples for each parent sequence for parent in parent_seqs: @@ -590,19 +707,37 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) + if isinstance(child_sample.output_token, (list, tuple)): + for output_token, logprobs in zip( + child_sample.output_token, child_sample.logprobs): + child.append_token_id(output_token, logprobs) + child_seqs_accepted.append(len(child_sample.output_token)) + else: + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs_accepted.append(1) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + if isinstance(last_child_sample.output_token, (list, tuple)): + for output_token, logprobs in zip( + last_child_sample.output_token, + last_child_sample.logprobs): + parent.append_token_id(output_token, logprobs) + child_seqs_accepted.append( + len(last_child_sample.output_token)) + else: + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs_accepted.append(1) child_seqs.append((parent, parent)) - for seq, _ in child_seqs: - self._decode_sequence(seq, seq_group.sampling_params) + for (seq, _), accepted in zip(child_seqs, child_seqs_accepted): + self._decode_sequence(seq, + seq_group.sampling_params, + accepted=accepted) self._check_stop(seq, seq_group.sampling_params) # Non-beam search case @@ -756,7 +891,6 @@ def _process_model_outputs( # Log stats. if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs)) - return request_outputs def step(self) -> List[RequestOutput]: @@ -814,7 +948,39 @@ def step(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. - all_outputs = self._run_workers( + + # execute the draft model + if scheduler_outputs.prompt_run: + draft_step_func = 'execute_model' + kwargs = {} + else: + draft_step_func = 'execute_model_multi_step' + kwargs = { + # speculative decoding: sampling (k-1) token from the draft model + "num_steps": + scheduler_outputs.parallel_decoding_lookahead - 1, + } + all_draft_outputs: List[SamplerOutput] = self._run_workers( + draft_step_func, + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + # blocks copy won't happen to draft model: no prefix cache and no beam search + "blocks_to_copy": {}, + **kwargs, + }, + driver_worker=self.draft_driver_worker, + workers=self.draft_workers) + draft_output = all_draft_outputs[0] if len( + all_draft_outputs) > 0 else None + + # add possible draft tokens to the sequences + self._apply_draft_output(seq_group_metadata_list, + scheduler_outputs, draft_output) + + # execute the target model + all_outputs: List[SamplerOutput] = self._run_workers( "execute_model", driver_kwargs={ "seq_group_metadata_list": seq_group_metadata_list, @@ -826,18 +992,189 @@ def step(self) -> List[RequestOutput]: # Only the driver worker returns the sampling results. output = all_outputs[0] + + # apply reject sampling + self._apply_reject_sampling(seq_group_metadata_list, + scheduler_outputs, draft_output, + output) else: output = [] return self._process_model_outputs(output, scheduler_outputs) + def _apply_draft_output( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + draft_output: List[SamplerOutput], + ): + if draft_output is None or \ + scheduler_outputs.prompt_run or \ + scheduler_outputs.parallel_decoding_lookahead == 1: + # clear the candidate token ids from previous step + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for _, seq_data in seq_group_metadata.seq_data.items(): + seq_data.candidate_token_ids = [] + return + + batch_size = len(seq_group_metadata_list) + num_draft_tokens = scheduler_outputs.parallel_decoding_lookahead - 1 + draft_token_ids = [[-1] * num_draft_tokens for _ in range(batch_size)] + + for j, seq_group_outputs in enumerate(draft_output): + for i, seq_group_output in enumerate(seq_group_outputs): + assert len(seq_group_output.samples + ) == 1, "no beam search when speculative decoding" + if isinstance(seq_group_output.samples[0].output_token, list): + draft_token_ids[i][j] = seq_group_output.samples[ + 0].output_token[-1] + else: + draft_token_ids[i][j] = seq_group_output.samples[ + 0].output_token + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for _, seq_data in seq_group_metadata.seq_data.items(): + seq_data.candidate_token_ids = draft_token_ids[i] + + def _apply_reject_sampling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + draft_output: List[SamplerOutput], + target_output: SamplerOutput, + ): + if draft_output is None or \ + scheduler_outputs.prompt_run or \ + scheduler_outputs.parallel_decoding_lookahead == 1: + return + + batch_size = len(seq_group_metadata_list) + num_draft_tokens = scheduler_outputs.parallel_decoding_lookahead - 1 + + target_probs = torch.empty(batch_size, + num_draft_tokens, + self.model_config.get_vocab_size(), + dtype=torch.float, + device='cuda') + bonus_token_ids = [[] for _ in range(batch_size)] + draft_probs = torch.empty(batch_size, + num_draft_tokens, + self.model_config.get_vocab_size(), + dtype=torch.float, + device='cuda') + draft_token_ids = [[-1] * num_draft_tokens for _ in range(batch_size)] + + for j, seq_group_outputs in enumerate(draft_output): + for i, seq_group_output in enumerate(seq_group_outputs): + assert len(seq_group_output.samples + ) == 1, "no beam search when speculative decoding" + if isinstance(seq_group_output.samples[0].output_token, list): + draft_token_ids[i][j] = seq_group_output.samples[ + 0].output_token[-1] + draft_probs[i, + j, :] = seq_group_output.decoding_probs[-1, :] + else: + draft_token_ids[i][j] = seq_group_output.samples[ + 0].output_token + draft_probs[i, j, :] = seq_group_output.decoding_probs + for i, seq_group_output in enumerate(target_output): + assert len(seq_group_output.samples + ) == 1, "no beam search when speculative decoding" + target_probs[i, :, :] = seq_group_output.decoding_probs[:-1, :] + bonus_token_ids[i] = seq_group_output.samples[0].output_token[-1:] + + bonus_token_ids = torch.tensor(bonus_token_ids, + dtype=torch.long, + device='cuda') + draft_token_ids = torch.tensor(draft_token_ids, + dtype=torch.long, + device='cuda') + + # apply the reject sampling + output_token_ids = self.reject_sampler(target_probs=target_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=draft_probs, + draft_token_ids=draft_token_ids) + output_token_ids = output_token_ids.tolist() + num_batch_tokens = 0 + for i, (seq_group_metadata, seq_group, sampled_token_ids, seq_group_output) in \ + enumerate(zip(seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups, + output_token_ids, + target_output)): + ignore_eos = seq_group.sampling_params.ignore_eos + eos_token_id = self.get_tokenizer_for_seq(seq_group).eos_token_id + target_output_tokens = seq_group_output.samples[0].output_token + + output_tokens, has_eos = [], False + # Apply reject sampling when random sampling is used, otherwise + # compare if the draft and target tokens are the same. + # + # see also: https://github.com/huggingface/transformers/blob/45244940725ec1b3e4c390b74dbafe65b298acca/src/transformers/generation/utils.py#L4566-L4597 + if seq_group.sampling_params.sampling_type == SamplingType.GREEDY: + for draft_token_id, token_id in zip(draft_token_ids[i], + target_output_tokens): + if draft_token_id != token_id: + break + output_tokens.append(token_id) + if not ignore_eos and token_id == eos_token_id: + has_eos = True + break + if not has_eos: + if len(output_tokens) == len(draft_token_ids[i]): + output_tokens.append(target_output_tokens[-1]) + else: + output_tokens.append( + target_output_tokens[len(output_tokens)]) + else: + for token_id in sampled_token_ids: + if token_id == -1: + break + output_tokens.append(token_id) + if not ignore_eos and token_id == eos_token_id: + has_eos = True + break + bonus = None if has_eos else output_tokens[-1] + + # When all tokens from the draft model are accepted, there + # would be two tokens kv cache are missed from the draft + # model: the last draft's output and the bonus token. + # + # In such case we add the bonus token as the "candidate" token + # for draft model to run the next decoding pass for computing + # the kv cache. + if len(output_tokens) == len( + target_output_tokens) and bonus is not None: + # set candidate token ids (there should be only one sequence in each sequence + # group if beam-search is not used). + for _, seq_data in seq_group_metadata.seq_data.items(): + seq_data.candidate_token_ids = [bonus] + break + seq_group_metadata.parallel_decoding_lookahead = 2 + else: + # clear the candidate token ids in the sequence data + for _, seq_data in seq_group_metadata.seq_data.items(): + seq_data.candidate_token_ids = [] + break + seq_group_metadata.parallel_decoding_lookahead = 1 + + num_batch_tokens += len(output_tokens) + seq_group_output.samples[0].output_token = output_tokens + scheduler_outputs.num_batched_tokens = num_batch_tokens + + def reset_system_stats(self): + now = time.monotonic() + self.last_logging_time = now + def do_log_stats(self) -> None: """Forced log when no requests active.""" if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs=None)) - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: + def _get_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs], + ) -> Stats: """Get Stats to be Logged to Prometheus.""" now = time.monotonic() @@ -888,11 +1225,27 @@ def _get_stats(self, time_to_first_tokens = time_last_iters if prompt_run else [] time_per_output_tokens = [] if prompt_run else time_last_iters + if self.reject_sampler is not None: + num_draft_tokens = self.reject_sampler.num_draft_tokens + num_bouns_tokens = num_draft_tokens // self.scheduler_config.parallel_decoding_lookahead + num_emitted_tokens = self.reject_sampler.num_emitted_tokens + reject_sampler_accept_rate = ( + self.reject_sampler.num_accepted_tokens / + (self.reject_sampler.num_draft_tokens + num_bouns_tokens) * + 100) + else: + num_draft_tokens = 0 + num_emitted_tokens = 0 + reject_sampler_accept_rate = 0 + return Stats( now=now, num_running=num_running, num_swapped=num_swapped, num_waiting=num_waiting, + num_draft_tokens=num_draft_tokens, + num_emitted_tokens=num_emitted_tokens, + reject_sampler_accept_rate=reject_sampler_accept_rate, gpu_cache_usage=gpu_cache_usage, cpu_cache_usage=cpu_cache_usage, num_prompt_tokens=num_prompt_tokens, @@ -902,8 +1255,20 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: + def _decode_sequence(self, + seq: Sequence, + prms: SamplingParams, + accepted: int = 1) -> None: """Decodes the new token for a sequence.""" + + # statsh current detokenization progress + seq.detokenize_progress = ( + seq.prefix_offset, + seq.read_offset, + len(seq.tokens) if seq.tokens is not None else 0, + len(seq.output_text), + ) + (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( self.get_tokenizer_for_seq(seq), @@ -913,6 +1278,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: read_offset=seq.read_offset, skip_special_tokens=prms.skip_special_tokens, spaces_between_special_tokens=prms.spaces_between_special_tokens, + parallel_decoding_accepted=accepted, ) if seq.tokens is None: seq.tokens = new_tokens @@ -938,12 +1304,13 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: + if seq.get_len() >= self.scheduler_config.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: + if sampling_params.max_tokens and seq.get_output_len( + ) >= sampling_params.max_tokens: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return @@ -985,11 +1352,20 @@ def _run_workers( driver_args: Optional[List[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, + # use `Ellipsis` as the default argument value as `None` will be + # treat as given worker (usually the draft worker) is not available. + driver_worker: List[Any] = Ellipsis, # List[Worker] + workers: Any = Ellipsis, # List[Worker] use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" + if driver_worker is Ellipsis: + driver_worker = self.driver_worker + if workers is Ellipsis: + workers = self.workers + if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") @@ -1002,7 +1378,7 @@ def _run_workers( # Start the ray workers first. ray_worker_outputs = [ worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers + for worker in workers ] if driver_args is None: @@ -1011,11 +1387,15 @@ def _run_workers( driver_kwargs = kwargs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) + if driver_worker is not None: + driver_worker_output = getattr(driver_worker, + method)(*driver_args, + **driver_kwargs) + else: + driver_worker_output = None # Get the results of the ray workers. - if self.workers: + if workers: if use_ray_compiled_dag: try: ray_worker_outputs = [ diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index e613b9f551b2f..72a61c41785fd 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -73,6 +73,11 @@ class Stats: num_running: int num_waiting: int num_swapped: int + + num_draft_tokens: int + num_emitted_tokens: int + reject_sampler_accept_rate: float + gpu_cache_usage: float cpu_cache_usage: float @@ -157,6 +162,15 @@ def log(self, stats: Stats) -> None: prompt_throughput=prompt_throughput, generation_throughput=generation_throughput) + if stats.num_draft_tokens: + specualtive_decoding_message = ( + f'Drafted: {stats.num_draft_tokens}, ' + f'Emitted: {stats.num_emitted_tokens}, ' + f'Accepted rate: {stats.reject_sampler_accept_rate:.2f}%, ' + ) + else: + specualtive_decoding_message = '' + # Log to stdout. logger.info( f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " @@ -164,6 +178,7 @@ def log(self, stats: Stats) -> None: f"Running: {stats.num_running} reqs, " f"Swapped: {stats.num_swapped} reqs, " f"Pending: {stats.num_waiting} reqs, " + f"{specualtive_decoding_message}" f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb6..e72327f3f15a6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -26,6 +26,9 @@ class LLM: Args: model: The name or path of a HuggingFace Transformers model. + draft_model: The name or path of a HuggingFace Transformers model for + draft generation. If None, we use the same model for draft and + refine generations. tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. @@ -70,6 +73,7 @@ class LLM: def __init__( self, model: str, + draft_model: Optional[str] = None, tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", trust_remote_code: bool = False, @@ -90,6 +94,7 @@ def __init__( kwargs["disable_log_stats"] = True engine_args = EngineArgs( model=model, + draft_model=draft_model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f8..77a6b9a007c5a 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -27,6 +27,7 @@ def __init__( block_tables: Optional[torch.Tensor], use_cuda_graph: bool, kv_cache_dtype: str, + use_flash_attn: bool = False, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -38,6 +39,7 @@ def __init__( self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph self.kv_cache_dtype = kv_cache_dtype + self.use_flash_attn = use_flash_attn # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -48,7 +50,10 @@ def __repr__(self) -> str: f"is_prompt={self.is_prompt}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " + f"prompt_lens={self.prompt_lens}, " + f"start_loc={self.start_loc}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") + f"kv_cache_dtype={self.kv_cache_dtype}, " + f"use_flash_attn={self.use_flash_attn})") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 0622a54db1bc0..5f1c372abec64 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -7,6 +7,10 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) +try: + from flash_attn import flash_attn_func, flash_attn_with_kvcache +except ImportError: + flash_attn_func, flash_attn_with_kvcache = None, None from vllm._C import ops from vllm._C import cache_ops @@ -110,8 +114,10 @@ def forward( value: shape = [batch_size, seq_len, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] + or: [num_blocks, block_size, num_kv_heads, head_size] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] + or: [num_blocks, block_size, num_kv_heads, head_size] input_metadata: metadata for the inputs. Returns: shape = [batch_size, seq_len, num_heads * head_size] @@ -127,35 +133,73 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) + if input_metadata.use_flash_attn: + # Update kv-cache using tensor indexing. We don't use the kernel + # `flash_attn_with_kvcache` for kv-cache updating as it submitted + # many small kernels for each key/value and is slow. + flatten_slot_mapping = input_metadata.slot_mapping.flatten() + slot_block_index = flatten_slot_mapping // key_cache.shape[1] + slot_block_offset = flatten_slot_mapping % key_cache.shape[1] + key_cache[slot_block_index, slot_block_offset, :, :] = key + value_cache[slot_block_index, slot_block_offset, :, :] = value + else: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) - if input_metadata.is_prompt: - # Prompt run. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) + if input_metadata.is_prompt and input_metadata.use_flash_attn: # normal attention + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) if (key_cache is None or value_cache is None - or input_metadata.block_tables.numel() == 0): + or not input_metadata.context_lens.any()): + output = flash_attn_func( + q=query, + k=key, + v=value, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ) + else: + output = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=input_metadata.context_lens + seq_len, + block_table=input_metadata.block_tables, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ) + elif input_metadata.is_prompt and not input_metadata.use_flash_attn: + # normal attention + if (key_cache is None or value_cache is None + or not input_metadata.context_lens.any()): + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -226,8 +270,12 @@ def forward( # Decoding run. output = _paged_attention( query, + key, + value, key_cache, value_cache, + batch_size, + seq_len, input_metadata, self.num_kv_heads, self.scale, @@ -274,8 +322,12 @@ def _make_alibi_bias( def _paged_attention( query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, + num_seqs: int, + seq_len: int, input_metadata: InputMetadata, num_kv_heads: int, scale: float, @@ -284,7 +336,7 @@ def _paged_attention( output = torch.empty_like(query) block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape + _, num_heads, head_size = query.shape max_num_partitions = ( (input_metadata.max_context_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) @@ -297,7 +349,53 @@ def _paged_attention( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: + + # NOTE: in `_prepare_prompt` and `_prepare_decode` is filled in different + # manner (which may needs to be fixed in the future): in the former + # `context_lens` is the length of contexts whose kv-cache has been stored + # (e.g., with prefix cache), however, in the later `context_lens` is the + # length of current attention context (includes the token whose kv-cache + # will be filled in this round). + # + # The kernel `flash_attn_with_kvcache` expects `cache_seqlens` to be the + # length of the context whose kv-cache has been stored, i.e., the value of + # `context_lens - seq_len` for decoding. + # + # The `context_attention_fwd` kernel expects the same samatics as the + # `flash_attn_with_kvcache` kernel. + # + # In the contrast, both `paged_attention_v1` and `paged_attention_v2` expects + # the `context_lens` to be the length of the current attention context. + + if input_metadata.use_flash_attn: + # see also: https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342 + output = flash_attn_with_kvcache( + q=query.reshape(num_seqs, -1, *query.shape[1:]), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=input_metadata.context_lens, + block_table=input_metadata.block_tables, + softmax_scale=scale, + causal=True, + alibi_slopes=alibi_slopes, + ) + elif seq_len > 1: + # prefix-enabled attention + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens - seq_len, + input_metadata.max_seq_len, + alibi_slopes, + ) + elif use_v1: # Run PagedAttention V1. ops.paged_attention_v1( output, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 3e1cfc783b8ef..a2589fb2c195f 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -30,6 +30,14 @@ def __init__(self, strict_mode: bool = False): # value in a variable for readability. self._num_bonus_tokens = 1 + # NOTE: the `num_accepted_tokens` is not the tokens actually "accepted" + # from the outputs of the draft model, but is used to measure the quality + # of the draft model. + # + # For the top-level metric on how many tokens are emitted by speculative + # decoding, one should use `num_emitted_tokens`. + # + # See also: https://github.com/vllm-project/vllm/pull/2658 self.num_accepted_tokens: Optional[torch.Tensor] = None self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 884d84387e505..7f5920b02db6b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -110,7 +110,7 @@ def forward( prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + prompt_logprobs, sample_logprobs, probs) def _prune_hidden_states( @@ -237,47 +237,71 @@ def _apply_min_p( def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], samples: torch.Tensor, + subquery_lens: Optional[List[int]] = None, ) -> List[Tuple[List[int], List[int]]]: - samples = samples.tolist() + samples = samples.cpu() sample_idx = 0 results = [] - for seq_group in selected_seq_groups: + for i, (seq_group, + is_prompt) in enumerate(zip(selected_seq_groups, is_prompts)): seq_ids, _ = seq_group num_parent_seqs = len(seq_ids) assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx]] + if is_prompt or subquery_lens is None: + next_token_ids = [samples[sample_idx].tolist()] + parent_ids = list(range(num_parent_seqs)) + sample_idx += 1 + else: + next_token_ids = samples[sample_idx:sample_idx + + subquery_lens[i]].tolist() + sample_idx += subquery_lens[i] + parent_ids = [ + parent_id for parent_id in range(num_parent_seqs) + for _ in range(subquery_lens[i]) + ] results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs + assert sample_idx == samples.size(0) return results def _random_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], is_prompts: List[bool], - random_samples: torch.Tensor, + samples: torch.Tensor, + subquery_lens: Optional[List[int]] = None, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. - random_samples = random_samples.cpu() + samples = samples.cpu() sample_idx = 0 results = [] for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): seq_ids, sampling_params = seq_group - num_parent_seqs = len(seq_ids) - if is_prompt: + if is_prompt or subquery_lens is None: # Prompt phase. - parent_ids = [0] * sampling_params.best_of - next_token_ids = random_samples[ + next_token_ids = samples[ sample_idx, :sampling_params.best_of].tolist() + parent_ids = [0] + sample_idx += 1 else: # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() + next_token_ids, parent_ids = [], [] + for idx, _ in enumerate(seq_ids): + subquery_len = subquery_lens[idx] + next_token_ids.extend(samples[ + sample_idx:sample_idx + + subquery_len, :sampling_params.best_of].flatten().tolist()) + parent_ids.extend([idx] * subquery_lens[idx]) + sample_idx += subquery_len + idx += 1 + parent_ids = [ + parent_id for parent_id in parent_ids + for _ in range(sampling_params.best_of) + ] results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs + assert sample_idx == samples.size(0) return results @@ -421,10 +445,13 @@ def _sample( seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ sampling_type] if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) + sample_results = _greedy_sample(seq_groups, is_prompts, + greedy_samples, + sampling_metadata.subquery_lens) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, is_prompts, - multinomial_samples[sampling_type]) + multinomial_samples[sampling_type], + sampling_metadata.subquery_lens) elif sampling_type == SamplingType.BEAM: sample_results = _beam_search_sample(seq_groups, is_prompts, sampling_metadata.seq_data, @@ -448,31 +475,63 @@ def _get_logprobs( batched_logprobs_query_seq_indices: List[int] = [] batched_logprobs_query_token_indices: List[int] = [] largest_num_logprobs = 0 - sample_idx = 0 + sample_idx, subquery_len_index = 0, 0 for i, (seq_group, sample_result) in enumerate( zip(sampling_metadata.seq_groups, sample_results)): seq_ids, sampling_params = seq_group next_token_ids, parent_ids = sample_result num_parent_seqs = len(seq_ids) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - batched_logprobs_query_seq_indices.extend( - sample_idx + j for j in range(prompt_len - 1)) - batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) - sample_idx += prompt_len - 1 + num_subquery_samples = 0 + if i < sampling_metadata.num_prompts: + if sampling_params.prompt_logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + subquery_seqs = [ + j for j in range(sampling_metadata.prompt_lens[i] - 1) + ] + parent_ids + subquery_tokens = sampling_metadata.seq_data[ + seq_ids[0]].prompt_token_ids[1:] + next_token_ids + num_subquery_samples += sampling_metadata.prompt_lens[ + i] - 1 + num_parent_seqs + else: + subquery_seqs = parent_ids + subquery_tokens = next_token_ids + num_subquery_samples += num_parent_seqs + else: + if sampling_metadata.subquery_lens is None: + subquery_seqs = parent_ids + subquery_tokens = next_token_ids + num_subquery_samples += num_parent_seqs + else: + # When using beam search, the length of sample result will + # be twice of sequence length, see also: _beam_search_sample. + beam_factor = len(next_token_ids) // sum( + sampling_metadata. + subquery_lens[subquery_len_index:subquery_len_index + + num_parent_seqs]) + subquery_seqs, subquery_tokens = [], [] + next_tokens_offset = 0 + for i in range(num_parent_seqs): + subquery_len = sampling_metadata.subquery_lens[ + subquery_len_index + i] + subquery_seqs.extend([ + k for k in parent_ids[i * beam_factor:(i + 1) * + beam_factor] + for _ in range(subquery_len) + ]) + subquery_tokens.extend( + next_token_ids[next_tokens_offset:next_tokens_offset + + subquery_len * beam_factor]) + next_tokens_offset += subquery_len * beam_factor + num_subquery_samples += subquery_len + subquery_len_index += num_parent_seqs batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) + [sample_idx + seq for seq in subquery_seqs]) + batched_logprobs_query_token_indices.extend(subquery_tokens) if sampling_params.logprobs is not None: largest_num_logprobs = max(largest_num_logprobs, sampling_params.logprobs) - sample_idx += num_parent_seqs + sample_idx += num_subquery_samples assert sample_idx == logprobs.size(0) # Batched query for logprobs of selected token @@ -507,7 +566,6 @@ def _get_logprobs( if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): num_logprobs = sampling_params.prompt_logprobs - prompt_len = sampling_metadata.prompt_lens[i] prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids group_prompt_logprobs: PromptLogprobs = [None] @@ -557,20 +615,39 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + decoding_probs: torch.Tensor, ) -> SamplerOutput: sampler_output = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + sample_idx = 0 + decoding_probs = decoding_probs.reshape(-1, decoding_probs.shape[-1]) + for i, (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in enumerate( + zip(sampling_metadata.seq_groups, sample_results, + prompt_logprobs, sample_logprobs)): seq_ids, _ = seq_group next_token_ids, parent_ids = sample_result seq_outputs = [] + + # return multiple token ids as output for parallel decoding + if i >= sampling_metadata.num_prompts and \ + sampling_metadata.subquery_lens is not None: + subquery_len = sampling_metadata.subquery_lens[i] + else: + subquery_len = 1 + group_decoding_probs = decoding_probs[sample_idx:sample_idx + + subquery_len] + sample_idx += subquery_len + + if subquery_len > 1: + next_token_ids, group_sample_logprobs = \ + [next_token_ids], [group_sample_logprobs] + for parent_id, next_token_id, logprobs in zip(parent_ids, next_token_ids, group_sample_logprobs): seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + SequenceGroupOutput(seq_outputs, group_prompt_logprobs, + group_decoding_probs)) return sampler_output diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index a1a2ab0c4805c..70f09224f1cf6 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -45,6 +45,7 @@ def _fwd_kernel( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -53,6 +54,8 @@ def _fwd_kernel( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) @@ -85,13 +88,14 @@ def _fwd_kernel( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -131,9 +135,9 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -232,6 +236,7 @@ def _fwd_kernel_flash_attn_v2( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -240,6 +245,8 @@ def _fwd_kernel_flash_attn_v2( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) @@ -272,13 +279,14 @@ def _fwd_kernel_flash_attn_v2( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -317,9 +325,9 @@ def _fwd_kernel_flash_attn_v2( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -420,6 +428,7 @@ def _fwd_kernel_alibi( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -429,6 +438,8 @@ def _fwd_kernel_alibi( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 @@ -468,13 +479,14 @@ def _fwd_kernel_alibi( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -522,9 +534,9 @@ def _fwd_kernel_alibi( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -628,6 +640,7 @@ def context_attention_fwd(q, sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, @@ -674,6 +687,7 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -721,6 +735,7 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d0ffeecd2d74d..a3d0fcdcfc0fa 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -30,6 +30,7 @@ def __init__( seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], seq_data: Optional[Dict[int, SequenceData]], prompt_lens: Optional[List[int]], + subquery_lens: Optional[List[int]], selected_token_indices: torch.Tensor, categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], generators: Optional[List[torch.Generator]] = None, @@ -38,6 +39,7 @@ def __init__( self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens + self.subquery_lens = subquery_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.generators = generators @@ -51,8 +53,9 @@ def __repr__(self) -> str: f"seq_groups={self.seq_groups}, " f"seq_data={self.seq_data}, " f"prompt_lens={self.prompt_lens}, " + f"subquery_lens={self.subquery_lens}, " f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices}), " + f"categorized_sample_indices={self.categorized_sample_indices}, " f"perform_sampling={self.perform_sampling})") @@ -87,6 +90,8 @@ def from_sampling_metadata( do_penalties = False do_top_p_top_k = False do_min_p = False + + subquery_lens_index = 0 for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -112,30 +117,38 @@ def from_sampling_metadata( or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # For tokens in the prompt that we only need to get their logprobs - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - prompt_tokens.extend([] for _ in range(prompt_len - 1)) - output_tokens.extend([] for _ in range(prompt_len - 1)) - for seq_id in seq_ids: + if i < sampling_metadata.num_prompts: + if sampling_params.prompt_logprobs is not None: + # For tokens in the prompt that we only need to get their logprobs + subquery_lens = [sampling_metadata.prompt_lens[i] + ] * len(seq_ids) + else: + subquery_lens = [1] * len(seq_ids) + else: + if sampling_metadata.subquery_lens is None: + subquery_lens = [1] * len(seq_ids) + else: + subquery_lens = sampling_metadata.subquery_lens[ + subquery_lens_index:subquery_lens_index + len(seq_ids)] + # move to the next set of subquery_lens entries + subquery_lens_index += len(seq_ids) + + total_subquery_lens = sum(subquery_lens) + temperatures += [temperature] * total_subquery_lens + top_ps += [top_p] * total_subquery_lens + top_ks += [top_k] * total_subquery_lens + min_ps += [min_p] * total_subquery_lens + presence_penalties += [0] * total_subquery_lens + frequency_penalties += [0] * total_subquery_lens + repetition_penalties += [1] * total_subquery_lens + + for seq_id, subquery_len in zip(seq_ids, subquery_lens): + prompt_tokens.extend([] for _ in range(subquery_len - 1)) + output_tokens.extend([] for _ in range(subquery_len - 1)) + seq_data = sampling_metadata.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids) output_tokens.append(seq_data.output_token_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..5bc481d44f858 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,10 +1,13 @@ """Sequence and its related classes.""" import copy -import enum from dataclasses import dataclass +import enum from typing import Dict, List, Optional, Union +import torch + from vllm.block import LogicalTokenBlock +from vllm.block import _BLANK_TOKEN_ID from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams from vllm.lora.request import LoRARequest @@ -87,6 +90,7 @@ def __init__( ) -> None: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] + self.candidate_token_ids: List[int] = [] self.cumulative_logprob = 0.0 def append_token_id(self, token_id: int, logprob: float) -> None: @@ -110,10 +114,14 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] + def get_candidate_token_ids(self) -> List[int]: + return self.candidate_token_ids + def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self.prompt_token_ids}, " f"output_token_ids={self.output_token_ids}, " + f"candidate_token_ids={self.candidate_token_ids}, " f"cumulative_logprob={self.cumulative_logprob})") @@ -168,32 +176,40 @@ def _append_logical_block(self) -> None: ) self.logical_token_blocks.append(block) - def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: + def _append_tokens_to_blocks(self, + token_ids: List[int], + reserved: bool = False) -> None: cursor = 0 while cursor < len(token_ids): if not self.logical_token_blocks: self._append_logical_block() last_block = self.logical_token_blocks[-1] - if last_block.is_full(): - self._append_logical_block() - last_block = self.logical_token_blocks[-1] - num_empty_slots = last_block.get_num_empty_slots() - last_block.append_tokens(token_ids[cursor:cursor + - num_empty_slots]) + if not reserved: + last_block.append_tokens(token_ids[cursor:cursor + + num_empty_slots]) cursor += num_empty_slots + if cursor < len(token_ids): + self._append_logical_block() def append_token_id( self, token_id: int, logprobs: Dict[int, float], ) -> None: - assert token_id in logprobs + # When speculative decoding, the token may not be in the logprobs from sampler. + # assert token_id in logprobs + if token_id not in logprobs: + logprobs[token_id] = 0.0 self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id]) + def reserve_logical_blocks(self, reserve: Optional[int] = 0): + self._append_tokens_to_blocks([_BLANK_TOKEN_ID] * reserve, + reserved=True) + def get_len(self) -> int: return self.data.get_len() @@ -209,6 +225,9 @@ def get_token_ids(self) -> List[int]: def get_last_token_id(self) -> int: return self.data.get_last_token_id() + def get_candidate_token_ids(self) -> List[int]: + return self.data.get_candidate_token_ids() + def get_output_token_ids(self) -> List[int]: return self.data.output_token_ids @@ -421,6 +440,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, prefix: Optional[Prefix] = None, state: Optional[SequenceGroupState] = None, + parallel_decoding_lookahead: Optional[int] = 1, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -430,11 +450,23 @@ def __init__( self.lora_request = lora_request self.prefix = prefix self.state = SequenceGroupState() if state is None else state + self.parallel_decoding_lookahead = parallel_decoding_lookahead @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def __repr__(self) -> str: + return ( + f"SequenceGroupMetadata(request_id={self.request_id}, " + f"is_prompt={self.is_prompt}, " + f"seq_data={self.seq_data}, " + f"sampling_params={self.sampling_params}, " + f"block_tables={self.block_tables}, " + f"lora_request={self.lora_request}, " + f"prefix={self.prefix}, " + f"parallel_decoding_lookahead={self.parallel_decoding_lookahead})") + class SequenceOutput: """The model output associated with a sequence. @@ -477,9 +509,11 @@ def __init__( self, samples: List[SequenceOutput], prompt_logprobs: Optional[PromptLogprobs], + decoding_probs: Optional[torch.Tensor] = None, ) -> None: self.samples = samples self.prompt_logprobs = prompt_logprobs + self.decoding_probs = decoding_probs def __repr__(self) -> str: return (f"SequenceGroupOutput(samples={self.samples}, " diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 6edc225cdfc80..4a6bd961c3186 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -190,8 +190,9 @@ def detokenize_incrementally( read_offset: int = 0, skip_special_tokens: bool = False, spaces_between_special_tokens: bool = True, + parallel_decoding_accepted: int = 1, ) -> Tuple[List[str], str, int, int]: - new_token_id = all_input_ids[-1] + new_token_ids = all_input_ids[-parallel_decoding_accepted:] # This is the first iteration for this sequence if prev_tokens is None: new_tokens = tokenizer.convert_ids_to_tokens( @@ -202,14 +203,15 @@ def detokenize_incrementally( # Subtract 1 extra to account for the generated token. prefix_offset = max(len(output_tokens) - 6, 0) # If the first new token is a special token, we can't skip 1 extra token - if skip_special_tokens and new_token_id in tokenizer.all_special_ids: + if skip_special_tokens and new_token_ids[ + -1] in tokenizer.all_special_ids: read_offset = max(len(output_tokens), 0) else: read_offset = max(len(output_tokens) - 1, 0) else: - # Put new_token_id in a list so skip_special_tokens is respected + # Put new_token_ids in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) + new_token_ids, skip_special_tokens=skip_special_tokens) output_tokens = prev_tokens + new_tokens # The prefix text is necessary only to defeat cleanup algorithms in diff --git a/vllm/utils.py b/vllm/utils.py index 6206879929061..e46206de152dc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -226,6 +226,7 @@ def create_kv_caches_with_random( model_dtype: Optional[Union[str, torch.dtype]] = None, seed: Optional[int] = 0, device: Optional[str] = "cuda", + use_flash_attn: Optional[bool] = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -251,8 +252,12 @@ def create_kv_caches_with_random( raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if use_flash_attn: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, + x) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -267,7 +272,10 @@ def create_kv_caches_with_random( f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) - value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if use_flash_attn: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) + else: + value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index bbe33989fc2a4..6861ff4b28af7 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -35,15 +35,21 @@ def __init__( self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) - self.block_size = cache_config.block_size - self.num_gpu_blocks = cache_config.num_gpu_blocks - self.num_cpu_blocks = cache_config.num_cpu_blocks - if cache_config.cache_dtype == "auto": self.dtype = model_config.dtype else: self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + if model_config.use_flash_attn and self.dtype not in [ + torch.half, torch.bfloat16 + ]: + raise ValueError( + "flash-attn requires cache_dtype to be half or bfloat16") + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + self.num_cpu_blocks = cache_config.num_cpu_blocks + # Initialize the cache. self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() @@ -55,21 +61,35 @@ def __init__( self.events = [torch.cuda.Event() for _ in range(self.num_layers)] def get_key_block_shape(self) -> Tuple[int, int, int, int]: - element_size = torch.tensor([], dtype=self.dtype).element_size() - x = 16 // element_size - return ( - self.num_heads, - self.head_size // x, - self.block_size, - x, - ) + if self.model_config.use_flash_attn: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + element_size = torch.tensor([], dtype=self.dtype).element_size() + x = 16 // element_size + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) def get_value_block_shape(self) -> Tuple[int, int, int]: - return ( - self.num_heads, - self.head_size, - self.block_size, - ) + if self.model_config.use_flash_attn: + return ( + self.block_size, + self.num_heads, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] @@ -152,6 +172,8 @@ def get_cache_block_size( model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: + ''' Returns the nbytes of kv cache for a single token. + ''' head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b99a409e02d1e..b85159665d931 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -140,12 +140,22 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix - if prefix is not None and prefix.computed: - prefix_len = prefix.get_length() - prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_tables.append(prefix.get_block_numbers()) + if prefix is not None: + current_block_tables = prefix.get_block_numbers() + if prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] else: - prefix_block_tables.append([]) + current_block_tables = [] + + # append seq groups's block table as the key-value cache + # will be updated (cached) by the flash-attn kernels + if seq_group_metadata.block_tables: + current_block_tables.extend( + seq_group_metadata.block_tables[seq_id] + [len(current_block_tables):]) + prefix_block_tables.append(current_block_tables) + # actual prompt lens context_lens.append(prefix_len) subquery_lens.append(prompt_len - prefix_len) @@ -249,6 +259,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=getattr(self.model_config, 'use_flash_attn', False), ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -269,9 +280,12 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + prompt_lens: List[int] = [] + subquery_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt + lookahead = seq_group_metadata.parallel_decoding_lookahead seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id @@ -281,21 +295,30 @@ def _prepare_decode( for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + input_tokens.append( + [generation_token] + + seq_data.get_candidate_token_ids()[:lookahead - 1]) - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) + position = seq_data.get_len() - 1 + seq_len = seq_data.get_len() + lookahead - 1 + input_positions.append( + [position + i for i in range(lookahead)]) context_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) context_lens.append(context_len) + prompt_lens.append(seq_len) + subquery_lens.append(lookahead) block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) + slots = [] + for p in [position + i for i in range(lookahead)]: + block_number = block_table[p // self.block_size] + block_offset = p % self.block_size + slot = block_number * self.block_size + block_offset + slots.append(slot) + slot_mapping.append(slots) + lora_index_mapping.append([lora_id]) lora_prompt_mapping.append(lora_id) @@ -321,21 +344,24 @@ def _prepare_decode( input_positions.append([]) slot_mapping.append([]) context_lens.append(1) + prompt_lens.append(2) + subquery_lens.append(1) block_tables.append([]) batch_size = graph_batch_size + max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, + max_len=max_prompt_len, pad=0, dtype=torch.long, device=self.device) input_positions = _make_tensor_with_pad(input_positions, - max_len=1, + max_len=max_prompt_len, pad=0, dtype=torch.long, device=self.device) slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, + max_len=max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) @@ -366,20 +392,31 @@ def _prepare_decode( _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping ] + start_loc_tensor = torch.arange(0, + len(prompt_lens) * max_prompt_len, + max_prompt_len, + dtype=torch.long, + device=self.device) + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.long, + device=self.device) + input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens=prompt_lens_tensor, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, max_context_len=max_context_len, context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=getattr(self.model_config, 'use_flash_attn', False), ) - return (input_tokens, input_positions, input_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) def _prepare_sample( self, @@ -395,6 +432,7 @@ def _prepare_sample( categorized_sample_indices_start_idx = 0 max_subquery_len = max(subquery_lens) if subquery_lens else 1 + subquery_lens_index = 0 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -425,17 +463,27 @@ def _prepare_sample( seq_group_metadata.state.generator = torch.Generator( device="cuda").manual_seed(sampling_params.seed) else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs + for _ in seq_group_metadata.seq_data: + if subquery_lens is not None: + subquery_len = subquery_lens[subquery_lens_index] + else: + subquery_len = 1 + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + subquery_len)) + categorized_sample_indices_start_idx += subquery_len - categorized_sample_indices[ - sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) - categorized_sample_indices_start_idx += num_seqs + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len)) + selected_token_start_idx += max_subquery_len + + # move the next entry in subquery_lens + subquery_lens_index += 1 if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) @@ -460,6 +508,7 @@ def _prepare_sample( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, + subquery_lens=subquery_lens, selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, generators=generators, @@ -481,11 +530,10 @@ def prepare_input_tensors( subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, - lora_index_mapping, lora_prompt_mapping, + (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] - subquery_lens = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -515,6 +563,8 @@ def prepare_input_tensors( "block_tables": input_metadata.block_tables, "use_cuda_graph": input_metadata.use_cuda_graph, "kv_cache_dtype": input_metadata.kv_cache_dtype, + "use_flash_attn": input_metadata.use_flash_attn, + "subquery_lens": sampling_metadata.subquery_lens, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, @@ -538,11 +588,13 @@ def prepare_input_tensors( block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], kv_cache_dtype=metadata_dict["kv_cache_dtype"], + use_flash_attn=metadata_dict["use_flash_attn"], ) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, prompt_lens=None, + subquery_lens=metadata_dict["subquery_lens"], selected_token_indices=metadata_dict["selected_token_indices"], categorized_sample_indices=None, generators=None, @@ -686,12 +738,32 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() - slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + input_tokens = torch.zeros( + max_batch_size, + self.scheduler_config.parallel_decoding_lookahead, + dtype=torch.long, + device='cuda') + input_positions = torch.zeros( + max_batch_size, + self.scheduler_config.parallel_decoding_lookahead, + dtype=torch.long, + device='cuda') + slot_mapping = torch.empty( + max_batch_size, + self.scheduler_config.parallel_decoding_lookahead, + dtype=torch.long, + device='cuda') slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + context_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device='cuda') + max_prompt_len = self.scheduler_config.parallel_decoding_lookahead + prompt_lens = context_lens + max_prompt_len + start_loc = torch.arange(0, + len(prompt_lens) * max_prompt_len, + max_prompt_len, + dtype=torch.long, + device=self.device) block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -714,14 +786,16 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], - prompt_lens=None, - max_seq_len=None, - start_loc=None, + prompt_lens=prompt_lens, + max_seq_len=max_prompt_len, + start_loc=start_loc, max_context_len=self.max_context_len_to_capture, context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, kv_cache_dtype=self.kv_cache_dtype, + use_flash_attn=getattr(self.model_config, 'use_flash_attn', + False), ) if self.lora_config: diff --git a/vllm/worker/spec_decode/__init__.py b/vllm/worker/spec_decode/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/worker/spec_decode/multi_step_worker.py b/vllm/worker/spec_decode/multi_step_worker.py index 591d1b1300c88..bca35d3f358d2 100644 --- a/vllm/worker/spec_decode/multi_step_worker.py +++ b/vllm/worker/spec_decode/multi_step_worker.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional import copy import torch @@ -19,6 +19,27 @@ class MultiStepWorker(Worker): requires more thought for MultiStepWorker support. """ + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy) + + # Shallow copy input data to avoid the effect of associated prefixes. + copied_seq_group_metadata_list = self._shallow_copy_inputs( + seq_group_metadata_list) + return super().execute_model( + seq_group_metadata_list=copied_seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + @torch.inference_mode() def execute_model_multi_step( self, @@ -34,6 +55,9 @@ def execute_model_multi_step( self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + if num_steps == 0: + return [] + # Shallow copy input data so modifications (such as appending tokens) # do not cause side-effects. copied_seq_group_metadata_list = self._shallow_copy_inputs( @@ -54,13 +78,18 @@ def execute_model_multi_step( self._append_new_tokens(model_output, copied_seq_group_metadata_list) + # clear candidate tokens after the first step + for seq_group_metadata in copied_seq_group_metadata_list: + for _, seq_data in seq_group_metadata.seq_data.items(): + seq_data.candidate_token_ids = [] + seq_group_metadata.parallel_decoding_lookahead = 1 model_outputs.append(model_output) return model_outputs def _append_new_tokens( self, model_output: SamplerOutput, - seq_group_metadata_list: SequenceGroupMetadata) -> None: + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: """Given model output from a single run, append the tokens to the sequences. This is normally done outside of the worker, but it is required if the worker is to perform multiple forward passes. @@ -73,10 +102,23 @@ def _append_new_tokens( # NOTE: Beam search is not supported, so we can assume that # parent_seq_id == seq_id. seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + # The bonus token from the target model's output is the + # candidate token for the draft model's first input. And + # it should be append to draft's outputs. + if seq.candidate_token_ids: + # only append the first candidate token + token = seq.candidate_token_ids[0] + token_logprob = seq_output.logprobs[0].get(token, 0.0) + seq.append_token_id(token, token_logprob) + # append the sampled last token + if isinstance(token_id, list): + token_id = token_id[-1] + logprobs = seq_output.logprobs[-1] + else: + logprobs = seq_output.logprobs + token_logprob = logprobs[token_id] seq.append_token_id(token_id, token_logprob) def _shallow_copy_inputs( @@ -112,13 +154,25 @@ def _shallow_copy_inputs( new_seq_group_metadata_list.append(seq_group_metadata) # We must shallow-copy seq_data as we will append token ids - new_seq_data = {} + new_seq_data, parallel_decoding_lookahead = {}, 1 for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): new_seq_data[seq_id] = copy.copy(old_seq_data) new_seq_data[ seq_id].output_token_ids = old_seq_data.output_token_ids[:] + # pop the last token from inputs if it is actually part of the + # candidate tokens (token that accepted from the target model's output) + if new_seq_data[seq_id].candidate_token_ids: + assert new_seq_data[seq_id].candidate_token_ids[ + -1] == new_seq_data[seq_id].output_token_ids[-1] + new_seq_data[seq_id].output_token_ids.pop() + parallel_decoding_lookahead = 2 + + # draft model won't share the prefix + seq_group_metadata.prefix = None + seq_group_metadata.seq_data = new_seq_data + seq_group_metadata.parallel_decoding_lookahead = parallel_decoding_lookahead return new_seq_group_metadata_list @@ -169,7 +223,10 @@ def _raise_if_unsupported( """ if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): raise NotImplementedError( - "MultiStepWorker does not support cache operations") + "MultiStepWorker does not support cache operations: " + f"blocks_to_swap_in = {blocks_to_swap_in}, " + f"blocks_to_swap_out = {blocks_to_swap_out}, " + f"blocks_to_copy = {blocks_to_copy}.") if any( len(seq_group_metadata.seq_data.keys()) != 1