Skip to content

Conversation

@frank-wei
Copy link
Contributor

@frank-wei frank-wei commented Sep 18, 2025

Purpose

As a follow up for #23734, this PR made some changes to support triton backend for DCP.
Specifically, 1) return the LSE from triton kernel 2) fix a bug in deepseekV2 which could potentially modify the residual variable.

Test Plan

export CUDA_VISIBLE_DEVICES=4,5,6,7
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=TRITON_MLA
export VLLM_LOG_LEVEL=DEBUG
pytest tests/distributed/test_context_parallel.py -s

Test Result

=============================== warnings summary ===============================
:488
:488: DeprecationWarning: builtin type SwigPyPacked has no module attribute

:488
:488: DeprecationWarning: builtin type SwigPyObject has no module attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 2 passed, 2 warnings in 212.42s (0:03:32) ===================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends Distributed Context Parallelism (DCP) support to the Triton backend by enabling the return of Log-Sum-Exp (LSE) values from attention kernels. The changes are generally well-implemented, but I have identified a critical issue where the Multi-Head Attention (MHA) path appears to be broken due to an incomplete function signature update. Additionally, a temporary test script with user-specific configurations seems to have been included by mistake and should be removed.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a perf comparison?

@frank-wei
Copy link
Contributor Author

Do we have a perf comparison?

I do not have a perf comparison. This is more on functional support especially adding LSE return in triton kernel. It won't impact the existing triton kernel performance. Now, I think vllm has completed flashMLA, FA, triton backend support for CP.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! this will help running MLA on ampere or lower end GPUs.

please fix the pre-commit errors.

@mergify
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @frank-wei.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Sep 23, 2025
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Sep 23, 2025
@22quinn 22quinn enabled auto-merge (squash) September 23, 2025 18:22
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 23, 2025
Signed-off-by: Wei Wei <wwei6@meta.com>
auto-merge was automatically disabled September 24, 2025 17:01

Head branch was pushed to by a user without write access

Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: Wei Wei <wwei6@meta.com>
@22quinn 22quinn merged commit 05c1948 into vllm-project:main Sep 25, 2025
50 checks passed
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Wei Wei <wwei6@meta.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants