Open
Conversation
…arallelism Restore Qwen2.5 model in external rollout test
…nges Handle CP OPSM masks centrally and restore loss guardrails
…source Simplify context-parallel seq KL helper
…d-clarity-issues Fix OPSM CP reduction and clean up interfaces
…stallation Prefer binary installs in build script
…nto modify-build_conda.sh-for-direct-installation-wkwuyr
…stallation-wkwuyr Adjust flash-attn wheel selection for torch pin
90fba9b to
ca4b729
Compare
Contributor
Author
|
@PopSoda2002 Thanks a lot for helping with the review, or please let me know who I should reach out to for it. |
Collaborator
|
This PR only optimized Megatron CP performance, and for FSDP it only make it compatible with the change. Is my understanding correct? |
Contributor
Author
Yes. Does FSDP also need optimization? I’ll take a look. |
Collaborator
We can do FSDP in next PR potentially. Otherwise the PR would be too large to review. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR optimises the computation of context‑parallel (CP) loss by avoiding full all‑gathers when multiple CP ranks are used. It introduces sequence‑level KL preparation and adapts the computation depending on whether cp_size equals 1 or exceeds it.
Key Changes
Sequence‑level KL preparation: When cp_size > 1, log‑probability and mask segments from each rank are used to compute per‑rank partial KL values, which are then aggregated. This avoids gathering the entire log‑probability tensor across ranks.
Explicit CP metadata: The code now explicitly determines cp_size, cp_rank, and cp_group. If multiple ranks are present, these are used to coordinate partial computations; if cp_size is 1, the logic simply uses the local tensors.
KL divergence formula: For a sequence j with prompt length p_j and total length T_j, the sequence‑level KL is defined as
In this expression:
p_j denotes the prompt length of sequence j.
T_j is the total length of sequence j.
log p_new_j(t) and log p_old_j(t) are the log‑probabilities under the new and old policies, respectively, at position t in sequence j.
mask_j(t) is a mask value (typically 0 or 1) that indicates whether the token at position t should contribute to the KL sum.
When cp_size == 1, this sum is computed directly on the entire sequence. When cp_size > 1, each rank computes the sum over its local segment (adjusted by its token offset), and the partial sums from all ranks are added together to recover the same result as a full computation.
OPSM inputs: Added logic to generate OPSM inputs from sequence‑level KLs when OPSM is enabled. This uses the computed seq_kls and chunked_loss_masks.
No unnecessary communication: By distinguishing between cp_size == 1 and cp_size > 1, the PR ensures that data is only communicated across ranks when needed, preserving correctness while reducing overhead.
Motivation
#1062 (comment)
The previous implementation gathered full log‑probability tensors across all context‑parallel ranks regardless of configuration, which was inefficient. By computing sequence‑level KLs locally and aggregating only the necessary values, this PR reduces communication overhead and makes loss computation scale more gracefully with the number of CP ranks. It maintains mathematical equivalence with the single‑rank calculation thanks to the linearity of the KL formula.
Impact
These changes should improve training performance in multi‑rank CP configurations and provide clearer, more explicit handling of context‑parallel metadata. The inclusion of the KL formula clarifies the computation and demonstrates that splitting the sum across ranks yields the same result as computing it in one pass when cp_size equals 1.
From a theoretical perspective, the speedup comes directly from reducing the communication volume before synchronization. A standard communication cost model writes the latency as$$T \approx \alpha + \beta \cdot \text{bytes} $$ , where $$\alpha$$ is the fixed startup and synchronization overhead and $$\beta$$ reflects the inverse effective bandwidth. In the all-gather formulation, each rank must exchange token-level tensors whose size scales with the total sequence length $$L_{\text{total}}$$ , leading to a cost on the order of $$\alpha + \beta \cdot s \cdot B \cdot L_{\text{total}} $$ . In contrast, the sequence-level KL approach locally reduces over the token dimension first and only communicates one scalar per sequence, so the cost becomes $$\alpha + \beta \cdot s \cdot B$$ . The resulting theoretical optimization factor can therefore be approximated as
When sequence lengths are sufficiently large and the bandwidth term dominates, this ratio grows roughly with$$L_{\text{total}}$$ , explaining why eliminating token-level communication yields substantial speedups in practice.
Using
examples/reproducibility/run-qwen2.5-0.5B-gsm8k.shto run tests on 4 GPUs, the results show that the seq-KL approach is about twice as fast as all_gather, with the time reduced from 0.002 to 0.001.