diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 27741bd156be..5428f4cb5c40 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -9,8 +9,6 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -24,27 +22,6 @@ MAX_NUM_PROMPT_TOKENS = 64 -def get_kv_cache_config() -> KVCacheConfig: - return KVCacheConfig( - num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, - kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=1, - num_kv_heads=1, - head_size=16, - dtype=torch.float16, - use_mla=False, - ), - ), - ], - ) - - def _compare_objs(obj1, obj2): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ @@ -251,7 +228,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - block_size=1, + block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -341,7 +318,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - block_size=1, + block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -350,7 +327,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - block_size=1, + block_sizes=[1], ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6ba6d1f6f131..96bbd9a8bb9e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -49,7 +49,9 @@ def initialize_kv_cache(runner: GPUModelRunner): device=runner.device, pin_memory=runner.pin_memory, vocab_size=runner.model_config.get_vocab_size(), - block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size, + block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) runner.initialize_attn_backend(kv_cache_config) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 576086ebeb7f..e21d6e1655b9 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -104,10 +104,11 @@ class MultiGroupBlockTable: def __init__(self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, block_size: int) -> None: + device: torch.device, block_sizes: list[int]) -> None: self.block_tables = [ BlockTable(max_num_reqs, cdiv(max_model_len, block_size), max_num_batched_tokens, pin_memory, device) + for block_size in block_sizes ] def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b3e65917d3cc..726f9111bcae 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -55,14 +55,14 @@ def get_token_id(self, idx: int) -> int: class InputBatch: def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_size: int, + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -104,7 +104,7 @@ def __init__( max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, - block_size=block_size, + block_sizes=block_sizes, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9f7c474c71cb..e10038af77d1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -142,7 +142,6 @@ def __init__( self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig - # self.input_batch: InputBatch # Persistent batch. # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -172,6 +171,15 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -179,7 +187,7 @@ def __init__( device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_size=self.cache_config.block_size, + block_sizes=[self.cache_config.block_size], ) self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -2033,6 +2041,35 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + if block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + ) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2040,11 +2077,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") self.kv_cache_config = kv_cache_config + self.may_reinitialize_input_batch(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches: dict[str, torch.Tensor] = {} diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5de92351e24b..f27696213059 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -187,7 +187,7 @@ def __init__( device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_size=self.block_size, + block_sizes=[self.block_size], ) # Cached torch/numpy tensor @@ -1316,8 +1316,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec. - block_size, + block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert self.block_table_cpu.dtype == self.input_batch.block_table[