-
Notifications
You must be signed in to change notification settings - Fork 1.6k
[TRTLLM-6445] feat: Enable AllReduce-associated fusion patterns in Llama3/4. #6205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 Walkthrough""" WalkthroughThe changes refactor fusion and all-reduce logic in Llama decoder layers, introducing environment-variable-based fusion enablement, quantization-aware fusion, and cross-layer fusion via new attributes. The code unifies fusion and all-reduce control, adds quantization support, links normalization and attention modules across layers for enhanced fusion capabilities, and adds a CUDA kernel launch bounds attribute. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant LlamaForCausalLM
participant DecoderLayer
participant NextDecoderLayer
User->>LlamaForCausalLM: load_weights(weights)
LlamaForCausalLM->>DecoderLayer: set next_layer_layernorm, next_attn (from NextDecoderLayer)
Note right of DecoderLayer: Enables cross-layer fusion
User->>DecoderLayer: forward(input)
DecoderLayer->>DecoderLayer: Check enable_fusion (from env)
DecoderLayer->>DecoderLayer: Set fusion and all-reduce flags
DecoderLayer->>DecoderLayer: If quantized, use quantization-aware fusion
DecoderLayer->>DecoderLayer: Perform forward pass with fusion logic
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧠 Learnings (1)tensorrt_llm/_torch/models/modeling_llama.py (1)Learnt from: yechank-nvidia ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (22)
✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
cpp/tensorrt_llm/thop/allreduceOp.cpp (1)
416-425
: Remove dead code or clarify the forced oneshot overrideThe unconditional
allreduce_fusion_params.use_oneshot = true;
(line 425) makes the preceding branch (lines 417–423) and the TWOSHOT validation block (lines 427–432) unreachable. Please either:
- Remove the dead code (the conditional that sets
use_oneshot
based onstrategy
/seq_len
and the subsequent TWOSHOT check), if this override is permanent.- Or, add a clear comment explaining why oneshot is being forced, when/under what conditions it will be revisited, and disable or gate the override accordingly.
Locations to update:
• cpp/tensorrt_llm/thop/allreduceOp.cpp, around lines 417–432Suggested diff:
- // Determine if using oneshot or twoshot allreduce kernel - if (strategy == AllReduceStrategyType::MIN_LATENCY) - { - allreduce_fusion_params.use_oneshot = seq_len <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken; - } - else - { - allreduce_fusion_params.use_oneshot = strategy == AllReduceStrategyType::ONESHOT; - } - // Force use oneshot - allreduce_fusion_params.use_oneshot = true; - - // Check for some kernel constraints if using TWOSHOT kernel - if (!allreduce_fusion_params.use_oneshot) - { - TORCH_CHECK(input.size(0) >= static_cast<int64_t>(tp_size), - "Sequence length must be greater than or equal to TP size"); - } + // Force use oneshot kernel for all fusion patterns. + // TODO: Remove this override or restore conditional logic after benchmarking with fp4/fp8. + allreduce_fusion_params.use_oneshot = true;
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
626-629
: Consider unifying fusion configuration between decoder classes
LlamaDecoderLayer
uses instance attributes (self.PRE_MLP_FUSION
,self.POST_MLP_FUSION
) whileLlama4DecoderLayer
usesself.fusion_config.PRE_MLP_FUSION
. Consider using a consistent approach across both classes for better maintainability.+ self.fusion_config = EagerFusionConfig() - self.PRE_MLP_FUSION = self.mapping.has_tp( + self.fusion_config.PRE_MLP_FUSION = self.mapping.has_tp( ) and not self.enable_attention_dp and self.enable_fusion - self.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusion + self.fusion_config.POST_MLP_FUSION = self.mapping.has_tp() and self.enable_fusionThen update the usage in the forward method accordingly.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
cpp/tensorrt_llm/thop/allreduceOp.cpp
(1 hunks)tensorrt_llm/_torch/models/modeling_llama.py
(15 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
cpp/tensorrt_llm/thop/allreduceOp.cpp (1)
424-425
: Verify performance impact of forcing oneshot kernelForcing the oneshot kernel for all cases might not be optimal, especially for longer sequences where the twoshot kernel could be more efficient. The PR aims to enable fusion patterns for fp4/fp8 quantization, but it's unclear if oneshot is the best choice for all scenarios.
Could you clarify:
- Is this change temporary for testing or permanent?
- Have you benchmarked the performance impact for various sequence lengths?
- Should this be configurable based on quantization type (fp4/fp8)?
tensorrt_llm/_torch/models/modeling_llama.py (4)
341-343
: Ensure consistent AllReduce behavior with attention DPThe condition for performing AllReduce checks both
enable_attention_dp
andhas_tp()
. When attention data parallelism is enabled, the AllReduce is skipped. Please verify this is the intended behavior for the fusion patterns with fp4/fp8 quantization.
482-492
: Consistent handling of quantization-aware fusion outputsThe quantization-aware fusion code properly handles NVFP4 outputs by unpacking them into
Fp4QuantizedTensor
objects. The implementation looks correct for both pre-fusion and post-fusion cases.Good job on maintaining consistency across different fusion points!
Also applies to: 544-555, 673-691, 703-721
913-924
: Well-designed cross-layer fusion setupThe
load_weights
method properly sets up cross-layer references (next_layer_layernorm
andnext_attn
) to enable fusion across decoder layers. This is a clean approach to enable cross-layer optimizations.The special handling for the last layer (using
model.norm
) is also correct.Also applies to: 1051-1058
853-853
: Mapping.has_tp() is a defined method—no change neededThe
Mapping
class in tensorrt_llm/mapping.py (around line 387) implementsdef has_tp(self):…
, so invokingmodel_config.mapping.has_tp()
is the intended API. You can safely ignore the suggestion to switch to a property-based check.Likely an incorrect or invalid review comment.
4b557d0
to
b73390b
Compare
/bot run --add-multi-gpu-test --disable-fail-fast |
PR_Github #12388 [ run ] triggered by Bot |
PR_Github #12388 [ run ] completed with state |
3ca263d
to
74d7950
Compare
/bot run --add-multi-gpu-test --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
688-697
: Similar code duplication issue as in Llama4DecoderLayer.The fusion op assignment logic here has the same duplication pattern as in Llama4DecoderLayer. The same refactoring suggestion applies.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
699-700
: Document the reason for disabling POST_MLP_FUSION.The TODO comment indicates POST_MLP_FUSION is disabled due to accuracy drop, but lacks specificity about the issue or timeline for resolution.
Consider expanding the comment to include:
- What kind of accuracy drop was observed
- Under what conditions it occurs
- Timeline or conditions for re-enabling
- Alternative approaches being considered
- # TODO: Disable this to avoid large accuracy drop + # TODO: POST_MLP_FUSION disabled due to accuracy degradation in specific workloads + # Re-enable after investigating root cause and implementing fixes self.POST_MLP_FUSION = False
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_llama.py
(14 hunks)
🔇 Additional comments (14)
tensorrt_llm/_torch/models/modeling_llama.py (14)
2-2
: LGTM: Import addition is necessary for environment variable access.The
os
import is correctly added to support the new environment variable-based fusion control functionality.
341-341
: LGTM: Improved abstraction for tensor parallel condition.Using
self.mapping.has_tp()
instead of direct tensor parallel size comparison provides better encapsulation and consistency with the rest of the codebase.
388-390
: LGTM: Environment variable control for fusion enablement.The environment variable
TRTLLM_LLAMA_EAGER_FUSION_DISABLED
provides appropriate runtime control for fusion behavior. The naming is model-specific and clear in intent.
473-481
: LGTM: Consolidated allreduce disable logic.The boolean flags
disable_attn_allreduce
anddisable_feed_forward_allreduce
properly consolidate the conditions for disabling allreduce operations, making the logic clearer and more maintainable.
511-513
: LGTM: Proper usage of consolidated allreduce disable flags.The forward method correctly uses
disable_attn_allreduce
anddisable_feed_forward_allreduce
flags to control AllReduce operations, maintaining consistency with the initialization logic.Also applies to: 557-559
516-520
: LGTM: Quantization-aware scale extraction.The scale extraction logic properly handles both NVFP4 and FP8 quantization modes, providing the necessary scale information for fusion operations.
547-552
: LGTM: Proper handling of NVFP4 quantized outputs.The code correctly unpacks NVFP4 quantized outputs into
Fp4QuantizedTensor
objects, maintaining the quantization state through the computation pipeline.
570-628
: LGTM: Cross-layer fusion implementation.The cross-layer fusion logic using
next_layer_layernorm
andnext_attn
attributes enables advanced fusion patterns. The implementation handles both normal and cutlass min-latency modes appropriately, with proper quantization support.
642-650
: LGTM: Proper initialization of new member variables.The addition of mapping, quantization flags, and next layer references follows the same pattern as Llama4DecoderLayer and provides necessary context for fusion operations.
Also applies to: 672-676
682-687
: LGTM: Consistent fusion control implementation.The environment variable-based fusion control and flag initialization follows the same pattern as Llama4DecoderLayer, maintaining consistency across model variants.
728-730
: LGTM: Consistent allreduce disable flag usage.The forward method correctly uses the consolidated disable flags for controlling AllReduce operations in both attention and MLP sections.
Also applies to: 758-760
733-755
: LGTM: Quantization-aware PRE_MLP_FUSION implementation.The PRE_MLP_FUSION logic properly handles scale extraction for NVFP4 quantization and correctly unpacks fusion outputs into appropriate tensor formats.
913-913
: LGTM: Consistent tensor parallel condition abstraction.Same improvement as seen elsewhere - using
has_tp()
instead of direct size comparison provides better encapsulation.
973-984
: LGTM: Essential cross-layer fusion setup.The
load_weights
method correctly establishes the cross-layer references needed for fusion patterns. The logic properly handles both intermediate layers (linking to next layer's components) and the final layer (linking to the model's norm).This enables the cross-layer fusion capabilities referenced in the forward methods of the decoder layers.
PR_Github #12439 [ run ] triggered by Bot |
PR_Github #12439 [ run ] completed with state |
74d7950
to
a4d3136
Compare
/bot run --add-multi-gpu-test --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
tensorrt_llm/_torch/models/modeling_llama.py (3)
391-400
: Note: Fusion op assignment logic is still duplicated.This code segment still has the same duplication issue mentioned in the past review comment. The fusion op assignment logic is repeated between this section and the LlamaDecoderLayer class (lines 649-658). The previous suggestion to consolidate this into shared variables still applies.
643-644
: Consistent environment variable naming needed.Same issue as in Llama4DecoderLayer - the environment variable name should be consistent across both classes.
649-658
: Code duplication: Fusion op assignment logic repeated.This fusion op assignment logic is duplicated from the Llama4DecoderLayer class (lines 391-400). The previous review suggestion to consolidate this into shared variables still applies to reduce code duplication.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/models/modeling_llama.py (3)
388-390
: Consider renaming the environment variable for clarity.The environment variable
TRTLLM_LLAMA_EAGER_FUSION_DISABLED
is used for both Llama and Llama4 models, which might be confusing. Consider using a more generic name likeTRTLLM_EAGER_FUSION_DISABLED
or model-specific names.- self.enable_fusion = os.environ.get( - "TRTLLM_LLAMA_EAGER_FUSION_DISABLED", "0") == "0" + self.enable_fusion = os.environ.get( + "TRTLLM_EAGER_FUSION_DISABLED", "0") == "0"
461-461
: Remove debug print statement.The debug print statement should be removed before merging to production.
- print(f"init Llama4DecoderLayer")
579-579
: Remove debug print statements.Debug print statements should be removed before production deployment.
- print(f"{self.layer_idx}, {self.next_layer_layernorm}")
- print(f"in forward")
Also applies to: 583-583
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_llama.py
(14 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/models/modeling_llama.py (9)
2-2
: LGTM - Import addition for environment variable support.The
os
import is correctly added to support the environment variable-based fusion control introduced later in the code.
341-341
: LGTM - Consistent use of mapping API.The change from
tp_size > 1
tohas_tp()
improves consistency with the mapping API usage pattern throughout the codebase.
452-459
: LGTM - Well-structured all-reduce disable flags.The consolidation of all-reduce disable logic into boolean flags (
disable_attn_allreduce
anddisable_feed_forward_allreduce
) improves code clarity and maintainability by centralizing the conditions.
517-521
: LGTM - Proper quantization handling in fusion.The quantization-aware fusion logic correctly unpacks the fusion output into separate components (fp4 tensor, scale factor, and residual) when NVFP4 quantization is enabled, and wraps them appropriately.
Also applies to: 584-588
660-661
: Clarify the accuracy drop issue.The TODO comment mentions disabling fusion "to avoid large accuracy drop" but doesn't provide details about the cause or planned resolution. This could impact performance benefits.
Can you provide more context about this accuracy drop issue? Is this a temporary workaround, and what's the timeline for fixing it?
663-668
: LGTM - Consistent disable flags pattern.The all-reduce disable flags follow the same well-structured pattern as the Llama4DecoderLayer, improving code consistency and maintainability.
694-722
: LGTM - Comprehensive fusion logic with quantization support.The pre-MLP and post-MLP fusion logic properly handles both regular and quantized (NVFP4) cases, correctly unpacking fusion outputs and creating appropriate tensor wrappers. The cross-layer fusion setup is also well-implemented.
Also applies to: 731-752
874-874
: LGTM - Consistent mapping API usage.The change to use
has_tp()
instead oftp_size > 1
maintains consistency with the mapping API pattern used throughout the codebase.
934-944
: LGTM - Well-implemented cross-layer fusion setup.The
load_weights
method correctly establishes cross-layer references needed for AllReduce fusion patterns. The logic properly handles the last layer (linking to final norm) and intermediate layers (linking to next layer's input normalization and attention modules).
PR_Github #12528 [ run ] triggered by Bot |
a4d3136
to
d84c4ce
Compare
PR_Github #12528 [ run ] completed with state |
/bot run --add-multi-gpu-test --disable-fail-fast |
PR_Github #12579 [ run ] triggered by Bot |
PR_Github #12579 [ run ] completed with state |
/bot run --only-multi-gpu-test --disable-fail-fast |
PR_Github #12630 [ run ] triggered by Bot |
PR_Github #12630 [ run ] completed with state |
/bot run --only-multi-gpu-test --disable-fail-fast |
PR_Github #12635 [ run ] triggered by Bot |
PR_Github #12635 [ run ] completed with state |
d84c4ce
to
36c9c91
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
388-389
: Environment variable name inconsistency.The environment variable name
TRTLLM_LLAMA_EAGER_FUSION_DISABLED
differs from the past review comments which referencedTRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED
. Ensure consistent naming across the codebase.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/models/modeling_llama.py (2)
461-461
: Remove debug print statement.This debug print should be removed before merging to production.
- print(f"init Llama4DecoderLayer")
586-586
: Remove debug print statements.These debug prints should be removed before merging to production.
- print(f"{self.layer_idx}, {self.next_layer_layernorm}")
- print(f"in forward")
Also applies to: 590-590
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
(1 hunks)tensorrt_llm/_torch/models/modeling_llama.py
(14 hunks)
✅ Files skipped from review due to trivial changes (1)
- cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/models/modeling_llama.py (1)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor()
is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation()
to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (16)
tensorrt_llm/_torch/models/modeling_llama.py (16)
2-2
: LGTM: Import addition for environment variable support.The
os
import is correctly added to support the environment variable-based fusion control introduced in the decoder layers.
341-341
: LGTM: Consistent tensor parallel check.Using
has_tp()
instead oftp_size > 1
maintains consistency with other tensor parallel checks throughout the codebase.
391-400
: LGTM: Proper quantization-aware fusion operation setup.The fusion operations are correctly configured based on quantization mode, with NVFP4 support implemented and FP8 support prepared for future enablement.
414-417
: LGTM: Correct fusion flag configuration.The fusion flags properly incorporate tensor parallelism presence, attention data parallelism state, and fusion enablement, ensuring fusion is only enabled when appropriate.
Also applies to: 431-434
452-459
: LGTM: Well-designed consolidated all-reduce control.The disable flags correctly consolidate fusion configuration, tensor parallel size, and attention DP state into clear boolean flags that control all-reduce behavior throughout the forward pass.
492-494
: LGTM: Proper all-reduce control in attention.The all-reduce parameters correctly use the consolidated disable flag to control all-reduce behavior during the attention phase.
497-510
: LGTM: Correct quantization-aware fusion implementation.The scale extraction for quantized modes and NVFP4 tensor unpacking are implemented correctly, properly handling the quantized tensor format returned by fusion operations.
Also applies to: 517-521
524-528
: LGTM: Proper speculative metadata handling.The fusion disabling logic for layers captured by speculative metadata is correctly implemented to avoid interference with speculative execution.
547-582
: LGTM: Well-implemented cross-layer fusion.The cross-layer fusion logic correctly references the next layer's normalization and attention modules, with proper quantization-aware scale handling for both normal and min-latency modes.
610-617
: LGTM: Necessary attributes for fusion logic.The additional attributes for mapping, attention DP, and quantization flags are correctly added to support the fusion implementation.
640-643
: LGTM: Proper initialization of cross-layer fusion attributes.The AllReduce initialization and cross-layer reference attributes are correctly set up to enable fusion between adjacent layers.
650-676
: LGTM: Consistent fusion control implementation.The environment variable-based fusion control and consolidated disable flags follow the same correct pattern as Llama4DecoderLayer, ensuring consistent behavior across model variants.
697-698
: LGTM: Consistent fusion implementation for LlamaDecoderLayer.The fusion logic correctly mirrors the Llama4DecoderLayer implementation, with proper quantization handling, speculative metadata support, and cross-layer fusion capabilities.
Also applies to: 702-723, 731-736, 745-766
888-888
: LGTM: Consistent tensor parallel check.Using
has_tp()
instead oftp_size > 1
maintains consistency with the tensor parallel checks used throughout the fusion logic.
948-958
: LGTM: Proper cross-layer fusion setup.The
load_weights
method correctly establishes cross-layer references by linking each decoder layer to the next layer's normalization and attention modules, with proper handling of the final layer boundary condition.
1086-1093
: LGTM: Consistent cross-layer fusion setup for Llama4.The cross-layer fusion setup in the Llama4 conditional generation model correctly mirrors the implementation in LlamaForCausalLM, ensuring consistent fusion behavior across model variants.
36c9c91
to
3c9e0ed
Compare
/bot run --add-multi-gpu-test --disable-fail-fast |
3c9e0ed
to
d19035f
Compare
PR_Github #12818 [ run ] triggered by Bot |
/bot run --add-multi-gpu-test --disable-fail-fast |
PR_Github #12822 [ run ] triggered by Bot |
PR_Github #12818 [ run ] completed with state |
PR_Github #12822 [ run ] completed with state |
…ama3/4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
d19035f
to
1289a0e
Compare
/bot run --add-multi-gpu-test --disable-fail-fast |
PR_Github #12863 [ run ] triggered by Bot |
PR_Github #12863 [ run ] completed with state |
Enable AllReduce-associated fusion patterns with fp4 and fp8 quantization in Llama3/4.
Summary by CodeRabbit
New Features
Improvements