Skip to content

[AutoDeploy] merge feat/ad-2025-07-07 #6196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 22, 2025

Conversation

lucaslie
Copy link
Member

@lucaslie lucaslie commented Jul 18, 2025

Summary by CodeRabbit

  • New Features

    • Introduced a modular graph transformation and export pipeline with configurable stages and patch systems.
    • Added backend-specific custom operators and transformations for attention, RMSNorm, and Mixture-of-Experts (MoE), including quantized FP8/FP4 and sliding window attention support.
    • Enabled dynamic layered YAML configuration loading with deep merging and strict validation.
    • Enhanced CLI and YAML configuration for expert usage with detailed documentation.
    • Added comprehensive testing for configuration, quantization, sharding, and custom operator correctness.
    • Added a pure PyTorch backend for cached multi-head attention with advanced features and a torch-based reference implementation for attention tests.
    • Introduced new AutoDeployConfig and refined LlmArgs for improved configuration management and validation.
    • Added support for max beam width and new tokens handling in the AutoDeploy executor.
  • Bug Fixes

    • Improved device and parameter deduplication during export.
    • Fixed parameter and buffer loading, caching, and quantization issues in graph transformations.
    • Addressed issues with tensor creation on meta device and patched PyTorch functions for export compatibility.
  • Refactor

    • Consolidated pattern matching, quantization, and sharding logic into in-place graph transformations.
    • Modularized patch and transform registration for easier maintenance.
    • Updated tests to align with new transformation APIs and in-place graph modifications.
    • Refactored sharding transforms to a declarative, two-phase detection and application model.
    • Simplified import structures and unified export function usage.
    • Replaced legacy export pipeline with a new robust export module and patch system.
    • Replaced manual attention reference computations in tests with centralized torch-based reference implementations.
  • Chores

    • Updated dependencies to support YAML-based configuration.
    • Added new tests covering configuration loading, sharding detection, and operator correctness.
    • Added auto-import modules for export patches, model patches, and transforms.
    • Improved logging and code comments for clarity and maintainability.
  • Documentation

    • Expanded README and in-code docs with detailed expert usage, configuration options, and best practices.
    • Added detailed docstrings for new configuration, transform, and export interfaces.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link

coderabbitai bot commented Jul 18, 2025

Walkthrough

This update introduces a modular, extensible transformation and export pipeline for PyTorch models in the AutoDeploy system. It replaces legacy monolithic transformation code with a registry-based system for export patches and graph transforms, adds deep YAML config merging, and expands quantization, sharding, and MoE support. The update includes new backend-specific custom ops, extensive test coverage, and improved configuration management.

Changes

File(s) / Area Change Summary
tensorrt_llm/_torch/auto_deploy/export/, tensorrt_llm/_torch/auto_deploy/transform/, tensorrt_llm/_torch/auto_deploy/transformations/ New modular export and transform systems: registry-based export patch framework, transform registry, and library auto-importers. Legacy transformation modules deprecated or removed.
tensorrt_llm/_torch/auto_deploy/export/library/, tensorrt_llm/_torch/auto_deploy/transform/library/ New and refactored export patches (e.g., autocast_noop, sdpa, tensor_meta_device, linear, modelopt_context, sdpa_kernel_noop, torch_modulelist_getitem, torch_where, transformers_sdpa_mask) and transform passes (e.g., build_model, export_to_gm, cleanup, quantize_moe, fuse_rmsnorm). Library __init__.py files auto-discover and import all patches/transforms.
tensorrt_llm/_torch/auto_deploy/llm_args.py, tensorrt_llm/_torch/auto_deploy/utils/_config.py Split LLM config into AutoDeployConfig and LlmArgs with stricter validation and YAML-based defaults. Added dynamic YAML config loading and deep merging via OmegaConf.
tensorrt_llm/_torch/auto_deploy/models/, tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py Centralized model and tokenizer defaults, improved checkpoint loading, and support for new config fields (e.g., max_beam_width). Model patching modularized.
tensorrt_llm/_torch/auto_deploy/custom_ops/, tensorrt_llm/_torch/auto_deploy/transformations/library/ Added/extended custom ops for attention (torch/triton/flashinfer), RMSNorm, MoE (FP8/FP4), and support for new features like logit capping, sliding window, and sinks. Pattern matching and fusion logic refactored to support quantized and backend-specific variants.
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py Refactored sharding detection and application into a two-phase, config-driven system supporting TP, BMM, and EP sharding, with pattern detection and transform execution separated.
tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py, tests/unittest/.../test_attention_matcher.py Unified attention pattern matching using a registry-based matcher, replacing multiple ad-hoc matchers and tests.
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py, tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py MoE pattern matching and quantization extended to support FP8 and FP4 quantization, with new test coverage.
tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py, tests/unittest/.../test_fuse_rmsnorm.py Added RMSNorm fusion transform supporting FlashInfer, Triton, and Torch backends, with new tests for correctness and operator presence.
tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py RoPE pattern matching and optimization functions now operate in-place and return match counts, not modified graphs.
tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py, fusion.py, etc. All graph transforms now operate in-place and return None, reflecting a shift to side-effect-based modification.
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py, node_utils.py Added quantization skipping, scale extraction, and quantization type detection utilities. Refined node filtering and pattern matching helpers.
examples/auto_deploy/, requirements.txt, setup.py Updated CLI, README, and launch configs to reflect new config fields and usage. Added omegaconf and YAML support to requirements. Included YAML files in package data.
tests/unittest/_torch/auto_deploy/ Extensive new and refactored tests for all new features: config merging, export/transform passes, quantization, MoE, sharding, attention, RMSNorm, and backend ops. Tests updated for in-place graph modification and new APIs.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI/Script
    participant ConfigLoader
    participant ModelFactory
    participant InferenceOptimizer
    participant TransformRegistry
    participant ExportPatchRegistry
    participant GraphModule

    User->>CLI/Script: Launch with CLI args/YAML configs
    CLI/Script->>ConfigLoader: Load and deep-merge YAML configs
    ConfigLoader->>ModelFactory: Create model factory with config
    CLI/Script->>InferenceOptimizer: Initialize with factory and transform config
    InferenceOptimizer->>TransformRegistry: For each transform stage, get transform class
    InferenceOptimizer->>ExportPatchRegistry: Apply export patches as needed
    InferenceOptimizer->>GraphModule: Apply transforms in order (build, export, cleanup, quantize, fuse, etc.)
    GraphModule-->>InferenceOptimizer: Modified graph after all transforms
    InferenceOptimizer-->>CLI/Script: Return optimized inference model
    CLI/Script-->>User: Ready-to-use inference model
Loading

Estimated code review effort

4 (~90 minutes)

Possibly related PRs

Suggested reviewers

  • chzblych
  • nv-guomingz
  • juney-nvidia

Poem

🐇✨
In the burrows of code, we rabbits delight,
With YAML and patches, our configs take flight.
Transforms and exports, now modular and neat,
Quantized and sharded, our models compete!
Attention and MoE, fused with great care,
Hop forward, dear models—deploy anywhere!
🥕🚀
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@lucaslie lucaslie moved this from Backlog to In review in AutoDeploy Board Jul 18, 2025
@lucaslie lucaslie self-assigned this Jul 18, 2025
@lucaslie lucaslie linked an issue Jul 18, 2025 that may be closed by this pull request
@lucaslie lucaslie enabled auto-merge (squash) July 18, 2025 23:54
@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12356 [ run ] triggered by Bot

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 24

🔭 Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (1)

98-104: Fix self-reference bug in graph node replacement.

Line 103 has a critical bug where new_contiguous_node.replace_input_with(new_contiguous_node, original_input) is trying to replace the node's input with itself, which doesn't make sense logically.

-            new_contiguous_node.replace_input_with(new_contiguous_node, original_input)
+            # The new_contiguous_node already uses original_input as its argument,
+            # so no additional input replacement is needed here

Additionally, there's a potential issue with the logic flow: original_input.replace_all_uses_with(new_contiguous_node) on line 102 will replace all uses of original_input, but then we immediately try to modify new_contiguous_node's inputs, which may cause inconsistencies.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_attention_with_kv_cache.py (1)

286-307: Add sliding window mask to reference computation.

The test passes sliding_window parameter to the kernel but doesn't apply the corresponding mask in the reference computation.

The reference computation should apply sliding window masking when sliding_window > 0 to properly validate the kernel behavior. Consider adding sliding window mask similar to what's done in test_context_attention_kv_flattened.

♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)

79-83: Apply same docstring improvement for consistency.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)

60-68: Same code duplication issue as in TP sharding test.

🧹 Nitpick comments (46)
tensorrt_llm/_torch/auto_deploy/export/library/__init__.py (1)

1-16: Functional but duplicates code from models/patches init.

This auto-import implementation is correct and well-documented, but it's identical to the pattern in tensorrt_llm/_torch/auto_deploy/models/patches/__init__.py. Consider extracting this common pattern into a reusable utility function to eliminate code duplication.

Example refactor to create a shared utility:

# Create shared utility in a common location
def auto_import_submodules(package_path, package_name):
    """Auto-import all non-private submodules in a package."""
    import importlib
    import pkgutil
    
    all_modules = []
    for _, module_name, is_pkg in pkgutil.iter_modules(package_path):
        if module_name.startswith("_"):
            continue
        all_modules.append(module_name)
        importlib.import_module(f"{package_name}.{module_name}")
    return all_modules

# Then use in both files:
__all__ = auto_import_submodules(__path__, __name__)
tensorrt_llm/_torch/auto_deploy/export/library/sdpa.py (1)

23-23: Consider adding error handling for missing custom operator.

The patch assumes torch.ops.auto_deploy.torch_attention_sdpa exists. Consider adding validation to ensure the custom operator is available before applying the patch.

    def _apply_patch(self):
        """Apply the SDPA patch."""
+        # Validate custom operator exists
+        if not hasattr(torch.ops.auto_deploy, 'torch_attention_sdpa'):
+            raise RuntimeError("Custom operator torch.ops.auto_deploy.torch_attention_sdpa not found")
+        
        # Store original function
        self.original_values["F.scaled_dot_product_attention"] = F.scaled_dot_product_attention

        # Apply patch
        F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

16-21: Consider adding comments for cleanup stages.

The cleanup stages (cleanup_noop_slice, cleanup_noop_add, cleanup_input_constraints) lack explanatory comments unlike the other stages. Adding brief descriptions would improve maintainability.

  cleanup_noop_slice:
    stage: post_export
+    # Remove unnecessary slice operations
  cleanup_noop_add:
    stage: post_export
+    # Remove unnecessary add operations  
  cleanup_input_constraints:
    stage: post_export
+    # Clean up input constraint nodes
tensorrt_llm/_torch/auto_deploy/transform/library/__init__.py (1)

12-16: Consider performance implications of eager loading.

All transform modules are imported immediately when the package is imported, which could impact startup time. Consider if lazy loading would be more appropriate for this use case.

You could implement lazy loading by only populating __all__ during discovery and deferring actual imports until the transforms are needed, or document that eager loading is intentional for registration purposes.

tensorrt_llm/_torch/auto_deploy/export/library/linear.py (1)

31-31: Consider thread safety implications.

The patch modifies the global F.linear function, which could potentially affect concurrent exports if they occur in the same process. Consider documenting this limitation or implementing thread-local patching if needed.

tensorrt_llm/_torch/auto_deploy/export/library/tensor_meta_device.py (2)

24-26: Consider expanding the condition to handle more cases.

The current condition only handles the exact case of data == 0.0. Consider whether other numeric values or data types might also have issues on the meta device.

-            if data == 0.0 and device is not None and torch.device(device) == torch.device("meta"):
-                return torch.zeros((), **kwargs)
+            if (
+                isinstance(data, (int, float)) 
+                and data == 0.0 
+                and device is not None 
+                and torch.device(device) == torch.device("meta")
+            ):
+                return torch.zeros((), **kwargs)

24-24: Add error handling for invalid device specifications.

The torch.device(device) call could raise an exception for invalid device strings. Consider adding error handling or validation.

-            if data == 0.0 and device is not None and torch.device(device) == torch.device("meta"):
+            if data == 0.0 and device is not None:
+                try:
+                    device_obj = torch.device(device)
+                    if device_obj == torch.device("meta"):
+                        return torch.zeros((), **kwargs)
+                except (TypeError, RuntimeError):
+                    pass  # Fall back to original function for invalid device specs
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py (1)

444-460: Address inconsistent reference implementation usage.

This test function still uses torch.nn.functional.scaled_dot_product_attention directly instead of TorchAttentionReference. For consistency and maintainability, consider updating this test to use the centralized reference implementation as well.

Do you want me to update this test to use TorchAttentionReference for consistency with the other test functions?

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (1)

37-37: Consider using a more portable max value check.

Using torch.iinfo(torch.long).max assumes that the slice end uses the same maximum value. Consider also checking for other common "unbounded" representations.

-            if node.args[2] != 0 or node.args[3] != torch.iinfo(torch.long).max:
+            # Check for no-op slice: start=0 and end is effectively unbounded
+            max_vals = [torch.iinfo(torch.long).max, torch.iinfo(torch.int64).max, float('inf')]
+            if node.args[2] != 0 or node.args[3] not in max_vals:
tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py (1)

354-356: Consistent parameter addition but needs documentation.

Same sliding window and sinks parameters as other calls, maintaining consistency.

Add inline comments for the new parameters for clarity:

-        False,
-        None,
+        False,  # sliding_window_enabled
+        None,   # sinks
tensorrt_llm/_torch/auto_deploy/export/library/transformers_sdpa_mask.py (1)

49-51: Consider logging skipped patches for debugging.

Silent skipping of patches due to ImportError might make debugging difficult in development environments.

         except ImportError:
             # If transformers is not available or doesn't have required modules, skip patch
-            pass
+            # Consider adding debug logging here for development
+            pass
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py (1)

425-426: TODO comment indicates incomplete refactoring.

The TODO comment correctly identifies the intention to replace the manual reference computation with the torch backend reference, consistent with the pattern applied to other tests. However, the refactoring for this test remains incomplete.

Consider completing the refactoring for test_paged_gqa_op to maintain consistency with the other test functions, or create a follow-up task to address this.

tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)

245-245: Significant behavioral change in BMMDynamicModel.

The model now uses fixed learnable weights instead of input-dependent dynamic weight generation. While this simplifies the model and may improve test determinism, the name "BMMDynamicModel" might be misleading since weights are no longer dynamically generated from inputs.

Consider updating the class name or docstring to reflect that weights are now fixed learnable parameters rather than input-dependent.

 class BMMDynamicModel(nn.Module):
-    """BMM model with dynamic tensor weights for testing."""
+    """BMM model with fixed learnable weights for testing."""

Also applies to: 252-252

examples/auto_deploy/README.md (1)

262-262: Minor markdown style inconsistency.

The linter flags inconsistent emphasis style (underscores vs asterisks). For consistency with markdown best practices, consider updating:

-*exclusively* exposed in the [`AutoDeployConfig` class]
+**exclusively** exposed in the [`AutoDeployConfig` class]
-_ignored_ in AutoDeploy
+**ignored** in AutoDeploy

Also applies to: 267-267

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1)

50-50: Consider improving assertion message for better debugging.

The assertion correctly validates the precondition, but could provide more context for debugging:

-assert len(gm.graph.nodes) == 0, "Expected empty graph module."
+assert len(gm.graph.nodes) == 0, f"Expected empty graph module for export, but found {len(gm.graph.nodes)} nodes. This transform should be applied to a dummy graph module containing only the factory_model submodule."
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)

47-56: Complete the docstring to include all parameters.

The docstring is missing documentation for the gm parameter.

Add the missing parameter:

 Args:
     cm: The cached sequence interface defining the sequence interface.
+    gm: Optional GraphModule to transform. If None, creates an empty one.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

172-198: Consider extracting common model initialization logic.

There's significant code duplication between _run_pattern_detection_job and _run_job for model and input initialization (lines 178-198 duplicate lines 89-109).

Extract common initialization:

def _init_model_and_input(model_cls, bias=False):
    """Initialize model and input tensor for testing."""
    batch_size = 4
    sequence_len = 8
    num_features = 32
    num_heads = 4
    num_key_value_heads = 1
    
    if model_cls == GQA_Block:
        model = model_cls(
            num_attention_heads=num_heads,
            hidden_size=num_features,
            num_key_value_heads=num_key_value_heads,
        ).to(device="cuda", dtype=torch.float16)
    else:
        model = model_cls(num_features, num_features, bias=bias).to(
            device="cuda", dtype=torch.float16
        )
    x = torch.randn(batch_size, sequence_len, num_features, device="cuda", dtype=torch.float16)
    return model, x, num_features, num_heads, num_key_value_heads
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)

42-61: Consider improving docstring consistency.

The implementation is correct, but the fake implementation's docstring is less detailed than the FlashInfer version.

Update the docstring for consistency:

 @triton_rmsnorm.register_fake
 def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
-    """Fake implementation for the custom operator during tracing."""
+    """Fake implementation for the custom operator during tracing.
+
+    Args:
+        input: Input tensor to normalize.
+        weight: Scaling weights for the normalized output.
+        eps: Small constant for numerical stability.
+
+    Returns:
+        Empty tensor with same shape as input.
+    """
     return torch.empty_like(input)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1)

107-107: Consider tracking the TODO for future improvements

The TODO suggests this test could benefit from a custom InferenceOptimizer config. Consider creating an issue to track this enhancement.

Would you like me to open an issue to track this TODO for implementing a custom InferenceOptimizer config?

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py (1)

398-402: Sinks mechanism correctly integrated

The implementation properly adds the sinks contribution to the softmax denominator. Consider adding a comment explaining what "sinks" represent in the attention mechanism for better code documentation.

tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py (1)

34-60: Well-structured quantization logic

The nested function quantize_param_list cleanly handles weight quantization, scale registration, and hook setup. Consider documenting the memory implications of quantizing large MoE models with many experts.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

505-513: Consider making the quantization type detection more extensible.

The function currently hardcodes the list of quantization implementations to check. This could be made more maintainable.

Consider using a registry pattern or class attribute to make this more extensible:

 def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
     """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
-    for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
+    # Consider adding a class method to QuantizationImpl to get all implementations
+    for qtype in QuantizationImpl.get_all_implementations():
         if is_op(node, qtype.target_op()):
             return extract_scales_from_node(
                 node, qtype.scale_names()
             ), qtype.__name__.lower().replace("quantizationimpl", "")
     return None, "simple"
tensorrt_llm/_torch/auto_deploy/utils/_config.py (1)

58-62: Enforce MRO requirement programmatically.

The documentation states this class must come first in the MRO, but this isn't enforced in code.

Consider adding a check in __init_subclass__ to enforce the MRO requirement:

def __init_subclass__(cls, **kwargs):
    super().__init_subclass__(**kwargs)
    # Ensure DynamicYamlMixInForSettings is first in MRO (after the class itself)
    if len(cls.__mro__) > 2 and cls.__mro__[1] != DynamicYamlMixInForSettings:
        raise TypeError(
            f"{cls.__name__} must inherit from DynamicYamlMixInForSettings "
            "as the first base class to ensure proper yaml_configs processing"
        )
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (2)

188-189: Address TODO for quantized model fusion.

The TODO indicates that trtllm_moe_fused doesn't currently support quantized models. This could be a performance limitation.

Would you like me to help implement the quantized version of trtllm_moe_fused or create an issue to track this enhancement?


139-143: Update docstring to reflect in-place modification.

The function now performs in-place modifications but the docstring doesn't indicate this change.

Update the docstring to clarify that the function modifies the graph module in-place:

 def fuse_moe(gm: torch.fx.GraphModule) -> None:
     """
     Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
     torch.ops.auto_deploy.trtllm_moe_fused.
+    
+    This function modifies the graph module in-place.
     """
tensorrt_llm/_torch/auto_deploy/export/export.py (2)

29-36: Consider checking buffer devices as well

The function only checks parameter devices but applies the device to all nodes in the graph. Buffers might be on different devices than parameters.

def _clean_up_device_info(gm: fx.GraphModule) -> None:
    """Correct device information in the graph."""
-    devices = {t.device for _, t in gm.named_parameters()}
+    devices = {t.device for _, t in gm.named_parameters()}
+    devices.update({t.device for _, t in gm.named_buffers()})
    if len(devices) == 0:
        return
    elif len(devices) > 1:
-        raise AssertionError("All parameters should be on the same device.")
+        raise AssertionError("All parameters and buffers should be on the same device.")

269-272: Address TODO about overlap between deduplication functions

The TODO comment suggests there's overlap between _add_load_hook_for_aliased_params and _deduplicate_params_and_buffers. This should be clarified to avoid maintenance issues.

Would you like me to analyze the overlap between these functions and suggest a refactoring approach to consolidate them?

tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_config.py (2)

843-862: Test doesn't actually verify relative path handling

The test is named test_relative_and_absolute_paths and has a comment about testing relative paths, but it only uses absolute paths throughout. Consider either renaming the test or implementing actual relative path testing.


866-866: Add newline at end of file

-        assert settings.simple.name == "config2"  # from config2
+        assert settings.simple.name == "config2"  # from config2
+
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)

171-173: Consider optimizing GPU->CPU transfer for new_tokens

The comment indicates a performance concern about the GPU->CPU transfer. This could be a bottleneck in the inference pipeline.

Consider tracking this as a performance optimization opportunity. Would you like me to open an issue to investigate avoiding this transfer by keeping the tensor operations on the GPU?

tensorrt_llm/_torch/auto_deploy/export/interface.py (1)

250-250: Add newline at end of file

-    yield from _apply_patches(patches)
+    yield from _apply_patches(patches)
+
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py (2)

332-339: Document known backend issues with sinks

The code handles backend failures when sinks are enabled, suggesting known issues. Consider adding a comment explaining the specific backend limitations or linking to a tracking issue.

         # For sinks: test that backend runs without crashing (backend has bugs)
         # and validate correct sinks behavior with numpy reference
+        # TODO: Backend currently has issues with sinks implementation - see issue #XXX
         try:

488-488: Add newline at end of file

-        assert result[0].shape[0] == batch_size, "First tensor should have batch_size elements"
+        assert result[0].shape[0] == batch_size, "First tensor should have batch_size elements"
+
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (5)

44-44: Document the magic numbers for weight scale factors.

The magic numbers 448 and 432 lack explanation. Please add a comment explaining why these specific values are chosen for bfloat16 vs other dtypes.

-        wt_factor = 448 if dtype == torch.bfloat16 else 432
+        # FP8 E4M3 max value is 448 for bfloat16 and 432 for float16/float32
+        wt_factor = 448 if dtype == torch.bfloat16 else 432

46-48: Consider using deterministic weight initialization for tests.

Using torch.randn without a seed in test code can lead to non-deterministic behavior. Consider setting a seed or accepting initialized weights as parameters.

+        torch.manual_seed(42)  # Ensure deterministic initialization
         w1_fp32 = torch.randn(ffn_dim, hidden_dim, device=device)
         w3_fp32 = torch.randn(ffn_dim, hidden_dim, device=device)
         w2_fp32 = torch.randn(hidden_dim, ffn_dim, device=device)

99-101: Document the weight initialization scale factor.

The weights are initialized with a 0.01 scale factor. Please add a comment explaining why this specific value is chosen for FP4 quantization.

         # Prepare full-precision weights
+        # Small initialization scale (0.01) helps with FP4 quantization stability
         w1_fp32 = torch.randn(ffn_dim, hidden_dim, device=device, dtype=dtype) * 0.01

244-245: Optimize input tensor creation.

Creating the input tensor on CPU and then moving it to the device in the forward pass is inefficient. Consider creating it directly on the target device.

-        input_ids = self.get_input(device="cpu")  # or pass as constructor arg
-        input_sample = self.embedding(input_ids)
+        with torch.device(device):
+            input_ids = self.get_input(device=device)
+            input_sample = self.embedding(input_ids)

262-262: Remove duplicate seed setting.

The random seed is set both in get_input (line 262) and in the test function (line 294). Consider removing one to avoid confusion.

Also applies to: 294-294

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (1)

274-275: Consider moving empty input check to the template.

The empty input check here is good, but for consistency, consider moving this check to the _template_moe function since it could benefit all MoE variants.

tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

232-241: Consider more specific exception handling.

The broad exception catch might hide specific errors. Consider catching more specific exceptions or at least logging the full traceback for debugging.

             try:
                 gm, info = self._apply(gm, cm, factory)
-            except Exception as e:
+            except (TransformError, ValueError, TypeError) as e:
                 error_msg = f"Transform {t_name} failed"
                 if self.config.skip_on_error:
                     ad_logger.warning(f"{error_msg}: {e}")
+                    ad_logger.debug(f"Full traceback:", exc_info=True)
                     info = TransformInfo(skipped=True, num_matches=0)
                 else:
                     raise TransformError(error_msg) from e
+            except Exception as e:
+                # Unexpected errors should always be raised
+                ad_logger.error(f"Unexpected error in transform {t_name}: {e}")
+                raise
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)

205-206: Address the TODO for Path support.

The TODO comment indicates that Path objects should be supported in the model factory. Consider implementing this support or creating an issue to track it.

Would you like me to help implement Path support in the model factory or create an issue to track this enhancement?

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py (1)

498-500: Consider validating ShardingConfig is not None.

Add a validation check to ensure sharding_config is provided when world_size >= 2.

     if world_size < 2:
         ad_logger.info("Skipping sharding for single device")
         return
+    
+    assert sharding_config is not None, "sharding_config must be provided for multi-device setup"
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (2)

188-213: Complex but correct sliding window implementation.

The sliding window mask implementation correctly handles position differences between query and key positions. However, consider adding more inline documentation to explain the logic, especially around lines 194-207 where position differences are calculated.

+            # Calculate position differences for sliding window mask
+            # Query positions are offset by input_pos_i, key positions start from 0
             query_positions = torch.arange(
                 input_pos_i, input_pos_i + seq_len_i, device=q.device
             )  # [seq_len_i]
             key_positions = torch.arange(0, kv_seq_len, device=q.device)  # [kv_seq_len]

             # Create position difference matrix: query_pos - key_pos
+            # This represents how far back each query position is looking
             pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(
                 0
             )  # [seq_len_i, kv_seq_len]

475-484: Consider using approximate equality check for scale validation.

When validating the scale parameter, consider allowing for floating-point precision differences if you want to verify it matches the expected 1/sqrt(head_dim) value.

         # Validate scale
         if not isinstance(scale, float):
             ad_logger.warning("Provided scale is not a float. Using default scale instead.")
             scale = None
+        elif scale is not None:
+            # Optionally validate scale is approximately 1/sqrt(head_dim)
+            k_fake: FakeTensor = source_attn_node.args[1].meta["val"]
+            expected_scale = 1.0 / math.sqrt(k_fake.shape[3])
+            if abs(scale - expected_scale) > 1e-6:
+                ad_logger.debug(f"Scale {scale} differs from expected {expected_scale}")
tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py (1)

226-231: Consider using more meaningful scalar values

The scalar values for scale and dropout (e.g., 0.1234743, 0.85849734) appear arbitrary. Consider using more intuitive values or adding a comment explaining why these specific values were chosen.

    configs = [
-        (_sfdp_pattern_1, _sfdp_replacement_1, True, 0.1234743, 0.85849734),
-        (_sfdp_pattern_2, _sfdp_replacement_2, False, 0.234743, 0.5849734),
+        (_sfdp_pattern_1, _sfdp_replacement_1, True, 0.125, 0.1),  # 1/8 scale, 10% dropout
+        (_sfdp_pattern_2, _sfdp_replacement_2, False, 0.25, 0.2),  # 1/4 scale, 20% dropout
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (1)

1174-1176: Document transpose node count expectations

The test expects "at least 7" transposes for bsnd but exactly 4 for bnsd. Consider adding a comment explaining why this asymmetry exists (e.g., some transposes may be optimized away in the bsnd case).

             # - 3 for the new input transposes
             # - 1 for the new output transpose
             # Note: Some nodes may be optimized away, so we check for at least 7
+            # The graph optimizer may fuse or eliminate some transposes, hence "at least"
             if len(transpose_nodes) < 7:

Also applies to: 1201-1203

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12356 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9178 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12444 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12444 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9254 completed with status: 'FAILURE'

galagam and others added 10 commits July 21, 2025 07:25
…rmations to return None (#71)

* Refactor the signatures of AD graph transformations to return None (NVIDIA#5249)

Refactor signatures of AD graph transformations from
  gm = transformation(gm)
to
  transformation(gm)

Since the AD graph transformations modify the input GraphModule
in-place. Previous signature style was misleading.

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
…ion (#76)

* Fix trtllm-bench test and enable trtllm-bench integration

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>

* Remove unneeded __init__.py

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>

---------

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
) (#73)

* yaml config loader for dynamic settings

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>

* updates for yaml mixin

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>

* addressing reviewer feedback

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>

---------

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
* [AutoDeploy] Refining AD configurability

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>

* addressed reviewer feedback

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>

---------

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
* Add the Torch backend and update the test to use the torch backend.

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* Add the sinks and fix the test failures

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* address reviewer's comments

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* use custom op convention

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* move the ref to the utils_test

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* Add torch backend tests in ad_build_small_single.py

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

* Address hidden comments...

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>

---------

Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
* add torch_fp8_moe and fp8 linear support in pattern matcher, update unit tests

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add torch-fp4-moe and fp4 support in pattern matcher, unit test has acc issue and e2e mixtral fp4 has kernel error wo moe matching

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add pre-commit hook

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* hacky fix for e2e run of mixtral FP4 and fp4 op unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* EP support for torch_fp4_moe and torch_fp8_moe

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix rebase: op rename, shard_load_hook bug in FP4

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix pre-commit

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix weight loading-load_hook issue for FP4, update function to handle exclude_modules in hf_quant_config

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* addressing feedback, add moe op template, update op names,other minor refinements

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* move common functionality to utility

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix FP4QuantizationImpl register from rebase

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add quantize_moe pass for patched torch_moe op

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add transformation unit tests for FP8 and FP4

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* update should_skip_quantization to fix bmm unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* update BMMDynamicModel and utils to extract weight for dynamic BMM case

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* update BMMDynamicModel to drop linear op

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

---------

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
* remove assert, add qwen small to tests

* lint

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>

---------

Signed-off-by: Suyog Gupta <suyogg@nvidia.com>
@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12450 [ run ] triggered by Bot

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (10)
tensorrt_llm/_torch/auto_deploy/export/library/modelopt_context.py (1)

16-23: Update import path for torch_export_context

The import path needs to be corrected as identified in the previous review.

tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)

72-74: quantize_moe must not return the graph – enforce in-place only

The bulk of your transforms correctly modify gm in place, but quantize_moe still has code paths that return gm, violating the in-place-only API as identified in previous reviews.


155-158: Backend selection should be configurable.

The RMSNorm backend is hardcoded to "flashinfer" despite the TODO comment and the availability of self.ad_config.attn_backend. This was previously identified in code review.

Replace the hardcoded backend:

-# TODO (lucaslie): add backend selection as part of configurable inference optimizers
-# check if we can fuse rmsnorm
-fuse_rmsnorm(egm, "flashinfer")
+# check if we can fuse rmsnorm using the configured backend
+fuse_rmsnorm(egm, self.ad_config.attn_backend)
tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py (2)

136-168: Add parameter existence validation for robustness.

As noted in the TODO and previous review, this function should validate that the extracted parameter names actually exist in the graph module's parameters.

Add validation after extracting names:

 w1_names = _unwrap_list(w1_list)
 w2_names = _unwrap_list(w2_list)
 w3_names = _unwrap_list(w3_list)
+
+# Validate that all extracted names are actual parameters
+for name in w1_names + w2_names + w3_names:
+    if not any(name == param_name for param_name, _ in moe_node.graph.owning_module.named_parameters()):
+        raise RuntimeError(f"Extracted parameter name '{name}' not found in module parameters")

 return w1_names, w2_names, w3_names

100-134: Remove inconsistent return statements for in-place modification.

The function returns gm on lines 108 and 129 but implicitly returns None on line 133. Since this follows the in-place modification pattern, all paths should consistently return None.

Apply this fix:

 if not quant_algo:
     ad_logger.info("No quantization to do.")
-    return gm
+    return

And:

 if count == 0:
-    return gm
+    return
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (1)

98-100: BMM sharding should validate expert divisibility.

While the test ensures num_experts is divisible by world_size through its construction, the actual BMM sharding implementation should validate this constraint as identified in previous reviews.

Verify that the BMM sharding implementation includes proper divisibility validation:

#!/bin/bash
# Check if BMMShardingInfo.validate includes num_experts divisibility check
ast-grep --pattern 'class BMMShardingInfo {
  $$$
  def validate($$$) {
    $$$
  }
  $$$
}'
tensorrt_llm/_torch/auto_deploy/utils/_config.py (1)

36-48: Add error handling for file loading.

The _read_files method should handle potential errors when loading YAML files.

Add try-except block to handle file loading errors gracefully:

 def _read_files(self, files: PathType | None) -> dict[str, Any]:
     if files is None:
         return {}
     if isinstance(files, (str, os.PathLike)):
         files = [files]
 
     confs = []
     for file in files:
         file_path = Path(file).expanduser()
         if file_path.is_file():
-            confs.append(OmegaConf.load(file_path))
+            try:
+                confs.append(OmegaConf.load(file_path))
+            except Exception as e:
+                # Log error or raise with more context
+                raise ValueError(f"Failed to load YAML file {file_path}: {e}")
 
     return deep_merge_dicts(*confs)
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (1)

287-290: Fix incorrect use of set in conditional check.

The condition uses a set literal which always evaluates to True. This should be an or condition.

Replace the set with proper conditional logic:

-    elif {
-        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
-        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
-    }:
+    elif (
+        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear) or
+        is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear)
+    ):
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)

35-119: Critical performance issue: Sequential processing with item() calls.

The current implementation has the same significant performance bottlenecks identified in previous reviews:

  1. Using .item() forces CPU-GPU synchronization and prevents parallelization
  2. Sequential for-loop processing of batch items is inefficient
  3. Manual attention computation could leverage PyTorch's optimized operations

This implementation should be refactored to use vectorized operations and torch.nn.functional.scaled_dot_product_attention with appropriate masks instead of manual computation.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (1)

167-181: Fix type annotation for input_sample parameter

The type annotation for input_sample is incorrect - it should be Optional[torch.Tensor] instead of None.

-    input_sample: None,
+    input_sample: Optional[torch.Tensor],
🧹 Nitpick comments (13)
tensorrt_llm/_torch/auto_deploy/export/interface.py (3)

38-131: Excellent abstract base class implementation with robust error handling.

The context manager pattern is well-implemented with proper error handling and logging. A few observations:

  1. The design correctly uses @final to prevent overriding critical methods
  2. Good separation between patch application and reversion logic
  3. Proper handling of the skip_on_error configuration

Consider adding validation in from_kwargs to catch configuration errors early:

 @classmethod
 def from_kwargs(cls, **kwargs) -> "BaseExportPatch":
     """Create a patch from kwargs."""
+    # Validate kwargs against the config class to catch errors early
+    try:
+        config = cls.get_config_class()(**kwargs)
+    except Exception as e:
+        raise ValueError(f"Invalid configuration for {cls.__name__}: {e}")
-    config = cls.get_config_class()(**kwargs)
     return cls(config=config)

166-212: Consider enhancing error handling in the registry.

The registry implementation is clean, but could benefit from more descriptive error messages:

 @classmethod
 def get(cls, name: str) -> Type[BaseExportPatch]:
     """Get a patch class by name."""
+    if name not in cls._registry:
+        available = ', '.join(cls._registry.keys())
+        raise KeyError(f"Unknown patch '{name}'. Available patches: {available}")
     return cls._registry[name]

Also consider thread safety if patches might be registered dynamically at runtime.


214-249: Consider using contextlib.ExitStack for better readability.

While the recursive generator pattern works correctly, contextlib.ExitStack would be more idiomatic and easier to debug:

+from contextlib import ExitStack

 @contextmanager
 def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]):
     """Context manager to apply multiple patches.
 
     Args:
         patch_configs: Dict mapping patch names to their configurations.
     """
-    patches = []
-
-    # Create patch instances
-    for name, config in patch_configs.items():
-        if not ExportPatchRegistry.has(name):
-            raise ValueError(f"Unknown patch: {name}")
-        patch = ExportPatchRegistry.create_patch(name, config)
-        patches.append(patch)
-
-    # Apply patches using nested context managers
-    if not patches:
-        yield
-        return
-
-    def _apply_patches(remaining_patches):
-        if not remaining_patches:
-            yield
-            return
-
-        patch = remaining_patches[0]
-        with patch:
-            yield from _apply_patches(remaining_patches[1:])
-
-    # log applied patches
-    ad_logger.debug(
-        f"applying export patches: {', '.join([patch.get_patch_key() for patch in patches])}"
-    )
-
-    yield from _apply_patches(patches)
+    with ExitStack() as stack:
+        patches = []
+        
+        # Create and enter patch instances
+        for name, config in patch_configs.items():
+            if not ExportPatchRegistry.has(name):
+                raise ValueError(f"Unknown patch: {name}")
+            patch = ExportPatchRegistry.create_patch(name, config)
+            patches.append(patch)
+            stack.enter_context(patch)
+        
+        # log applied patches
+        if patches:
+            ad_logger.debug(
+                f"applying export patches: {', '.join([patch.get_patch_key() for patch in patches])}"
+            )
+        
+        yield

This approach is more straightforward and provides better exception handling and cleanup guarantees.

tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py (1)

60-113: Well-structured transformation with comprehensive dtype support.

The implementation correctly handles pattern matching for multiple data types. Consider adding a comment explaining why the scalar_workaround is needed for the eps parameter on line 107:

             dummy_args=dummy_args(input_dtype, weight_dtype),
             op_ignore_types={},
+            # Workaround: eps is a scalar that needs special handling in pattern matching
             scalar_workaround={"eps": 1e-6},
         )
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (1)

95-138: Consider extracting tensor initialization logic to a helper function.

The test implementation is thorough and well-structured. The repeated tensor initialization logic (lines 116-120 and 128-132) could be extracted to a helper function for better maintainability:

def initialize_meta_tensors(module, device="cuda"):
    """Initialize meta tensors with random values on the specified device."""
    module._apply(
        lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
        if t.device == torch.device("meta")
        else t.to(device)
    )

This would make the test more readable and reusable.

examples/auto_deploy/README.md (1)

249-275: Fix emphasis style for consistency.

The expert configuration section is well-written, but please update the emphasis style to use asterisks for consistency:

-  _exclusively_ exposed in the [`AutoDeployConfig` class](../../tensorrt_llm/_torch/auto_deploy/llm_args.py).
+  *exclusively* exposed in the [`AutoDeployConfig` class](../../tensorrt_llm/_torch/auto_deploy/llm_args.py).

-  object are overlapping, duplicated, and/or _ignored_ in AutoDeploy, particularly arguments
+  object are overlapping, duplicated, and/or *ignored* in AutoDeploy, particularly arguments
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (1)

46-67: Clean integration with modular optimizer pipeline.

The use of ModularInferenceOptimizer with configurable transforms is a significant improvement. The TODO indicates ongoing migration work.

Would you like me to help identify which legacy transforms should be prioritized for migration to the new optimizer?

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (1)

170-171: Address type annotation for optimizer configuration.

The # type: ignore comment suggests a type mismatch with _get_optimizer_config(). Consider adding proper type annotations to the helper function to ensure type safety.

-def _get_optimizer_config() -> InferenceOptimizerConfig:
+def _get_optimizer_config() -> InferenceOptimizerConfig:
     return {

Or alternatively, use the transform registry to create properly typed configs:

from tensorrt_llm._torch.auto_deploy.transform.interface import TransformRegistry

def _get_optimizer_config() -> InferenceOptimizerConfig:
    return {
        "build_model": TransformRegistry.get_config_class("build_model")(
            stage="factory",
            device="cuda",
            run_graph_cleanup=False,
            requires_clean_graph=False,
        ),
        # ... similar for other transforms
    }
tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (1)

188-189: Track the TODO for quantized MoE fusion support.

The TODO indicates that fused MoE support for quantized models is not yet implemented. Consider creating an issue to track this feature.

Would you like me to create an issue to track the implementation of trtllm_moe_fused support for quantized models?

tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_config.py (1)

428-428: Use specific exception type for YAML parsing errors.

Catch the specific YAML parsing exception instead of generic Exception for better error handling.

-    with pytest.raises(Exception):  # Should raise yaml parsing error
+    with pytest.raises(yaml.YAMLError):  # Should raise yaml parsing error

Note: You may need to import the appropriate YAML error class from OmegaConf or the underlying YAML library.

tensorrt_llm/_torch/auto_deploy/transform/interface.py (3)

191-270: Consider breaking down this orchestration method for improved readability.

While the implementation is correct, the __call__ method handles many responsibilities. Consider extracting some logic into helper methods.

For example, you could extract the history management and logging into separate methods:

+    def _update_history(self, gm: GraphModule, t_name: str, info: TransformInfo) -> None:
+        """Update the transform history in the graph metadata."""
+        autodeploy_meta = self._get_autodeploy_meta(gm)
+        history: TransformHistory = autodeploy_meta.get(self._history_key, {})
+        history[t_name] = info
+        autodeploy_meta[self._history_key] = history
+        self._set_autodeploy_meta(gm, autodeploy_meta)
+
+    def _log_transform_result(self, t_name: str, info: TransformInfo) -> None:
+        """Log the result of the transform."""
+        log_msgs = [
+            f"stage={self.config.stage.value}",
+            f"transform={t_name}",
+            "skipped=True" if info.skipped else f"num_matches={info.num_matches}",
+            f"is_clean={info.is_clean}",
+            f"has_valid_shapes={info.has_valid_shapes}",
+        ]
+        ad_logger.info(", ".join(log_msgs))
+        ad_logger.debug(f"Graph after {t_name}: {gm}")

This would make the main __call__ method more concise and easier to follow.


281-322: Consider extracting duplicated cleanup logic.

The cleanup logic is duplicated between _run_pre_cleanup and _run_post_cleanup.

+    def _should_run_shape_prop(self, info: TransformInfo, config_flag: bool) -> bool:
+        """Check if shape propagation should be run."""
+        return config_flag and not (info.is_clean and info.has_valid_shapes)
+
+    def _should_run_cleanup(self, info: TransformInfo, config_flag: bool) -> bool:
+        """Check if cleanup should be run."""
+        return config_flag and not info.is_clean
+
+    def _execute_cleanup(self, gm: GraphModule, run_shape_prop: bool) -> None:
+        """Execute the appropriate cleanup based on requirements."""
+        if run_shape_prop:
+            with lift_to_meta(gm):
+                canonicalize_graph(gm, shape_prop=True)
+        else:
+            canonicalize_graph(gm)

This would simplify both cleanup methods and make the logic more maintainable.


333-362: Add missing newline at end of file.

The registry implementation is excellent, but the file should end with a newline character per Python conventions.

     def has(cls, name: str) -> bool:
         """Check if a transform is registered."""
         return name in cls._registry
+
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f5e88c and 810df7d.

📒 Files selected for processing (105)
  • examples/auto_deploy/.vscode/launch.json (1 hunks)
  • examples/auto_deploy/README.md (6 hunks)
  • examples/auto_deploy/build_and_run_ad.py (5 hunks)
  • examples/auto_deploy/build_and_run_flux.py (2 hunks)
  • requirements.txt (1 hunks)
  • setup.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py (11 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py (9 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/export.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/interface.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/autocast_noop.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/linear.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/modelopt_context.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/sdpa.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/sdpa_kernel_noop.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/tensor_meta_device.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/torch_modulelist_getitem.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/torch_where.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/export/library/transformers_sdpa_mask.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/llm_args.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/factory.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/hf.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/decilm.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/phi.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (8 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/_graph.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/export.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py (0 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py (9 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py (11 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transformations/transform.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/_config.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (4 hunks)
  • tensorrt_llm/bench/benchmark/throughput.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_captured_graph.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_compiler.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py (5 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_attention_with_kv_cache.py (9 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_rms_norm.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (11 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (3 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_config.py (1 hunks)
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/auto_deploy/transformations/library/ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/transformations/export.py
✅ Files skipped from review due to trivial changes (8)
  • requirements.txt
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_compiler.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py
  • tensorrt_llm/_torch/auto_deploy/export/library/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/init.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/init.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🚧 Files skipped from review as they are similar to previous changes (72)
  • tensorrt_llm/_torch/auto_deploy/models/patches/deepseek.py
  • tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/decilm.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/phi.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_captured_graph.py
  • tensorrt_llm/_torch/auto_deploy/init.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/visualization.py
  • examples/auto_deploy/build_and_run_flux.py
  • tensorrt_llm/_torch/auto_deploy/export/init.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py
  • tensorrt_llm/_torch/auto_deploy/models/init.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/transformations/init.py
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tensorrt_llm/bench/benchmark/throughput.py
  • tensorrt_llm/_torch/auto_deploy/transform/init.py
  • setup.py
  • tensorrt_llm/_torch/auto_deploy/models/factory.py
  • tensorrt_llm/_torch/auto_deploy/export/library/torch_where.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/test_export.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/init.py
  • tensorrt_llm/_torch/auto_deploy/export/library/linear.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/export/library/torch_modulelist_getitem.py
  • tensorrt_llm/_torch/auto_deploy/export/library/sdpa.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/init.py
  • examples/auto_deploy/.vscode/launch.json
  • tensorrt_llm/_torch/auto_deploy/models/patches/mixtral.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/eliminate_redundant_transposes.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_add.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_noop_slice.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/export/library/tensor_meta_device.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/qwen3.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py
  • tensorrt_llm/_torch/auto_deploy/export/library/autocast_noop.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py
  • tensorrt_llm/_torch/auto_deploy/export/library/transformers_sdpa_mask.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py
  • tensorrt_llm/_torch/auto_deploy/export/library/sdpa_kernel_noop.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
  • tensorrt_llm/_torch/auto_deploy/transformations/_graph.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py
  • tensorrt_llm/_torch/auto_deploy/transform/optimizer.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_attention_with_kv_cache.py
  • tensorrt_llm/_torch/auto_deploy/export/export.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
  • tensorrt_llm/_torch/auto_deploy/models/hf.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • examples/auto_deploy/build_and_run_ad.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (1)

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🧬 Code Graph Analysis (5)
tensorrt_llm/_torch/auto_deploy/export/library/modelopt_context.py (1)
tensorrt_llm/_torch/auto_deploy/export/interface.py (4)
  • ContextManagerPatch (133-163)
  • ExportPatchRegistry (166-211)
  • register (172-181)
  • init_context_manager (146-152)
tensorrt_llm/_torch/auto_deploy/transformations/transform.py (10)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (3)
  • AttentionRegistry (627-650)
  • get (644-646)
  • get_attention_layout (513-514)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
  • AutoDeployConfig (41-221)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (19-76)
tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py (5)
  • ShardingConfig (248-253)
  • detect_column_row_shard (473-648)
  • detect_dp_bmm_shard (651-728)
  • detect_ep_shard (731-761)
  • sharding_transform_executor (256-284)
tensorrt_llm/_torch/auto_deploy/transformations/library/collectives.py (2)
  • fuse_allreduce_residual_rmsnorm (62-167)
  • fuse_collectives (18-59)
tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py (1)
  • update_in_out_nodes (17-45)
tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py (1)
  • match_attention_layout (292-379)
tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py (1)
  • quantize_moe (100-133)
tensorrt_llm/_torch/auto_deploy/distributed/common.py (2)
  • get (27-28)
  • get_rank_world_size (86-87)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp4_compatible (33-34)
  • fp8_compatible (29-30)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (58-60)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3)
  • torch_moe (44-78)
  • torch_quant_fp8_moe (159-217)
  • torch_quant_fp4_moe (239-305)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)
  • repeat_kv (26-40)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (19-76)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (1)
  • attention (768-880)
tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py (1)
  • match_attention_pattern (34-77)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (5)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (10)
  • MoEOpModel (170-206)
  • forward (74-94)
  • forward (106-108)
  • forward (130-135)
  • forward (154-159)
  • forward (184-203)
  • forward (220-222)
  • forward (232-234)
  • forward (247-253)
  • get_input (205-206)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp4_compatible (33-34)
  • fp8_compatible (29-30)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (58-60)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (3)
  • torch_moe (44-78)
  • torch_quant_fp8_moe (159-217)
  • torch_quant_fp4_moe (239-305)
🪛 markdownlint-cli2 (0.17.2)
examples/auto_deploy/README.md

262-262: Emphasis style
Expected: asterisk; Actual: underscore

(MD049, emphasis-style)


262-262: Emphasis style
Expected: asterisk; Actual: underscore

(MD049, emphasis-style)


267-267: Emphasis style
Expected: asterisk; Actual: underscore

(MD049, emphasis-style)


267-267: Emphasis style
Expected: asterisk; Actual: underscore

(MD049, emphasis-style)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (85)
tensorrt_llm/_torch/auto_deploy/export/interface.py (3)

15-18: LGTM!

Clean implementation of a custom exception class for export patch errors.


21-36: Well-designed configuration class using Pydantic.

The use of Pydantic with extra="allow" provides good extensibility for subclasses while maintaining type safety for base fields.


133-164: Clean adapter pattern for context managers.

This implementation elegantly wraps existing context managers into the patch interface. The lifecycle management is correct with proper cleanup in _revert_patch.

tensorrt_llm/_torch/auto_deploy/transformations/library/rms_norm.py (2)

1-18: Clean module setup with extensible backend mapping.

Good organization of imports and backend operations mapping.


21-36: Correct RMSNorm pattern implementation.

The pattern accurately captures the RMSNorm computation with proper handling of numerical stability through float32 conversion.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py (1)

46-48: Clean refactoring to in-place transformation pattern.

The simplified transform function correctly follows the new in-place modification pattern.

examples/auto_deploy/README.md (2)

9-9: Excellent documentation improvements.

The beta stage clarification and world-size description correction improve accuracy. The logging level section provides helpful debugging guidance.

Also applies to: 149-149, 175-189


277-445: Outstanding documentation of advanced configuration capabilities.

This section provides comprehensive coverage of the configuration system with clear examples and explanations. The documentation of CLI arguments with dot notation, YAML configuration files, deep merging behavior, and default configuration handling will be extremely helpful for expert users.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (3)

11-22: Import changes align with the new architecture.

The updated imports correctly reflect the modular transformation pipeline with separate detection and execution phases, and the migration to the new export module.


149-153: Clean separation of detection and execution phases.

The refactored transform_func properly uses the new two-phase approach with ShardingConfig, improving modularity and aligning with the architectural changes across the codebase.


172-309: Excellent test coverage for pattern detection logic.

The new _run_pattern_detection_job and test_sharding_pattern_detection provide comprehensive validation of the sharding pattern detection without requiring distributed execution. The expected transformations are correctly constructed for each model type (GQA_Block, MLP, nn.Linear) with appropriate split dimensions, distribution operations, and min_local_shape constraints.

tensorrt_llm/_torch/auto_deploy/transformations/transform.py (2)

42-45: Configuration update aligns with AutoDeploy refactoring.

The change from LlmArgs to AutoDeployConfig provides a cleaner, more focused configuration interface specifically for AutoDeploy.


105-120: Well-structured two-phase sharding transformation.

The detection and execution phases are cleanly separated using ShardingConfig, improving modularity and enabling independent testing of pattern detection logic. The approach correctly handles TP, EP, and BMM sharding patterns.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py (2)

71-90: Clean configuration for modular optimizer pipeline.

The _get_optimizer_config helper properly defines the transformation stages (factory, export, post_export) with appropriate settings for the test scenario.


92-106: Well-designed extension for embedding inputs.

The SequenceEmbeddingInfo class cleanly extends SequenceInfo to support 3D embedding tensors, enabling tests with models that expect embeddings rather than token IDs as input.

tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py (1)

19-98: Well-structured MoE quantization implementation.

The _quantize_moe_node function properly handles weight quantization, scale registration, and graph node replacement. The use of QuantizationImpl abstraction provides clean support for multiple quantization algorithms.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_sharding.py (2)

72-76: Consistent refactoring with two-phase sharding approach.

The transform function properly uses ShardingConfig with separate detection and execution phases, maintaining consistency with the architectural changes across the codebase.


138-151: Good test coverage for BMM pattern detection.

The parameterized test properly validates BMM sharding detection logic across different world sizes and expert multipliers without requiring distributed execution.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py (6)

166-173: LGTM! Causal mask creation simplified nicely.

The refactored mask creation is more concise and readable while maintaining the same functionality.


248-255: Good consistency with EagerAttentionModel mask creation.

The causal mask creation follows the same simplified pattern, maintaining consistency across model implementations.


417-430: Well-structured transformation pipeline.

The function properly combines the post-export cleanup with attention pattern matching in a modular way.


453-457: Good practice: Models set to eval mode for testing.

Adding .eval() ensures deterministic behavior during testing, especially important for attention mechanisms with dropout.

Also applies to: 467-471


487-492: Correct update to expected transformed node type.

The change from torch_attention_sdpa to torch_attention_grouped_sdpa correctly reflects the new unified attention pattern matching behavior.


632-633: Clear and specific test expectation comment.

The updated comment precisely states what node type is expected after transformation.

tensorrt_llm/_torch/auto_deploy/llm_args.py (7)

1-16: LGTM: Clean imports and type definitions.

The new imports support the modular configuration system and the PathLike type alias provides good flexibility for handling both string and Path objects.


18-38: LGTM: Well-structured configuration helpers.

The configuration dictionary setup and validation helper follow good patterns for Pydantic settings with YAML support.


41-136: LGTM: Well-designed configuration class.

The AutoDeployConfig class provides a clean, validated interface for AutoDeploy-specific configurations. The use of PathLike for model/tokenizer paths adds flexibility, and the constraints on max_beam_width appropriately reflect AutoDeploy's limitations.


138-199: LGTM: Appropriate backend expansion and validation.

The addition of the "torch" attention backend aligns with the new PyTorch attention implementation. The model validator correctly enforces that attn_page_size equals max_seq_len for triton and torch backends, which is a necessary constraint.


202-222: LGTM: Well-implemented utility methods.

The utility methods provide clean abstractions for factory creation and type conversion. The explicit conversion of Path objects to strings in create_factory() handles the current ModelFactory API correctly.


224-284: LGTM: Proper inheritance and validation.

The LlmArgs class appropriately extends AutoDeployConfig while maintaining BaseLlmArgs compatibility. The field validators correctly enforce AutoDeploy's parallelism model by ensuring parallel configuration fields remain at their defaults.


285-317: LGTM: Appropriate model validation and utilities.

The model validators correctly set up the parallel configuration for AutoDeploy's world_size-based model while maintaining compatibility with other runtime components. The utility methods provide necessary configuration access patterns.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py (6)

1-33: LGTM: Clean imports and correct logit softcapping implementation.

The logit softcapping function correctly implements the mathematical formula and handles the optional parameter appropriately.


252-338: LGTM: Well-structured dispatch function.

The main attention function properly handles input/output reshaping and correctly dispatches between generate and context phases based on sequence length. The dimension handling and output shape construction look correct.


340-357: LGTM: Correct fake tensor implementation.

The fake tensor registration correctly computes output shapes for PyTorch's dispatch system without performing actual computation.


359-392: LGTM: Proper metadata preparation.

The metadata preparation function correctly computes sequence start indices and uses appropriate sanitization. Both real and fake implementations are consistent.


394-456: LGTM: Well-implemented attention backend registration.

The TorchBackendAttention class properly integrates with the attention registry system. The cache initializers correctly handle shapes, devices, and data types.


458-496: LGTM: Robust parameter extraction and validation.

The constants extraction properly validates attention arguments and extracts necessary parameters. The warning for unsupported arguments helps with debugging and ensures correct usage.

tensorrt_llm/_torch/auto_deploy/transformations/library/sharding.py (10)

21-82: LGTM: Excellent architectural design.

The introduction of SplitDimension enum and the abstract ShardingTransformInfo base class provides a clean, type-safe foundation. The use of frozen Pydantic models ensures immutability, which is crucial for transformation specifications.


83-119: LGTM: Sound tensor parallelism implementation.

The TPShardingInfo class correctly validates the compatibility between split dimensions and distribution operations. The validation ensures row splits use all_gather and column splits use all_reduce, which is mathematically correct for tensor parallelism.


121-221: LGTM: Comprehensive BMM sharding implementation.

The BMMShardingInfo class handles the complexity of BMM sharding well. The validation properly checks batch size divisibility, and the apply method correctly handles both parameter and dynamic tensors while maintaining distributed computation semantics.


223-246: LGTM: Clean expert parallelism implementation.

The EPShardingInfo class provides appropriate validation for MoE nodes and cleanly delegates to the existing MoE sharding implementation.


248-285: LGTM: Well-architected execution framework.

The ShardingConfig container and sharding_transform_executor provide a clean separation between transformation specification and execution. The executor includes proper error handling and performance optimizations.


449-470: LGTM: Appropriate refactor for deferred execution.

The _append_simple_shard function correctly maintains the same sharding logic while adapting to the new deferred execution model by collecting TPShardingInfo instances instead of immediately applying transformations.


473-649: LGTM: Excellent separation of concerns.

The detect_column_row_shard function successfully separates detection from execution. The same detection logic is preserved while transformation specifications are collected for later execution, which is a significant architectural improvement.


651-729: LGTM: Consistent pattern application.

The detect_dp_bmm_shard function correctly applies the same deferred execution pattern while preserving the BMM sharding detection logic and validation.


731-762: LGTM: Consistent with architectural pattern.

The detect_ep_shard function correctly follows the deferred execution pattern established by the other detection functions while preserving MoE detection logic.


764-849: LGTM: Comprehensive MoE sharding implementation.

The _insert_sharded_moe function provides thorough expert parallelism support with proper expert partitioning, routing logic updates, and quantization handling. The distributed computation semantics are correctly maintained.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (4)

12-62: LGTM: Well-structured test setup helper function

The setup_moe_test function effectively modularizes the common test setup logic. The controlled random seeding, proper tensor initialization with scaling factor 0.1, and comprehensive return tuple provide a solid foundation for the test variations.


65-78: LGTM: Clean test refactoring with proper setup usage

The refactored test_moe_op_run function cleanly uses the new setup helper and maintains the same test logic. The parameterization across torch.float16 and torch.bfloat16 ensures good dtype coverage.


112-180: LGTM: Comprehensive FP8 quantization test implementation

The FP8 test properly:

  • Uses compatibility checks with appropriate skip conditions
  • Implements correct FP8 quantization with proper scaling factors (448 for bfloat16, 432 for float16)
  • Casts weights to torch.float8_e4m3fn format
  • Uses relaxed tolerances appropriate for quantized operations
  • Compares against both unquantized reference and torch_moe output

183-273: LGTM: Thorough FP4 quantization test with proper setup

The FP4 test demonstrates excellent implementation:

  • Proper conditional skipping for both FP4 compatibility and TensorRT LLM ops availability
  • Correct usage of fp4_global_scale utility for scale computation
  • Proper FP4 quantization using torch.ops.trtllm.fp4_quantize with appropriate parameters
  • Accurate alpha computation as 1 / (input_scale * weight_scale)
  • Appropriate relaxed tolerances for the quantized comparison

The test thoroughly validates the FP4 MoE operator implementation.

tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py (6)

17-32: LGTM: Well-designed utility function for pattern application

The _apply_pattern function provides a clean, reusable pattern for registering and applying pattern transformations. The conditional shape propagation and proper logging enhance maintainability and debuggability.


34-78: LGTM: Clean consolidation of attention pattern matching

The unified match_attention_pattern function effectively consolidates the previous separate pattern matchers into a cohesive approach. The separation into eager and grouped attention pattern registration provides clear organization while maintaining the same functionality.


80-213: LGTM: Comprehensive SFDP pattern definitions

The six SFDP pattern variants comprehensively cover different attention computation patterns:

  • With/without causal masks
  • With/without scaling division
  • With/without explicit dtype casting

Each pattern has a corresponding replacement that properly maps parameters to the torch_attention_sdpa operator. The scaling inversion for division patterns (e.g., scaling = 1.0 / scaling) is correctly implemented.


215-251: LGTM: Well-structured pattern configuration factory

The _get_sfdp_patterns function provides a clean factory pattern for generating pattern configurations. The tensor creation utilities and comprehensive parameter coverage ensure robust pattern matching across different scenarios.


254-289: LGTM: Proper grouped attention pattern definitions

The grouped attention patterns correctly implement the repeat K/V logic and provide appropriate replacements with torch_attention_grouped_sdpa. The pattern definitions maintain consistency with the expected tensor operations.


376-376: LGTM: Simplified in-place canonicalization

The removal of the return statement correctly reflects that match_attention_layout now performs in-place graph modifications. This aligns with the broader refactoring toward more consistent transformation patterns.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py (7)

12-201: Excellent comprehensive numpy reference implementation

The numpy_attention_reference function is exceptionally thorough, handling:

  • KV cache updates for both context and generate phases
  • Proper tensor reshaping for batch vs flattened cases
  • GQA support with correct key/value repetition
  • Causal masking with proper broadcasting
  • Sliding window attention with position-based masking
  • Logit softcapping with tanh activation
  • Sinks implementation with attention score concatenation
  • Proper softmax computation with numerical stability

The implementation correctly handles complex edge cases like empty sequences and different tensor layouts. This provides an excellent ground truth for testing the PyTorch backend.


203-270: LGTM: Well-structured test class with proper setup

The TestTorchBackendAttention class provides:

  • Proper CUDA setup and reproducible seeding
  • Appropriate tolerance values for fp16 vs fp32 comparisons (5e-2)
  • Clean test data creation with flexible parameter support
  • Proper tensor flattening for different phases (context vs generate)

The _create_test_data method correctly handles the different tensor layouts required for context and generate phases.


292-308: LGTM: Solid basic functionality test

The basic functionality test properly verifies:

  • Correct output tensor shape
  • Finite output values (no NaN/Inf)
  • Basic operation without crashes

This provides a good foundation before testing more complex features.


309-382: Well-designed comprehensive feature testing

The parameterized test for combined features effectively:

  • Tests logit capping, sliding window, and sinks in combination
  • Properly handles backend failures gracefully
  • Uses appropriate tolerances for different feature combinations
  • Validates against the numpy reference implementation

The acknowledgment that the backend has bugs with sinks while still testing the reference implementation demonstrates good testing practices.


383-411: LGTM: Thorough GQA functionality testing

The GQA test comprehensively validates different head ratios (8:4, 12:3, 16:1) and properly compares outputs against the numpy reference. This ensures the grouped query attention implementation works correctly across various configurations.


413-466: LGTM: Comprehensive phase testing

The context vs generate phase test properly validates both multi-token (context) and single-token (generate) scenarios. This ensures the backend correctly handles the different computation patterns and cache update logic for both phases.


467-487: LGTM: Metadata preparation test

The metadata preparation test validates the auxiliary operation used by the attention backend. The verification of result structure and tensor properties provides good coverage for this supporting operation.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py (5)

36-90: LGTM: Well-implemented FP8 quantized MLP block

The BlockSparseTop2MLPFP8 class correctly implements:

  • Proper FP8 quantization with dtype-dependent scaling factors (448 for bfloat16, 432 for float16)
  • Correct weight quantization and scale computation
  • Proper buffer and parameter registration
  • Accurate forward pass using torch_quant_fp8_linear operations

The implementation follows established FP8 quantization patterns and properly handles the activation function and tensor operations.


92-164: LGTM: Comprehensive FP4 quantized MLP block

The BlockSparseTop2MLPFP4 class demonstrates excellent FP4 quantization implementation:

  • Uses fp4_global_scale utility for proper scale computation
  • Correctly applies torch.ops.trtllm.fp4_quantize for weight quantization
  • Properly computes alpha values as 1 / (input_scale * weight_scale)
  • Registers all necessary tensors and metadata as buffers/parameters
  • Implements forward pass with torch_quant_fp4_linear operations

The implementation is thorough and follows FP4 quantization best practices.


184-206: LGTM: Enhanced BlockSparseMoE with quantization support

The updated BlockSparseMoE constructor properly:

  • Accepts quantization parameters (quant_type, input_sample, dtype, device)
  • Uses the factory function to create appropriate MLP blocks
  • Ensures proper dtype and device casting for the gate layer
  • Maintains backward compatibility with default parameters

241-253: LGTM: Updated MoEPatternModel with quantization support

The MoEPatternModel enhancements properly:

  • Accept quantization type parameter
  • Generate input sample for FP4 quantization scale computation
  • Pass quantization parameters to BlockSparseMoE
  • Increase embedding size to 64 for better test coverage

266-315: Excellent parameterized test design

The new parameterized test_moe_matching function provides comprehensive coverage:

  • Tests all three quantization variants (none, FP8, FP4)
  • Proper conditional skipping based on hardware/software compatibility
  • Appropriate tolerance values for each quantization type
  • Correct expected operator validation
  • Proper dtype handling for different quantization modes

The test design ensures robust validation across all supported MoE variants.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (6)

7-41: Excellent generic MoE template implementation

The _template_moe function provides a robust, reusable foundation for MoE operations:

  • Proper input handling: Correctly reshapes input tensors and preserves original shape
  • Expert validation: Uses masking to handle out-of-range expert indices gracefully
  • Empty input protection: Skips experts with zero tokens to prevent kernel issues (especially important for FP4)
  • Efficient dispatch: Uses one-hot encoding and expert masks for token routing
  • Consistent aggregation: Proper weighted aggregation with dtype preservation

This template eliminates code duplication across MoE variants while maintaining correctness and performance.


72-78: LGTM: Clean refactoring to use template

The refactored torch_moe function cleanly leverages the new template:

  • Constructs appropriate MLP lambdas with linear operations and SiLU activation
  • Delegates core MoE logic to the proven template function
  • Maintains the same API and behavior as before

82-90: LGTM: Consistent fake implementation naming

The renaming of fake implementations from torch_moe to torch_moe_fake and torch_fused_moe to torch_fused_moe_fake improves clarity and follows consistent naming conventions.

Also applies to: 148-155


158-217: LGTM: Well-implemented FP8 quantized MoE operator

The torch_quant_fp8_moe operator correctly:

  • Uses the established template function for consistent MoE logic
  • Constructs per-expert MLPs with torch_quant_fp8_linear operations
  • Properly handles all required FP8 quantization parameters (input_scale, weight_scale)
  • Implements the standard gate-up-down MLP pattern with SiLU activation
  • Maintains the same API structure as the reference implementation

238-305: LGTM: Comprehensive FP4 quantized MoE operator

The torch_quant_fp4_moe operator provides excellent FP4 quantization support:

  • Empty input handling: Explicit check for zero-sized inputs to prevent FP4 kernel issues
  • Complete parameter support: Handles all FP4 quantization parameters including alpha scales
  • Consistent structure: Uses the same MLP construction pattern as other variants
  • Template integration: Leverages the shared _template_moe function for reliable behavior

The implementation properly addresses FP4-specific requirements while maintaining consistency with other MoE variants.


220-235: LGTM: Appropriate fake implementations

Both quantized MoE operators have proper fake implementations that return torch.empty_like(x), maintaining consistency with PyTorch's fake tensor system for graph tracing and shape inference.

Also applies to: 308-326

tensorrt_llm/_torch/auto_deploy/transform/interface.py (10)

1-18: Well-structured imports and clear module documentation.

The imports are properly organized and the module docstring clearly describes the purpose.


20-24: Standard exception class implementation.

The custom exception follows Python conventions.


50-91: Well-designed configuration model with clear field documentation.

The Pydantic model provides good extensibility and clear configuration options for transforms.


92-96: Clear type aliases for configuration management.

Good use of type aliases to improve code readability.


98-125: Well-designed immutable result model.

The frozen configuration ensures transform results remain immutable, which is appropriate for tracking transformation history.


129-149: Good abstract base class design with registry integration.

The pattern of having the registry set _transform_key and the error handling in get_transform_key ensures proper registration.


150-190: Excellent initialization pattern with proper hooks for subclasses.

The use of @final on __init__ with a _post_init hook ensures consistent initialization while allowing customization.


271-280: Clean metadata access methods.

Good encapsulation of graph metadata access.


323-331: Clear abstract method definition.

The abstract method provides a clear contract for subclasses.


43-47: Fix the __lt__ implementation - incorrect list indexing.

The current implementation has a bug. You're trying to find the index of other (which is a Stages instance) in list(other.__class__), but this list contains the enum class itself, not enum members.

 def __lt__(self, other):
     """Enable sorting by definition order."""
     if self.__class__ is other.__class__:
-        return list(self.__class__).index(self) < list(other.__class__).index(other)
+        members = list(self.__class__)
+        return members.index(self) < members.index(other)
     return NotImplemented

Likely an incorrect or invalid review comment.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12450 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9258 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@lucaslie lucaslie disabled auto-merge July 21, 2025 17:13
@lucaslie lucaslie enabled auto-merge (squash) July 21, 2025 17:14
@tensorrt-cicd
Copy link
Collaborator

PR_Github #12460 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12460 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9268 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12475 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12475 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9280 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@lucaslie
Copy link
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12485 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12484 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12485 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12484 [ run ] completed with state FAILURE

@lucaslie
Copy link
Member Author

/bot run --disable-fail-fast

1 similar comment
@lucaslie
Copy link
Member Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12490 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12502 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12490 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12502 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9293 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12577 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12577 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9355 completed with status: 'SUCCESS'

@lucaslie lucaslie merged commit 41fb8aa into NVIDIA:main Jul 22, 2025
3 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in AutoDeploy Board Jul 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[AutoDeploy] Create export patch registry
9 participants