-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Misc] Refactor get_kv_cache_spec into AttentionLayerBase
#26587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Refactor get_kv_cache_spec into AttentionLayerBase
#26587
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| **kwargs, | ||
| ) | ||
|
|
||
| def get_kv_cache_spec(self, vllm_config: VllmConfig) -> Optional[KVCacheSpec]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we also have EncoderOnlyAttentionSpec. We skip it in get_kv_cache_spec because these layers doesn't need kv cache, but add it back in may_add_encoder_only_layers_to_kv_cache_config as we need to build attention metadata for these layers. Can you try to make a better abstraction (leaving it as it is now may also be fine if no better idea)
| ) | ||
|
|
||
| ds_indexer_layers = get_layers_from_vllm_config( | ||
| self.vllm_config, DeepseekV32IndexerCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to implement get_kv_cache_spec for DeepseekV32IndexerCache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it already had one, just changed its signature
|
@heheda12345 bridging discussion here
So I think the fact that we need need different specs for worker/scheduler side is a bit of a nuisance here, as I wouldn't want a very generic interface such as the Attention one to be aware of that. Same thing for having interface methods only called from worker. Taking a step back, can we avoid having different worker<>scheduler specs for encoder-only in the first place? |
| kv_cache_dtype = kv_cache_dtype_str_to_dtype( | ||
| self.kv_cache_dtype, vllm_config.model_config.dtype | ||
| ) | ||
| return MLAAttentionSpec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think this is the right spec; the MLA spec should be used for:
Line 567 in 96ad65b
| class MLAAttention(nn.Module, AttentionLayerBase): |
I think this layer (MultiHeadAttention) multimodal models but tbh im not exactly sure where it is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson not sure what's wrong with github preview, this change is actually to the
class MLAAttention(nn.Module, AttentionLayerBase):
MHA is a simple nn.Module, it doesn't even implement the attn interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like if you go to line 815 you can see it belongs to the mla class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry ya you are correct my bad! thats really weird that the GitHub preview didn't show that 🤔
| from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase | ||
| from vllm.model_executor.model_loader import get_model | ||
| from vllm.model_executor.models import supports_multimodal | ||
| from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed here the instance are loaded directly from vllm's models, which is not registered by plugin. Would this be possible to fetch from Model Registry?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, I think it's outside the scope of the PR but we have to change that too
|
@heheda12345 @LucasWilkinson gentle ping on this one |
Can you please address: https://github.com/vllm-project/vllm/pull/26587/files#r2421239326 Otherwise LGTM |
| compilation_config.static_forward_context[prefix] = self | ||
|
|
||
| def get_kv_cache_spec(self) -> KVCacheSpec: | ||
| def get_kv_cache_spec(self, vllm_config: VllmConfig) -> Optional[KVCacheSpec]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 deepseek change
| kv_cache_dtype = kv_cache_dtype_str_to_dtype( | ||
| self.kv_cache_dtype, vllm_config.model_config.dtype | ||
| ) | ||
| return MLAAttentionSpec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson not sure what's wrong with github preview, this change is actually to the
class MLAAttention(nn.Module, AttentionLayerBase):
MHA is a simple nn.Module, it doesn't even implement the attn interface
| kv_cache_dtype = kv_cache_dtype_str_to_dtype( | ||
| self.kv_cache_dtype, vllm_config.model_config.dtype | ||
| ) | ||
| return MLAAttentionSpec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like if you go to line 815 you can see it belongs to the mla class
| from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase | ||
| from vllm.model_executor.model_loader import get_model | ||
| from vllm.model_executor.models import supports_multimodal | ||
| from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, I think it's outside the scope of the PR but we have to change that too
|
@LucasWilkinson uh, somehow I didn't send the response to your comment last week... basically I am just saying MHA wasn't edited, I am not sure why github shows it like that, but only MLAAttention got changed |
|
This pull request has merge conflicts that must be resolved before it can be |
e9cbbb9 to
d15cc44
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
3045ede to
973da05
Compare
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
395442e to
b615111
Compare
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…roject#26587) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
|
I've just found that this PR breaks the behaviour of passing If you instantiate Is this expected? If yes:
|
|
Nice catch! I think we can always use EncoderOnlyAttention and start the deprecation of attn_type. As the Atttention class is used by too many people, I prefer to add a deprecation warning first and remove it in a future release. |
This PR modifies the
AttentionLayerBaseinterface to add a newget_kv_cache_specmethod.This allows different attention layers to define their own KV Cache spec, by making the spec entirely transparent to the Model Runner.
As a consequence, the runner can now limit itself to collect the specs without having to handle different attention types and/or model-specific hacks such as the one for DSv32 Indexer.
It also makes the code much simpler as all ENCODER,ENCODER_ONLY and ENCODER_DECODER type management is moved to a method dispatch system.
cc @heheda12345 who clearly defined the task
PS this used to be a TODO in code from @LucasWilkinson https://github.com/vllm-project/vllm/blob/releases/v0.11.0/vllm/v1/worker/gpu_model_runner.py#L4065