-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1] TPU - Remove self.kv_caches #14309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,6 @@ | |
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
| KVCacheSpec) | ||
| from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput | ||
| from vllm.v1.utils import bind_kv_cache | ||
| from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -104,9 +103,6 @@ def __init__( | |
| self.max_num_encoder_input_tokens = encoder_compute_budget | ||
| self.encoder_cache_size = encoder_cache_size | ||
|
|
||
| # Lazy initialization | ||
| # self.model: nn.Module # Set after load_model | ||
| self.kv_caches: list[torch.Tensor] = [] | ||
| # req_id -> (input_id -> encoder_output) | ||
| self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} | ||
|
|
||
|
|
@@ -582,7 +578,6 @@ def execute_model( | |
| hidden_states = self.model( | ||
| input_ids=input_ids, | ||
| positions=self.position_ids, | ||
| kv_caches=self.kv_caches, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
| hidden_states = hidden_states[:total_num_scheduled_tokens] | ||
|
|
@@ -680,8 +675,8 @@ def load_model(self) -> None: | |
|
|
||
| def _dummy_run( | ||
| self, | ||
| kv_caches, | ||
| num_tokens: int, | ||
| is_profile_run: bool, | ||
| ) -> None: | ||
| if self.is_multimodal_model: | ||
| input_ids = None | ||
|
|
@@ -728,15 +723,28 @@ def _dummy_run( | |
| torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) | ||
| torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) | ||
|
|
||
| with set_forward_context(attn_metadata, self.vllm_config, 0): | ||
| with set_forward_context(attn_metadata, | ||
| self.vllm_config, | ||
| 0, | ||
| is_profile_run=is_profile_run): | ||
| assert self.model is not None | ||
| self.model( | ||
| input_ids=input_ids, | ||
| positions=position_ids, | ||
| kv_caches=kv_caches, | ||
| inputs_embeds=inputs_embeds, | ||
| ) | ||
|
|
||
| # This is used before KV cache init | ||
| def profile_run(self, num_tokens) -> None: | ||
| self._dummy_run(num_tokens=num_tokens, is_profile_run=True) | ||
|
|
||
| # This is used after KV cache init | ||
| def dummy_run( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self, | ||
| num_tokens: int, | ||
| ) -> None: | ||
| self._dummy_run(num_tokens=num_tokens, is_profile_run=False) | ||
|
|
||
| def capture_model(self) -> None: | ||
| """Compile the model.""" | ||
|
|
||
|
|
@@ -745,7 +753,7 @@ def capture_model(self) -> None: | |
| start = time.perf_counter() | ||
| num_tokens = 16 | ||
| while True: | ||
| self._dummy_run(self.kv_caches, num_tokens) | ||
| self.dummy_run(num_tokens) | ||
| logger.info(" -- num_tokens: %d", num_tokens) | ||
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
|
|
@@ -769,6 +777,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
|
|
||
| kv_caches: dict[str, torch.Tensor] = {} | ||
|
|
||
| kv_cache_shape_prev = None | ||
| for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): | ||
| tensor_config = kv_cache_config.tensors[layer_name] | ||
| assert tensor_config.size % layer_spec.page_size_bytes == 0 | ||
|
|
@@ -779,6 +788,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
| layer_spec.head_size) | ||
| dtype = layer_spec.dtype | ||
|
|
||
| # Ensure all "kv_cache_shape" are the same across the model | ||
| if kv_cache_shape_prev is None: | ||
| kv_cache_shape_prev = kv_cache_shape | ||
| else: | ||
| assert kv_cache_shape == kv_cache_shape_prev | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: is this for ruling out some model architecture? |
||
|
|
||
| tpu_k_cache = torch.zeros(kv_cache_shape, | ||
| dtype=dtype, | ||
| device=self.device) | ||
|
|
@@ -788,23 +803,32 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
| else: | ||
| raise NotImplementedError | ||
|
|
||
| bind_kv_cache( | ||
| kv_caches, | ||
| self.vllm_config.compilation_config.static_forward_context, | ||
| self.kv_caches) | ||
| # ModelWrapperV1 needs to know the KV cache shape | ||
| self.model.set_kv_cache_shape(kv_cache_shape_prev) | ||
|
|
||
| # Associates each attention layer in the `forward_context` with the | ||
| # initialized KV cache. | ||
| forward_context = self.vllm_config.compilation_config \ | ||
| .static_forward_context | ||
| for layer_name, kv_cache in kv_caches.items(): | ||
| # NOTE: Use list because of v0 PP virtual engine. | ||
| forward_context[layer_name].kv_cache = [kv_cache] | ||
|
Comment on lines
+811
to
+815
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: do you see any use in having this bit factored as a util, similarly to bind_kv_cache? We could re-use at least in tpu_worker |
||
|
|
||
|
|
||
| class ModelWrapperV1(nn.Module): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to implement ModelWrapperV1 like this?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @heheda12345 this is not possible because num_blocks is not known until determine_num_available_blocks is done and initialize_kv_cache is executed.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we pass a fake value first and update it after |
||
|
|
||
| def __init__(self, model: nn.Module): | ||
| super().__init__() | ||
| self.model = model | ||
| self.kv_cache_shape = None | ||
|
|
||
| def set_kv_cache_shape(self, kv_cache_shape): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we can probably get away without setters as long as we keep the class and the logic lean |
||
| self.kv_cache_shape = kv_cache_shape | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: list[tuple[torch.Tensor, torch.Tensor]], | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Executes the forward pass of the model and samples the next token. | ||
|
|
@@ -817,16 +841,20 @@ def forward( | |
| inputs_embeds: The input embeddings of shape [num_tokens, | ||
| hidden_size]. It is used for multimodal models. | ||
| """ | ||
| # Skip this in memory profiling at initialization. | ||
| if kv_caches[0][0].numel() > 0: | ||
| attn_metadata = get_forward_context().attn_metadata | ||
| # index_copy_(slot_mapping) only works when the inserted dimension | ||
| # is 0. However, the KV cache in the Pallas backend has the shape | ||
| # [num_kv_heads, num_blocks, block_size, head_size]. To make it | ||
| # work, we need to flatten the first three dimensions and modify | ||
| # the slot_mapping accordingly. | ||
| # kv_caches: list[tuple[torch.Tensor, torch.Tensor]] | ||
| num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape | ||
| forward_context = get_forward_context() | ||
| attn_metadata = forward_context.attn_metadata | ||
|
|
||
| # index_copy_(slot_mapping) only works when the inserted dimension | ||
| # is 0. However, the KV cache in the Pallas backend has the shape | ||
| # [num_kv_heads, num_blocks, block_size, head_size]. To make it | ||
| # work, we need to flatten the first three dimensions and modify | ||
| # the slot_mapping accordingly. | ||
| # | ||
| # Note: We skip this step during first profiling run (before KV init) | ||
| if not forward_context.is_profile_run: | ||
| assert self.kv_cache_shape # Ensure initialized | ||
| num_kv_heads, num_blocks, block_size, _ = self.kv_cache_shape | ||
|
|
||
| slot_mapping = attn_metadata.slot_mapping | ||
| slot_mapping = slot_mapping.flatten() | ||
| head_indicies = torch.arange(0, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this - we should make a
examples/offline_inference/tpu/folder to keep this