diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index f5f6e28b5fd9..45534da7ad49 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -71,6 +73,7 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--compilation-config", type=str, default="") return parser.parse_args() @@ -139,6 +142,9 @@ def main(args): max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + compilation_config=( + json.loads(args.compilation_config) if args.compilation_config else None + ), ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c84a060922e3..cf42a627340c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -428,6 +428,36 @@ class CompilationConfig: max_num_seqs, and prevents capture of many large graphs (>512) that would greatly increase startup time with limited performance benefit. """ + disable_cudagraph_uniform_alignment: bool = False + """Whether to disable uniformly alignment of cudagraph capture sizes for + uniform decode batch with query length>1 (i.e., for spec-decode). This flag + only takes effective when cudagraph_mode is FULL_DECODE_ONLY or + FULL_AND_PIECEWISE. + + Uniform alignment make sure all capture sizes for uniform-decode batch + are multiples of 1+num_speculative_tokens. This aligmnment is typically + useful for padded speculation (see #21984 for details), and is needed by + some attention backends to achieve their sota performance, which support + uniform-decode but no in a varible-length fashion. However, we should + realize here is a trade-off that while it is good for attention layer, + it may introduce slight regressions to other layers if these sizes after + alignment don't hit the multiple of 8. + + Note: for DP_size>1, the uniformity of sizes may be broken after dp_padding + sync. Therefore, we only ensure running full cudagraph of uniform-decode batch + of current rank if all dp ranks are uniform-decode batch. Otherwise, it would + fall back to piecewise cudagraphs, where the uniformity batch before padded + should still be utilized by attention layers under eager exectution. + """ + uniform_cudagraph_capture_sizes: list[int] | None = None + """ + List for capture sizes for uniform decode for the main model. Its elements + should be multiples of uniform_decode_len(1 for common pure decode, or + 1+num_speculative_tokens for speculative decode). + Not configurable, computed after init + """ + max_uniform_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( @@ -438,6 +468,11 @@ class CompilationConfig: Intuitively, bs_to_padded_graph_size should be dict[int, int]. since we know all keys are in a range [0, max_cudagraph_capture_size], we can optimize it to list[int] for better lookup performance.""" + bs_to_padded_graph_size_uniform: list[int] = field( + default=None, # type: ignore + init=False, + ) + """same as bs_to_padded_graph_size, but for uniform capture sizes""" # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) @@ -503,6 +538,7 @@ def __repr__(self) -> str: "disabled_custom_ops": True, "compilation_time": True, "bs_to_padded_graph_size": True, + "bs_to_padded_graph_size_uniform": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, @@ -718,7 +754,8 @@ def post_init_cudagraph_sizes(self) -> None: """To complete the initialization after cudagraph related configs are set. This includes: - initialize compile_sizes - - pre-compute the mapping bs_to_padded_graph_size + - pre-compute the mapping bs_to_padded_graph_size and + bs_to_padded_graph_size_uniform """ computed_compile_sizes = [] @@ -739,8 +776,14 @@ def post_init_cudagraph_sizes(self) -> None: # make sure the sizes are in ascending order self.cudagraph_capture_sizes.sort() + self.uniform_cudagraph_capture_sizes.sort() if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size + if self.uniform_cudagraph_capture_sizes: + assert ( + self.uniform_cudagraph_capture_sizes[-1] + == self.max_uniform_capture_size + ) # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ @@ -756,6 +799,20 @@ def post_init_cudagraph_sizes(self) -> None: else: self.bs_to_padded_graph_size[bs] = end + # pre-compute the mapping for uniform decode padding. + self.bs_to_padded_graph_size_uniform = [ + 0 for i in range(self.max_uniform_capture_size + 1) + ] + for end, start in zip( + self.uniform_cudagraph_capture_sizes + [self.max_uniform_capture_size + 1], + [0] + self.uniform_cudagraph_capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size_uniform[bs] = start + else: + self.bs_to_padded_graph_size_uniform[bs] = end + def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when mode is # CompilationMode.VLLM_COMPILE diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ac4607886305..772831a70d4a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -201,12 +201,23 @@ def compute_hash(self) -> str: ).hexdigest()[:10] return hash_str - def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_cudagraph_capture_size, - # it should raise an IndexError. - # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size - return self.compilation_config.bs_to_padded_graph_size[batch_size] + def pad_for_cudagraph(self, batch_size: int, uniform_aligned: bool = False) -> int: + """Get the padded graph size for the batch size. + uniform_aligned: if True, means the padding batch size would be + divisible by the uniform_decode_len for the main model. + For drafter, caller should make sure uniform_aligned is False because + drafter's uniform_decode_len is 1. + """ + if self.compilation_config.disable_cudagraph_uniform_alignment: + uniform_aligned = False + # if batch_size > max_cudagraph_capture_size (uniform_aligned=False) + # or batch_size > max_uniform_capture_size (uniform_aligned=True), + # it would raise an IndexError. So the caller should make sure the + # batch_size is within the range + if not uniform_aligned: + return self.compilation_config.bs_to_padded_graph_size[batch_size] + else: + return self.compilation_config.bs_to_padded_graph_size_uniform[batch_size] def enable_trace_function_call_for_thread(self) -> None: """ @@ -800,7 +811,6 @@ def _set_cudagraph_sizes(self): - If batch size > largest `cudagraph_capture_sizes`, cudagraph will not be used. """ - if ( self.model_config is not None and not self.model_config.enforce_eager @@ -847,7 +857,29 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes += list( range(256, max_cudagraph_capture_size + 1, 16) ) - + uniform_decode_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) + max_num_decode_tokens = min( + max_num_tokens, + self.scheduler_config.max_num_seqs * uniform_decode_len, + ) + if ( + self.compilation_config.disable_cudagraph_uniform_alignment + or uniform_decode_len == 1 + ): + uniform_cudagraph_capture_sizes = [ + x for x in cudagraph_capture_sizes if x < max_num_decode_tokens + ] + else: + uniform_cudagraph_capture_sizes = [ + size * uniform_decode_len + for size in cudagraph_capture_sizes + if size >= uniform_decode_len + and size * uniform_decode_len <= max_num_decode_tokens + ] if ( self.parallel_config.tensor_parallel_size > 1 and self.compilation_config.pass_config.enable_sequence_parallelism @@ -855,6 +887,11 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( cudagraph_capture_sizes ) + uniform_cudagraph_capture_sizes = ( + self.update_sizes_for_sequence_parallelism( + uniform_cudagraph_capture_sizes + ) + ) # user-specific compilation_config.max_cudagraph_capture_size get # truncated to valid_max_size when they are inconsistent. @@ -899,10 +936,22 @@ def _set_cudagraph_sizes(self): # always write back the final sizes self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes + # set uniform_cudagraph_sizes related values. + self.compilation_config.max_uniform_capture_size = ( + uniform_cudagraph_capture_sizes[-1] + if uniform_cudagraph_capture_sizes + else 0 + ) + self.compilation_config.uniform_cudagraph_capture_sizes = ( + uniform_cudagraph_capture_sizes + ) + else: # no cudagraph in use self.compilation_config.max_cudagraph_capture_size = 0 self.compilation_config.cudagraph_capture_sizes = [] + self.compilation_config.max_uniform_capture_size = 0 + self.compilation_config.uniform_cudagraph_capture_sizes = [] # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ef37cf862c9f..0b62e42a0e06 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -40,6 +40,11 @@ class BatchDescriptor(NamedTuple): False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ + uniform_query_len: int = 0 + """ + For non-uniform batches, should set to 0 for uniquely identifying the batch. + For uniform batches, it is the max_query_len of a uniform batch. + """ has_lora: bool = False """ Whether this batch has active LoRA adapters. @@ -51,7 +56,10 @@ def non_uniform(self) -> "BatchDescriptor": Return a non-uniform version of current batch descriptor. """ return BatchDescriptor( - self.num_tokens, uniform_decode=False, has_lora=self.has_lora + self.num_tokens, + uniform_decode=False, + uniform_query_len=0, + has_lora=self.has_lora, ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f57dfc1571b6..d6a277cf643d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -741,23 +741,6 @@ def _build_decode( dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> M: - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with MLA. - """ - m = common_attn_metadata - assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( - "MLA only supports decode-only full CUDAGraph capture. " - "Make sure all cudagraph capture sizes <= max_num_seq." - ) - - assert m.max_query_len <= self.reorder_batch_threshold # decode only - - return self.build(0, m) - def build( self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index b1d34dbfd172..2995ee4f61d1 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -86,16 +86,6 @@ def __init__( self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - attn_metadata = self.build(0, common_attn_metadata) - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - attn_metadata.seq_lens.fill_(1) - return attn_metadata - def build( self, common_prefix_len: int, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index b480ac78f23c..ce27b6825bf8 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -4,6 +4,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor +from vllm.utils import round_up class CudagraphDispatcher: @@ -25,16 +26,23 @@ class CudagraphDispatcher: runnable without cudagraph (if the mode does not match or mode is NONE). """ - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, is_drafter: bool = False): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.is_drafter = is_drafter # Dict to store valid cudagraph dispatching keys. self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { CUDAGraphMode.PIECEWISE: set(), CUDAGraphMode.FULL: set(), } + # Placeholder for capture sizes. Should be initialized in + # self.initialize_cudagraph_keys. + self.cudagraph_capture_sizes: list[int] = [] + # map uniform_query_len to capture sizes + self.uniform_cudagraph_capture_sizes: dict[int, list[int]] = {} + self.uniform_query_lens: list[int] = [] not_use_piecewise_compilation = ( not self.cudagraph_mode.requires_piecewise_compilation() @@ -54,6 +62,7 @@ def __init__(self, vllm_config: VllmConfig): ) self.keys_initialized = False + self.lora_cases: list[bool] = [False] def add_cudagraph_key( self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor @@ -64,22 +73,46 @@ def add_cudagraph_key( self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys( - self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + self, cudagraph_mode: CUDAGraphMode, uniform_query_lens: int | list[int] ): # This should be called only after attention backend is initialized. # LoRA activation cases to specialize the cuda graphs on - if self.vllm_config.lora_config: + if self.vllm_config.lora_config and not self.is_drafter: if self.compilation_config.cudagraph_specialize_lora: lora_cases = [True, False] else: lora_cases = [True] else: lora_cases = [False] + self.lora_cases = lora_cases # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. + + # support multiple uniform_decode_query_lens for spec-decode + if isinstance(uniform_query_lens, int): + uniform_query_lens = [uniform_query_lens] + assert len(uniform_query_lens) > 0 and all( + isinstance(x, int) and x > 0 for x in uniform_query_lens + ), f"Invalid uniform_query_lens: {uniform_query_lens}" + self.uniform_query_lens = uniform_query_lens + + # we only have compilation_config.uniform_cudagraph_capture_sizes + # being aligned with one uniform_query_len that greater than 1, not + # multiple of them. Should verify this here. + if not self.compilation_config.disable_cudagraph_uniform_alignment: + for uniform_query_len in self.uniform_query_lens: + if ( + uniform_query_len > 1 + and self.compilation_config.uniform_cudagraph_capture_sizes + ): + assert all( + x % uniform_query_len == 0 + for x in self.compilation_config.uniform_cudagraph_capture_sizes + ), f"Invalid uniform_query_lens: {uniform_query_len}" + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs, has_lora in product( self.compilation_config.cudagraph_capture_sizes, lora_cases @@ -90,30 +123,185 @@ def initialize_cudagraph_keys( num_tokens=bs, uniform_decode=False, has_lora=has_lora ), ) + self.cudagraph_capture_sizes = ( + self.compilation_config.cudagraph_capture_sizes + ) # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. + for uniform_query_len in self.uniform_query_lens: + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + uniform_query_len * self.vllm_config.scheduler_config.max_num_seqs + ) + # for uniform_query_len==1, we use the non-uniform + # capture sizes, this can be for main model without spec-decode + # or for the drafter. Otherwise, we use the uniform-aligned + # sizes. + candidate_sizes = ( + self.compilation_config.cudagraph_capture_sizes + if ( + uniform_query_len == 1 + or self.compilation_config.disable_cudagraph_uniform_alignment + ) + else self.compilation_config.uniform_cudagraph_capture_sizes + ) + cudagraph_capture_sizes_for_decode = [ + x + for x in candidate_sizes + if x <= max_num_tokens and x >= uniform_query_len + ] + for bs, has_lora in product( + cudagraph_capture_sizes_for_decode, lora_cases + ): + self.add_cudagraph_key( + CUDAGraphMode.FULL, + BatchDescriptor( + num_tokens=bs, + uniform_decode=True, + uniform_query_len=uniform_query_len, + has_lora=has_lora, + ), + ) + self.uniform_cudagraph_capture_sizes[uniform_query_len] = ( + cudagraph_capture_sizes_for_decode + ) + + # update the cudagraph mode resolved from runner + self.cudagraph_mode = cudagraph_mode + self.keys_initialized = True + + def get_capture_cases( + self, uniform_decode: bool, uniform_query_len: int + ) -> tuple[CUDAGraphMode, list[BatchDescriptor], list[int]]: + """Return capture sizes, keys, and runtime mode for a given case. + The capture sizes and keys are sorted in descending order. + """ + if not uniform_decode: + runtime_mode = self.cudagraph_mode.mixed_mode() + uniform_query_len = 0 + capture_sizes = sorted(self.cudagraph_capture_sizes, reverse=True) + else: + runtime_mode = self.cudagraph_mode.decode_mode() + assert uniform_query_len in self.uniform_cudagraph_capture_sizes + capture_sizes = sorted( + self.uniform_cudagraph_capture_sizes[uniform_query_len], reverse=True + ) + combos = product(capture_sizes, self.lora_cases) + keys = [ + BatchDescriptor( + num_tokens=x, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, + has_lora=has_lora, + ) + for x, has_lora in combos + ] + capture_sizes = [x for x, _ in combos] + return capture_sizes, keys, runtime_mode + + def cudagraph_padded_num_tokens( + self, num_tokens: int, uniform_decode: bool, uniform_query_len: int + ) -> tuple[int, bool]: + """Return Tuple[num_tokens_after_padded, is_cudagraph_padded].""" + assert uniform_query_len == 0 or uniform_query_len in self.uniform_query_lens, ( + f"Invalid uniform_query_len: {uniform_query_len}" + ) if ( - cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() + uniform_query_len <= 1 + and num_tokens <= self.compilation_config.max_capture_size ): - max_num_tokens = ( - uniform_decode_query_len - * self.vllm_config.scheduler_config.max_num_seqs + # common situation within the range of max_capture_size for main + # model or for a drafter. + # we ignore whether it is uniform-decode since it is always safe + # to pad. + return self.vllm_config.pad_for_cudagraph( + num_tokens, uniform_aligned=False + ), True + + if ( + uniform_decode + and uniform_query_len > 1 + and num_tokens <= self.compilation_config.max_uniform_capture_size + ): + # this is particular for uniform-decode alignment for validation + # phase of spec-decode, or for the first iteration of drafter when + # support padded speculation + return self.vllm_config.pad_for_cudagraph( + num_tokens, uniform_aligned=True + ), True + + # otherwise, it is not cudagraph padded + return num_tokens, False + + def caculate_uniform_decode( + self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int + ) -> tuple[bool, int]: + uniform_decode = (max_query_len in self.uniform_query_lens) and ( + num_scheduled_tokens == num_reqs * max_query_len + ) + uniform_query_len = max_query_len if uniform_decode else 0 + return uniform_decode, uniform_query_len + + def get_local_batch_description( + self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int + ) -> tuple[int, bool, int]: + """ + return Tuple[num_tokens_after_padding, uniform_decode, uniform_query_len] + """ + uniform_decode, uniform_query_len = self.caculate_uniform_decode( + num_scheduled_tokens, num_reqs, max_query_len + ) + + # Compute padded tokens + cudagraph_padded = False + if self.cudagraph_mode != CUDAGraphMode.NONE: + num_input_tokens, cudagraph_padded = self.cudagraph_padded_num_tokens( + num_scheduled_tokens, uniform_decode, uniform_query_len ) - cudagraph_capture_sizes_for_decode = [ - x - for x in self.compilation_config.cudagraph_capture_sizes - if x <= max_num_tokens and x >= uniform_decode_query_len - ] - for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): - self.add_cudagraph_key( - CUDAGraphMode.FULL, - BatchDescriptor( - num_tokens=bs, uniform_decode=True, has_lora=has_lora - ), - ) - self.keys_initialized = True + else: + num_input_tokens = num_scheduled_tokens + + if not cudagraph_padded and not self.is_drafter: + # Eager mode + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + return num_input_tokens, uniform_decode, uniform_query_len + + def fast_plan( + self, + num_scheduled_tokens: int, + num_reqs: int, + max_query_len: int, + use_cascade_attn: bool = False, + ) -> tuple[CUDAGraphMode, BatchDescriptor | None, int]: + """Plan cudagraph execution in a single call, without considering dp. + + Returns (runtime_mode, batch_descriptor, num_input_tokens). + """ + num_input_tokens, uniform_decode, uniform_query_len = ( + self.get_local_batch_description( + num_scheduled_tokens, num_reqs, max_query_len + ) + ) + + # Build initial descriptor and then dispatch + descriptor = BatchDescriptor( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, + ) + runtime_mode, descriptor = self.dispatch(descriptor, use_cascade_attn) + return runtime_mode, descriptor, num_input_tokens def dispatch( self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75a4140fd655..abee747684d1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,19 +3,15 @@ import ast from dataclasses import replace from importlib.util import find_spec +from typing import Any import numpy as np import torch import torch.nn as nn -from vllm.config import ( - CompilationMode, - CUDAGraphMode, - VllmConfig, - get_layers_from_vllm_config, -) +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -35,6 +31,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -84,32 +81,11 @@ def __init__( self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = False - - compilation_config = self.vllm_config.compilation_config - if compilation_config.mode == CompilationMode.VLLM_COMPILE: - cudagraph_mode = compilation_config.cudagraph_mode - if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( - CUDAGraphMode.PIECEWISE - ): - logger.warning( - "Currently the eagle proposer only supports cudagraph_mode " - "PIECEWISE, if you want the drafter to use cuda graphs, " - "please set compilation_config.cudagraph_mode to PIECEWISE " - "or FULL_AND_PIECEWISE" - ) - self.use_cuda_graph = ( - cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) - - self.cudagraph_batch_sizes = ( - (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes)) - if self.use_cuda_graph - else [] + self.use_cuda_graph = ( + self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not self.speculative_config.enforce_eager ) - self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) # persistent buffers for cuda graph self.input_ids = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=device @@ -187,6 +163,12 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) + # Cudagraph dispatcher for runtime cudagraph dispatching of drafter, + # which is independent of the dispatcher of the model runner. + self.cudagraph_dispatcher = CudagraphDispatcher( + self.vllm_config, is_drafter=True + ) + def _get_positions(self, num_tokens: int): if self.uses_mrope: return self.mrope_positions[:, :num_tokens] @@ -218,6 +200,7 @@ def propose( if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + max_query_len = common_attn_metadata.max_query_len if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -262,12 +245,14 @@ def propose( assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata - cudagraph_runtime_mode = CUDAGraphMode.NONE - if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens + # dispatcher planning for drafter + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = ( + self.cudagraph_dispatcher.fast_plan( + num_scheduled_tokens=num_tokens, + num_reqs=batch_size, + max_query_len=max_query_len, + ) + ) # copy inputs to buffer for cudagraph self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states @@ -292,6 +277,7 @@ def propose( self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -353,12 +339,12 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - input_batch_size = batch_size - cudagraph_runtime_mode = CUDAGraphMode.NONE + # dispatcher plans only once for the remaining loop + cudagraph_runtime_mode, batch_descriptor, input_batch_size = ( + self.cudagraph_dispatcher.fast_plan( + num_scheduled_tokens=batch_size, num_reqs=batch_size, max_query_len=1 + ) + ) common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 @@ -459,6 +445,7 @@ def propose( self.vllm_config, num_tokens=input_batch_size, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -763,18 +750,25 @@ def propose_tree( self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens - cudagraph_runtime_mode = CUDAGraphMode.NONE + # Note: decode phase of TreeAttentionBackend does not have an + # unique uniform decode query length (1 for the root level and + # total_num_drafts for subsequent levels). Here we may support + # this situation once full cudagraph of TreeAttention is supported. + cudagraph_runtime_mode, batch_descriptor, num_input_tokens = ( + self.cudagraph_dispatcher.fast_plan( + num_scheduled_tokens=num_tokens, + num_reqs=batch_size, + max_query_len=attn_metadata.max_query_len, + ) + ) + # Run the model. with set_forward_context( per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -1054,20 +1048,91 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, - use_cudagraphs=True, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + force_attention: bool = False, + uniform_decode: bool = False, + uniform_query_len: int = 0, + **other_kwargs, # unused but may get passed from caller ) -> None: - # Determine if CUDA graphs should be used for this run. - cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph - if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + assert cudagraph_runtime_mode != CUDAGraphMode.FULL, ( + "Eagle drafter doesn't support full cudagraphs at this moment" + ) + # overwrite runtime mode to NONE if enforce_eager + if self.speculative_config.enforce_eager: + cudagraph_runtime_mode = CUDAGraphMode.NONE + + max_query_len = uniform_query_len if uniform_decode else num_tokens + + max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + if uniform_decode: + assert num_tokens % max_query_len == 0, ( + "num_tokens must be divisible by max_query_len for uniform decode" + ) + num_reqs = min(num_tokens // max_query_len, max_num_reqs) + else: + assert uniform_query_len == 0 + num_reqs = min(num_tokens, max_num_reqs) + + per_layer_attn_metadata: dict[str, Any] | None = None + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + per_layer_attn_metadata = {} + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[: num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc_cpu[: num_reqs + 1], + seq_lens=self.runner.seq_lens[:num_reqs], + seq_lens_cpu=self.runner.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.runner.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.runner.input_batch.block_table[ + 0 + ].get_device_tensor()[:num_reqs], + slot_mapping=self.runner.input_batch.block_table[0].slot_mapping[ + :num_tokens + ], + causal=True, + ) + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_groups[0][ + 0 + ].metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + batch_descriptor = None + else: + batch_descriptor = BatchDescriptor( + num_tokens=num_tokens, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, + has_lora=False, + ) + # sanity check + _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + batch_descriptor + ) + assert cudagraph_runtime_mode == _cg_mode, ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) with set_forward_context( - None, + per_layer_attn_metadata, self.vllm_config, num_tokens=num_tokens, - cudagraph_runtime_mode=( - CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE - ), + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ): if self.supports_mm_inputs: input_ids = None diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 464fbf11a21a..aaa155b2188f 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -39,16 +39,20 @@ def _run_ar( should_dp_pad: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, + disable_padding_extend: bool, + num_tokens_padded_extended: int, parallel_config: ParallelConfig, ) -> torch.Tensor: dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) + tensor = torch.zeros(6, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 tensor[3][dp_rank] = 1 if should_dp_pad else 0 + tensor[4][dp_rank] = 1 if disable_padding_extend else 0 + tensor[5][dp_rank] = num_tokens_padded_extended dist.all_reduce(tensor, group=group) return tensor @@ -76,6 +80,11 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor: num_tokens_across_dp = tensor[1, :] if should_dp_pad: + # replace num_tokens_across_dp with the extended version when exists one dp rank + # do not disable it. + disable_padding_extend = bool(torch.all(tensor[4] == 1).item()) + if not disable_padding_extend: + num_tokens_across_dp = tensor[5, :] # If DP padding is enabled, ensure that each rank is processing the same number # of tokens max_num_tokens = int(num_tokens_across_dp.max().item()) @@ -93,6 +102,8 @@ def _synchronize_dp_ranks( num_tokens_padded: int, should_attempt_ubatching: bool, should_attempt_dp_padding: bool, + try_disable_padding_extend: bool, + num_tokens_padded_extended: int, parallel_config: ParallelConfig, ) -> tuple[bool, torch.Tensor | None]: """ @@ -120,6 +131,8 @@ def _synchronize_dp_ranks( should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, + disable_padding_extend=try_disable_padding_extend, + num_tokens_padded_extended=num_tokens_padded_extended, parallel_config=parallel_config, ) @@ -157,6 +170,8 @@ def coordinate_batch_across_dp( parallel_config: ParallelConfig, num_tokens_padded: int | None = None, uniform_decode: bool | None = None, + try_disable_padding_extend: bool = True, + num_tokens_padded_extended: int | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None, ) -> tuple[UBatchSlices | None, torch.Tensor | None]: """ @@ -170,8 +185,11 @@ def coordinate_batch_across_dp( parallel_config: The parallel config num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, TP, etc) - uniform_decode: Only used if allow_microbatching is True. True if the batch - only contains single token decodes + uniform_decode: Used when allow_microbatching is True and/or when it is uniform + decoding for spec-decode. + try_disable_padding_extend: If it is True across all dp rank, we do not extend + the padding to the max value of num_tokens_padded_extended across dp ranks. + num_tokens_padded_extended: the number of tokens after extending the padding. num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The number of tokens per request. @@ -203,11 +221,16 @@ def coordinate_batch_across_dp( if num_tokens_padded is None: num_tokens_padded = num_tokens_unpadded + if num_tokens_padded_extended is None: + num_tokens_padded_extended = num_tokens_padded + (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, allow_dp_padding, + try_disable_padding_extend, + num_tokens_padded_extended, parallel_config, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2db4235c89de..3cd71bce5fbe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,8 +8,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from functools import reduce -from itertools import product +from functools import partial, reduce from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -72,7 +71,7 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.jsontree import json_map_leaves -from vllm.utils.math_utils import cdiv, round_up +from vllm.utils.math_utils import cdiv from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.platform_utils import is_pin_memory_available @@ -501,7 +500,9 @@ def __init__( ) # Cudagraph dispatcher for runtime cudagraph dispatching. - self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + self.cudagraph_dispatcher = CudagraphDispatcher( + self.vllm_config, is_drafter=False + ) self.mm_budget = ( MultiModalBudget( @@ -1074,7 +1075,7 @@ def _prepare_inputs( SpecDecodeMetadata | None, np.ndarray, CommonAttentionMetadata | None, - int, + BatchDescriptor, UBatchSlices | None, torch.Tensor | None, bool, @@ -1201,10 +1202,11 @@ def _prepare_inputs( query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + num_tokens_padded, uniform_decode, uniform_query_len = ( + self._get_local_batch_description( + num_tokens_unpadded, num_reqs, max_num_scheduled_tokens + ) + ) # Disable DP padding when running eager to avoid excessive padding when # running prefills. This lets us set enforce_eager on the prefiller in @@ -1212,6 +1214,25 @@ def _prepare_inputs( # decoder. allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # For uniform decode batch with query length > 1, we may extend to non-uniform + # padding if there exists one dp rank that is non-uniform batch (i.e. can run + # into piecewise cudagraph), to resolve the conflicts where we may no have + # cudagraph for uniform-batch after dp-padding. + num_tokens_padded_extended = num_tokens_padded + disable_padding_extend = ( + self.compilation_config.disable_cudagraph_uniform_alignment + or not self.compilation_config.cudagraph_mode.separate_routine() + or not uniform_decode + or uniform_query_len <= 1 + ) + if ( + not disable_padding_extend + and num_tokens_padded < self.compilation_config.max_capture_size + ): + num_tokens_padded_extended = self.vllm_config.pad_for_cudagraph( + num_tokens_padded, uniform_aligned=False + ) + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.parallel_config, @@ -1219,6 +1240,8 @@ def _prepare_inputs( allow_dp_padding=allow_dp_padding, num_tokens_padded=num_tokens_padded, uniform_decode=uniform_decode, + try_disable_padding_extend=disable_padding_extend, + num_tokens_padded_extended=num_tokens_padded_extended, num_scheduled_tokens_per_request=num_scheduled_tokens, ) @@ -1454,13 +1477,30 @@ def _prepare_inputs( self.input_batch, num_scheduled_tokens, num_sampled_tokens ) + dp_rank = self.parallel_config.data_parallel_rank + if ubatch_slices: + assert num_tokens_across_dp is not None + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif num_tokens_across_dp is not None: + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + else: + num_input_tokens = num_tokens_padded + + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, + ) + return ( attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, + batch_descriptor, ubatch_slices, num_tokens_across_dp, use_cascade_attn, @@ -2119,27 +2159,16 @@ def _pool( pooler_output=pooler_output, ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] - ): - # Use CUDA graphs. - # Add padding to the batch size. - return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if ( - self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1 - ): - return round_up(num_scheduled_tokens, tp_size) - return num_scheduled_tokens + def _get_local_batch_description( + self, num_scheduled_tokens: int, num_reqs: int, max_query_len: int + ) -> tuple[int, bool, int]: + """ + Get local batch descriptions for before DP sync. + returns (num_tokens_after_padding, uniform_decode, uniform_query_len) + """ + return self.cudagraph_dispatcher.get_local_batch_description( + num_scheduled_tokens, num_reqs, max_query_len + ) def _preprocess( self, @@ -2490,24 +2519,21 @@ def execute_model( spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, + batch_descriptor, ubatch_slices, num_tokens_across_dp, use_cascade_attn, ) = self._prepare_inputs(scheduler_output) - dp_rank = self.parallel_config.data_parallel_rank - if ubatch_slices: - assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) - else: - num_input_tokens = self._get_num_input_tokens( - scheduler_output.total_num_scheduled_tokens - ) - + # cudagraph dispatching + ( + cudagraph_runtime_mode, + batch_descriptor, + ) = self.cudagraph_dispatcher.dispatch( + batch_descriptor, + use_cascade_attn, + ) + num_input_tokens = batch_descriptor.num_tokens ( input_ids, inputs_embeds, @@ -2518,18 +2544,6 @@ def execute_model( scheduler_output, num_input_tokens, intermediate_tensors ) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len - ) - batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, - uniform_decode=uniform_decode, - has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, - ) - cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) - ) - # Set cudagraph mode to none if calc_kv_scales is true. if attn_metadata is not None: metadata_list = ( @@ -3314,6 +3328,7 @@ def _dummy_run( cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, + uniform_query_len: int = 0, allow_microbatching: bool = True, skip_eplb: bool = False, is_profile: bool = False, @@ -3337,6 +3352,9 @@ def _dummy_run( force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. + uniform_query_len: The query length for uniform decode batch. + Should be 0 when uniform_decode is False. + allow_microbatching: If True, allow microbatching for DBO. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. create_mixed_batch: If True, create a mixed batch with both decode @@ -3363,6 +3381,9 @@ def _dummy_run( # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens + if uniform_decode and uniform_query_len: + # allow skip this assertion when it is on a dummy execution on DP setup + assert max_query_len == uniform_query_len # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -3402,6 +3423,17 @@ def _dummy_run( # Disable DP padding when running eager allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # make sure uniform-decode batch can safely hit a cudagraph when it is + # on a dummy execution for DP size>1. + num_tokens_padded_extended = total_num_scheduled_tokens + if ( + total_num_scheduled_tokens + < self.vllm_config.compilation_config.max_capture_size + ): + num_tokens_padded_extended = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens, uniform_aligned=False + ) + # We currently only microbatch if the number of tokens is # over a certain threshold. ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( @@ -3411,6 +3443,8 @@ def _dummy_run( allow_dp_padding=allow_dp_padding, num_tokens_padded=total_num_scheduled_tokens, uniform_decode=uniform_decode, + try_disable_padding_extend=True, + num_tokens_padded_extended=num_tokens_padded_extended, num_scheduled_tokens_per_request=num_scheduled_tokens, ) num_tokens_after_padding = num_tokens @@ -3544,6 +3578,7 @@ def _dummy_run( BatchDescriptor( num_tokens=num_tokens_after_padding, uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, has_lora=activate_lora and self.lora_config is not None, ) ) @@ -3596,24 +3631,17 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): + # Only trigger drafter's dummy run for profile run. Otherwise, the + # dummy run logic of drafter is separated from the main model's + # dummy run. + if ( + is_profile + and self.speculative_config + and self.speculative_config.use_eagle() + ): assert isinstance(self.drafter, EagleProposer) - use_cudagraphs = ( - cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE - and not self.speculative_config.enforce_eager - ) - - # Note(gnovack) - We need to disable cudagraphs for one of the two - # lora cases when cudagraph_specialize_lora is enabled. This is a - # short term mitigation for issue mentioned in - # https://github.com/vllm-project/vllm/issues/28334 - if self.compilation_config.cudagraph_specialize_lora and activate_lora: - use_cudagraphs = False - - self.drafter.dummy_run( - num_tokens, - use_cudagraphs=use_cudagraphs, - ) + # no cudagraph for profile run + self.drafter.dummy_run(num_tokens) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -3913,26 +3941,22 @@ def freeze_gc(): with freeze_gc(), graph_capture(device=self.device): start_free_gpu_memory = torch.cuda.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode + logger.info("Start capturing cudagraphs for main model...") assert cudagraph_mode is not None - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] - if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - # make sure we capture the largest batch size first - compilation_cases = list( - product(reversed(self.cudagraph_batch_sizes), lora_cases) + capture_sizes, keys, runtime_mode = ( + self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=False, uniform_query_len=0 + ) ) - self._capture_cudagraphs( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, + self._capture_cudagraphs_with_callable( + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, + dummy_run_callable=partial( + self._dummy_run, skip_eplb=True, remove_lora=False + ), ) # Capture full cudagraph for uniform decode batches if we @@ -3941,23 +3965,69 @@ def freeze_gc(): cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and cudagraph_mode.separate_routine() ): - max_num_tokens = ( - self.scheduler_config.max_num_seqs * self.uniform_decode_query_len - ) - decode_cudagraph_batch_sizes = [ - x - for x in self.cudagraph_batch_sizes - if max_num_tokens >= x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - product(reversed(decode_cudagraph_batch_sizes), lora_cases) + capture_sizes, keys, runtime_mode = ( + self.cudagraph_dispatcher.get_capture_cases( + uniform_decode=True, + uniform_query_len=self.uniform_decode_query_len, + ) ) - self._capture_cudagraphs( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + + self._capture_cudagraphs_with_callable( + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, + dummy_run_callable=partial( + self._dummy_run, skip_eplb=True, remove_lora=False + ), ) + self.maybe_remove_all_loras(self.lora_config) + + # Capture drafter cudagraphs. + # Note: Currently only PIECEWISE mode is supported for eagle + # drafter. + # TODO: add full cudagraph support for drafter. + if ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.enforce_eager + ): + assert isinstance(self.drafter, EagleProposer) + logger.info("Start capturing cudagraphs for drafter...") + # when not enforce_eager, eagle drafter share the same cudagraph_mode + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + capture_sizes, keys, runtime_mode = ( + self.drafter.cudagraph_dispatcher.get_capture_cases( + uniform_decode=False, uniform_query_len=0 + ) + ) + self._capture_cudagraphs_with_callable( + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, + dummy_run_callable=self.drafter.dummy_run, + ) + # the following code would not be triggered at present since + # only PIECEWISE mode is supported. But it is kept and prepared + # for full cudagraphs. + # TODO: support multiple uniform_query_lens for drafter once + # Padded speculation is supported. i.e., + # [1, self.uniform_decode_query_len] for drafter. + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + capture_sizes, keys, runtime_mode = ( + self.drafter.cudagraph_dispatcher.get_capture_cases( + uniform_decode=True, uniform_query_len=1 + ) + ) + self._capture_cudagraphs_with_callable( + capture_sizes=capture_sizes, + keys=keys, + cudagraph_runtime_mode=runtime_mode, + dummy_run_callable=self.drafter.dummy_run, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -3980,30 +4050,45 @@ def freeze_gc(): ) return cuda_graph_size - def _capture_cudagraphs( + def _capture_cudagraphs_with_callable( self, - compilation_cases: list[tuple[int, bool]], + capture_sizes: list[int], + keys: list[BatchDescriptor], cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool, + dummy_run_callable: Any, ): assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode.valid_runtime_modes() ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + assert len(keys) > 0, "keys must be non-empty" + assert len(capture_sizes) == len(keys), ( + "capture_sizes and keys must have the same length" + ) + uniform_decode = keys[0].uniform_decode + uniform_query_len = keys[0].uniform_query_len # Only rank 0 should print progress bar during capture if is_global_first_rank(): - compilation_cases = tqdm( - compilation_cases, + capture_sizes = tqdm( + capture_sizes, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( - "decode" if uniform_decode else "mixed prefill-decode", + f"decode(query_len={uniform_query_len})" + if uniform_decode + else "mixed prefill-decode", cudagraph_runtime_mode.name, ), ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for num_tokens, key in zip(capture_sizes, keys): + assert ( + key.uniform_decode == uniform_decode + and key.uniform_query_len == uniform_query_len + and key.num_tokens == num_tokens + ), "Inconsistent batch descriptor during cudagraph capture." + # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -4026,26 +4111,23 @@ def _capture_cudagraphs( # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - self._dummy_run( + dummy_run_callable( num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, + activate_lora=key.has_lora, ) - self._dummy_run( + dummy_run_callable( num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, + uniform_query_len=uniform_query_len, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, + activate_lora=key.has_lora, ) - self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -4270,6 +4352,27 @@ def _check_and_update_cudagraph_mode( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len ) + # At this moment, we assume the drafter and main model shares the + # same cudagraph_mode if not speculative_config.enforce_eager + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + assert not cudagraph_mode.has_full_cudagraphs(), ( + "Eagle drafter does not support full cudagraphs yet" + ) + cudagraph_mode = ( + self.compilation_config.cudagraph_mode + if not self.speculative_config.enforce_eager + else CUDAGraphMode.NONE + ) + # uniform_query_len is 1 for drafter + # TODO: let uniform_query_lens = [1, self.uniform_decode_query_len] + # for drafter once Padded speculation is supported. See: + # https://github.com/vllm-project/vllm/issues/21984 for details + # and an implementation in https://github.com/vllm-project/vllm/pull/24539 # noqa: E501 + self.drafter.cudagraph_dispatcher.initialize_cudagraph_keys( + cudagraph_mode, uniform_query_lens=1 + ) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f13ff4e726bd..76f90ec08f99 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -535,7 +535,15 @@ def execute_model( intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) + num_reqs = len(scheduler_output.num_scheduled_tokens) + max_query_len = ( + max(scheduler_output.num_scheduled_tokens.values()) + if num_scheduled_tokens > 0 + else 0 + ) + num_input_tokens, _, _ = self.model_runner._get_local_batch_description( + num_scheduled_tokens, num_reqs, max_query_len + ) all_gather_tensors = { "residual": not is_residual_scattered_for_sp( self.vllm_config, num_input_tokens