Skip to content

Conversation

@ca1207
Copy link
Contributor

@ca1207 ca1207 commented Aug 22, 2025

Purpose

New model for :
https://huggingface.co/Motif-Technologies/Motif-2.6B
https://huggingface.co/Motif-Technologies/Motif-2.6b-v1.1-LC
Tech report (https://arxiv.org/pdf/2508.09148)

Implemented the Polynorm custom kernel based on https://arxiv.org/abs/2411.03884
co-author : @WyldeCat

Test Plan

benchmark for Polynorm kernel (benchmarks/kernels/benchmark_polynorm.py)

Test Result

Benchmark Results

Results for dim = 2048
dim batch_size seq_len Naive vLLM speed up
2048 1 64 197.76 8.61 22.97
2048 1 128 197.31 9.18 21.48
2048 1 256 194.34 11.36 17.11
2048 1 512 274.91 13.66 20.12
2048 1 1024 552.77 20.51 26.95
2048 4 64 194.46 11.62 16.74
2048 4 128 275.07 13.63 20.18
2048 4 256 551.94 20.54 26.87
2048 4 512 1044.08 34.53 30.24
2048 4 1024 1971.10 61.57 32.02
2048 16 64 552.70 20.58 26.86
2048 16 128 1045.07 34.56 30.24
2048 16 256 1971.33 61.15 32.24
2048 16 512 3811.10 129.89 29.34
2048 16 1024 7492.06 237.12 31.60
2048 64 64 1971.89 61.54 32.04
2048 64 128 3811.01 129.63 29.40
2048 64 256 7491.81 236.13 31.73
2048 64 512 14854.58 479.70 30.97
2048 64 1024 29569.15 975.44 30.31
Results for dim = 4096
dim batch_size seq_len Naive vLLM speed up
4096 1 64 196.59 9.86 19.95
4096 1 128 196.13 10.78 18.19
4096 1 256 269.94 14.40 18.75
4096 1 512 545.79 21.22 25.73
4096 1 1024 1034.14 38.98 26.53
4096 4 64 269.34 14.50 18.58
4096 4 128 545.15 21.09 25.85
4096 4 256 1033.57 39.01 26.50
4096 4 512 1963.90 71.52 27.46
4096 4 1024 3804.32 135.81 28.01
4096 16 64 1034.14 39.09 26.46
4096 16 128 1964.13 71.49 27.47
4096 16 256 3804.16 135.81 28.01
4096 16 512 7476.80 265.06 28.21
4096 16 1024 14831.06 523.97 28.31
4096 64 64 3803.46 135.74 28.02
4096 64 128 7477.31 265.09 28.21
4096 64 256 14832.98 524.03 28.31
4096 64 512 29530.88 1041.42 28.36
4096 64 1024 58956.83 2080.24 28.34

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models performance Performance-related issues labels Aug 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Motif model, including a new PolyNorm custom kernel and a DifferentialFlashAttention backend. The changes are extensive and well-structured. However, I've identified a critical issue in the MotifDecoderLayer implementation where the forward pass logic deviates from the reference model's pre-normalization architecture, which will likely result in incorrect outputs. Additionally, there is a minor typo in the new benchmark script for PolyNorm.

@ca1207 ca1207 marked this pull request as draft August 22, 2025 06:50
WyldeCat and others added 4 commits August 22, 2025 07:11
Signed-off-by: WyldeCat <skan1543@gmail.com>

Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Sync with https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py#L366

Signed-off-by: WyldeCat <skan1543@gmail.com>

Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@ca1207 ca1207 marked this pull request as ready for review August 22, 2025 07:19
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
"MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B",
trust_remote_code=True,
v0_only=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why does this model only support V0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our model uses the differential attention introduced in this PR.
As you can see here, the current differential attention backend does not support chunked prefill, so it is only available in v0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 If this model only supports V0, how should we handle this now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can have the model inherit from SupportsV0Only explicitly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And state the reason why via code comment for future reference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon is it possible to turn off chunked prefill for generative models in V1?

@WyldeCat
Copy link
Contributor

Hi @jeejeelee , just a gentle ping on this PR 🙂.
Would love your feedback when you get a chance.

ca1207 and others added 2 commits September 1, 2025 11:36
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@ca1207
Copy link
Contributor Author

ca1207 commented Sep 1, 2025

@jeejeelee I have addressed your PR review comments. Please feel free to take a look whenever you have some time.

@hmellor
Copy link
Member

hmellor commented Sep 1, 2025

  • Please merge from main to hopefully fix the failing test
  • Please fix pre-commit

ca1207 and others added 3 commits September 2, 2025 02:40
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@jeejeelee jeejeelee removed the ready ONLY add when PR is ready to merge/full CI is needed label Sep 9, 2025
@jeejeelee
Copy link
Collaborator

Please resolve the branch conflicts, then considering that the changes in this PR won't significantly affect other features, we can go ahead and merge.

Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Overall looks good to me, just a few comments

}
};

using BlockReduce = cub::BlockReduce<float3, 1024>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we set BlockReduce with blockDim.x 1024
But blockDim.x is decided in

const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size)); 

In this case, worried the inconsistency will cause trouble, could you test on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. In test_layernorm.py, we already have a test case for num_tokens < 256, and it works fine in that case.

NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
                8199]  # Arbitrary values for testing
ADD_RESIDUAL = [False, True]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

ca1207 and others added 2 commits September 10, 2025 11:53
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@mergify
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ca1207.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 10, 2025
@DarkLight1337
Copy link
Member

Please fix the merge conflicts

Signed-off-by: ca1207 <ca1207zzz@gmail.com>
@mergify mergify bot removed the needs-rebase label Sep 11, 2025
@ca1207
Copy link
Contributor Author

ca1207 commented Sep 11, 2025

@DarkLight1337
Hi! have quick question
If the main branch has been updated without any conflicts, do I still need to keep updating it?
Currently, while my GitHub Action is running, main continues to receive updates, and each time I merge these updates, the workflow restarts from the beginning.

@DarkLight1337
Copy link
Member

No need to update as long as there are no merge conflicts.

@vllm-bot vllm-bot merged commit 9bd831f into vllm-project:main Sep 11, 2025
69 of 72 checks passed
@DarkLight1337 DarkLight1337 added this to the v0.10.2 milestone Sep 11, 2025
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants