Skip to content

Conversation

@WyldeCat
Copy link
Contributor

@WyldeCat WyldeCat commented Oct 23, 2025

Purpose

Changes

New model for :
https://huggingface.co/Motif-Technologies/Motif-2-12.7B-Base
https://huggingface.co/Motif-Technologies/Motif-2.6b-v1.1-LC
https://huggingface.co/Motif-Technologies/Motif-2.6B

co-author : @ca1207

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

This reverts commit 3125d79.
Signed-off-by: WyldeCat <skan1543@gmail.com>
Signed-off-by: WyldeCat <skan1543@gmail.com>
Signed-off-by: WyldeCat <skan1543@gmail.com>
@mergify
Copy link

mergify bot commented Oct 23, 2025

Documentation preview: https://vllm--27396.org.readthedocs.build/en/27396/

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models performance Performance-related issues v1 labels Oct 23, 2025
@WyldeCat WyldeCat changed the title Re-support Motif Model Re-support MotifForCausalLM Oct 23, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request re-introduces support for the Motif model, which involves adding a new PolyNorm layer with its corresponding CUDA kernel, and a GroupedDifferentialAttention backend. The implementations for PolyNorm and the new model architecture appear to be correct and follow best practices. However, I've identified a critical issue in the caching logic within the new GroupedDifferentialAttentionBackend that could lead to incorrect behavior and needs to be addressed.

Comment on lines +680 to +701
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_single_attention method includes a block for caching key and value tensors. This is problematic because the main forward method already performs the necessary caching for the Grouped Differential Attention (GDA) splits (k1, v1 and k2, v2) using the populate_kv_cache method before forward_single_attention is ever called.

This leads to two significant issues:

  1. Redundant Caching: The same key-value pairs are cached multiple times, which is inefficient.
  2. Incorrect Caching: For cross-split attention computations (e.g., Attn(q1, K1, V2)), forward_single_attention is invoked with mismatched key and value tensors (like k1 and v2). The reshape_and_cache_flash call within this method is not designed to handle such cases and will likely corrupt the cache state.

The caching logic should be centralized within the forward method. The forward_single_attention method should then only be responsible for the attention computation, using the already populated cache, without performing any caching itself.

To resolve this, the caching block within forward_single_attention should be removed. The key and value arguments are still necessary for other parts of the function (e.g., descale_shape calculation), so they should remain in the function signature.

@WyldeCat WyldeCat changed the title Re-support MotifForCausalLM [Model] Re-support MotifForCausalLM Oct 24, 2025
@WyldeCat
Copy link
Contributor Author

Hi @jeejeelee , I hope you’re doing well!
I just wanted to kindly check if there’s anything I should address with this PR.
Since you reviewed our previous PR, I thought I’d ask if everything looks okay or if there’s anything you’d like me to update.
Thanks a lot for your time!

@@ -0,0 +1,744 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 Who can review this attention?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants