Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Oct 10, 2025

This PR modifies the AttentionLayerBase interface to add a new get_kv_cache_spec method.
This allows different attention layers to define their own KV Cache spec, by making the spec entirely transparent to the Model Runner.

As a consequence, the runner can now limit itself to collect the specs without having to handle different attention types and/or model-specific hacks such as the one for DSv32 Indexer.
It also makes the code much simpler as all ENCODER,ENCODER_ONLY and ENCODER_DECODER type management is moved to a method dispatch system.

cc @heheda12345 who clearly defined the task

PS this used to be a TODO in code from @LucasWilkinson https://github.com/vllm-project/vllm/blob/releases/v0.11.0/vllm/v1/worker/gpu_model_runner.py#L4065

@NickLucche NickLucche marked this pull request as ready for review October 10, 2025 13:23
@mergify mergify bot added deepseek Related to DeepSeek models v1 labels Oct 10, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

**kwargs,
)

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> Optional[KVCacheSpec]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we also have EncoderOnlyAttentionSpec. We skip it in get_kv_cache_spec because these layers doesn't need kv cache, but add it back in may_add_encoder_only_layers_to_kv_cache_config as we need to build attention metadata for these layers. Can you try to make a better abstraction (leaving it as it is now may also be fine if no better idea)

)

ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to implement get_kv_cache_spec for DeepseekV32IndexerCache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it already had one, just changed its signature

@NickLucche
Copy link
Collaborator Author

@heheda12345 bridging discussion here

Note that we also have EncoderOnlyAttentionSpec. We skip it in get_kv_cache_spec because these layers doesn't need kv cache, but add it back in may_add_encoder_only_layers_to_kv_cache_config as we need to build attention metadata for these layers. Can you try to make a better abstraction (leaving it as it is now may also be fine if no better idea)

So I think the fact that we need need different specs for worker/scheduler side is a bit of a nuisance here, as I wouldn't want a very generic interface such as the Attention one to be aware of that. Same thing for having interface methods only called from worker.
Hence as long as a specific branching is needed for encoder-only, I think the current may_add_encoder_only_layers_to_kv_cache_config setup is clearer.

Taking a step back, can we avoid having different worker<>scheduler specs for encoder-only in the first place?

kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config.dtype
)
return MLAAttentionSpec(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is the right spec; the MLA spec should be used for:

class MLAAttention(nn.Module, AttentionLayerBase):

I think this layer (MultiHeadAttention) multimodal models but tbh im not exactly sure where it is used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson not sure what's wrong with github preview, this change is actually to the
class MLAAttention(nn.Module, AttentionLayerBase):

MHA is a simple nn.Module, it doesn't even implement the attn interface

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like if you go to line 815 you can see it belongs to the mla class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry ya you are correct my bad! thats really weird that the GitHub preview didn't show that 🤔

from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed here the instance are loaded directly from vllm's models, which is not registered by plugin. Would this be possible to fetch from Model Registry?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, I think it's outside the scope of the PR but we have to change that too

@NickLucche
Copy link
Collaborator Author

@heheda12345 @LucasWilkinson gentle ping on this one

@LucasWilkinson
Copy link
Collaborator

@heheda12345 @LucasWilkinson gentle ping on this one

Can you please address: https://github.com/vllm-project/vllm/pull/26587/files#r2421239326

Otherwise LGTM

compilation_config.static_forward_context[prefix] = self

def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> Optional[KVCacheSpec]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 deepseek change

kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config.dtype
)
return MLAAttentionSpec(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson not sure what's wrong with github preview, this change is actually to the
class MLAAttention(nn.Module, AttentionLayerBase):

MHA is a simple nn.Module, it doesn't even implement the attn interface

kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config.dtype
)
return MLAAttentionSpec(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like if you go to line 815 you can see it belongs to the mla class

from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, I think it's outside the scope of the PR but we have to change that too

@NickLucche
Copy link
Collaborator Author

@LucasWilkinson uh, somehow I didn't send the response to your comment last week... basically I am just saying MHA wasn't edited, I am not sure why github shows it like that, but only MLAAttention got changed

@mergify
Copy link

mergify bot commented Oct 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 15, 2025
@NickLucche NickLucche force-pushed the get-kvcache-spec-refactor branch from e9cbbb9 to d15cc44 Compare October 16, 2025 13:20
@mergify mergify bot removed the needs-rebase label Oct 16, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) October 16, 2025 13:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 16, 2025
@NickLucche NickLucche force-pushed the get-kvcache-spec-refactor branch from 3045ede to 973da05 Compare October 17, 2025 16:28
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche force-pushed the get-kvcache-spec-refactor branch from 395442e to b615111 Compare October 18, 2025 11:52
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@LucasWilkinson LucasWilkinson merged commit b26b70b into vllm-project:main Oct 18, 2025
58 checks passed
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
adabeyta pushed a commit to adabeyta/vllm that referenced this pull request Oct 20, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…roject#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
@hmellor
Copy link
Member

hmellor commented Nov 4, 2025

I've just found that this PR breaks the behaviour of passing attn_type to Attention to select the attention type.

If you instantiate Attention(..., attn_type=AttentionType.ENCODER_ONLY), the new get_kv_cache_spec of Attention will be called which asserts that attn_type == AttentionType.DECODER.

Is this expected?

If yes:

  • Does this mean that we should explicitly use Attention and EncoderOnlyAttention classes?
  • Should attn_type be removed from __init__ because it can no longer effectively modify the used attention type?

@heheda12345
Copy link
Collaborator

Nice catch! I think we can always use EncoderOnlyAttention and start the deprecation of attn_type. As the Atttention class is used by too many people, I prefer to add a deprecation warning first and remove it in a future release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants