Skip to content

Conversation

@qiruiyangmeta
Copy link

@qiruiyangmeta qiruiyangmeta commented Oct 1, 2025

Purpose

Context parallelism improves performance as the context length grows by distributing both computation and the KV cache across multiple GPUs. This approach effectively lowers processing latency and can also decrease the memory required per GPU potentially, especially when dealing with extremely large KV caches (such as sequence lengths on the order of 1 million tokens), as shown in the figure below. This PR add initial context parallel parallel configuration and communication group.

image

Test Plan

Unit tests and e2e tests will be submitted in following PRs


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces context parallelism by adding configurations and communication groups. The changes are mostly correct and consistent with the existing parallelism structure. However, I've identified a critical bug in the creation of expert parallel groups when context parallelism is enabled. The logic for grouping ranks for expert parallelism is incorrect for the new 5D tensor layout, which will lead to incorrect behavior for MoE models. I've provided a fix for this issue.

Comment on lines 1265 to 1266
group_ranks = (all_ranks.transpose(1, 2).reshape(
-1, data_parallel_size * tensor_model_parallel_size).unbind(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for creating the expert parallel group (_EP) is incorrect with the introduction of the context parallelism dimension. The all_ranks tensor is now 5D with shape (ExtDP, DP, PP, CP, TP). The expert parallel group should group ranks that have the same (ExtDP, PP, CP) coordinates, which means it should span across DP and TP dimensions. The current transpose(1, 2) operation is incorrect for this 5D tensor and does not produce the correct grouping. This will lead to incorrect behavior for MoE models when context parallelism is enabled. It should be replaced with a permutation that brings the DP and TP dimensions to the end before reshaping.

group_ranks = (all_ranks.permute(0, 2, 3, 1, 4).reshape(
    -1, data_parallel_size * tensor_model_parallel_size).unbind(0))

@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @qiruiyangmeta.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
@hmellor
Copy link
Member

hmellor commented Oct 8, 2025

These conflicts are caused by our migration to ruff. Please see https://vllm-dev.slack.com/archives/C07R5Q1Q2BB/p1759663228844749 which contains detailed instructions to make updating your branch as painless as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants