-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[V1] TPU - Remove self.kv_caches #14309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this - we should make a examples/offline_inference/tpu/ folder to keep this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add is_profile_run to forward_context, we need to change other backends to pass the is_profile_run arg. Can we achieve it by passing the attributes to ModelWrapperV1 like the pseudo code?
| forward_context[layer_name].kv_cache = [kv_cache] | ||
|
|
||
|
|
||
| class ModelWrapperV1(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to implement ModelWrapperV1 like this?
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module, num_kv_heads, num_blocks, block_size):
super().__init__()
self.model = model
self.num_kv_heads = num_kv_heads
...
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
is_profile_run: bool,
) -> torch.Tensor:
if not is_profile_run:
num_kv_heads = self.num_kv_heads
...
class TPUModelRunner:
def _dummy_run(
self,
num_tokens: int,
is_profile_run: bool,
) -> None:
self.model.forward(..., is_profile_run=is_profile_run)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 this is not possible because num_blocks is not known until determine_num_available_blocks is done and initialize_kv_cache is executed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pass a fake value first and update it after determine_num_available_blocks?
|
@heheda12345 the is_profile_run arg is set to False by default, so it should not be necessary to pass this parameter explicitly to set_forward_context(..) function in other backends. Is there any specific code example where you would need to specify it explicitly? |
|
If we add |
|
@heheda12345 I see makes sense. There is a new PR from Google that removes the problematic non-profile reshuffling #14310 which will affect this PR. I would wait for this PR to land first and then rebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me thanks! Only had minor comments on my side. Let's wait for the rebase to deal with is_profile_run.
| self._dummy_run(num_tokens=num_tokens, is_profile_run=True) | ||
|
|
||
| # This is used after KV cache init | ||
| def dummy_run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_dummy_run and dummy_run look confusing, do we really need this overload?
| self.model = model | ||
| self.kv_cache_shape = None | ||
|
|
||
| def set_kv_cache_shape(self, kv_cache_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can probably get away without setters as long as we keep the class and the logic lean
| if kv_cache_shape_prev is None: | ||
| kv_cache_shape_prev = kv_cache_shape | ||
| else: | ||
| assert kv_cache_shape == kv_cache_shape_prev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: is this for ruling out some model architecture?
| forward_context = self.vllm_config.compilation_config \ | ||
| .static_forward_context | ||
| for layer_name, kv_cache in kv_caches.items(): | ||
| # NOTE: Use list because of v0 PP virtual engine. | ||
| forward_context[layer_name].kv_cache = [kv_cache] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you see any use in having this bit factored as a util, similarly to bind_kv_cache? We could re-use at least in tpu_worker
|
This pull request has merge conflicts that must be resolved before it can be |
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
|
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
This PR removes self.kv_caches from the tpu_model_runner.py in V1, so that @heheda12345 #14098 can cleanly land.
@mgoin @NickLucche feel free to make a pass.