Skip to content

Conversation

@sudhu2k
Copy link
Contributor

@sudhu2k sudhu2k commented Jan 15, 2026

Description

This PR enhances the GroupedLinear module with Triton kernel support for grouped GEMM operations, providing optimized performances. The implementation includes a complete Triton-based grouped matrix multiplication (GMM) backend that can be enabled via environment variables, along with pre-tuned configurations for optimal performance.

  • Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables.
  • Updated the setup.py to include JSON configuration files for Triton kernels in the package data.
  • Added a new test case for grouped GEMM functionality in the CI pipeline.
  • Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.

Benchmark results:

https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3739558113
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3746418683

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added Triton kernel support for GroupedLinear: Implemented a complete Triton-based grouped GEMM backend with support for dynamic kernel selection based on environment variables (NVTE_USE_GROUPED_GEMM_TRITON)
  • Added optional m_splits_tensor parameter to keep tensor data on GPU and avoid redundant CPU-GPU data transfers for improved performance
  • New GMM (Grouped Matrix Multiplication) module from AITER: Added comprehensive Triton kernel implementation in transformer_engine/pytorch/triton_kernels/gmm/ from AITER including:
    • gmm_common.py: Common utilities and helper functions
    • gmm_kernels.py: Core Triton kernel implementations for grouped GEMM operations
    • gmm_wrapper.py: High-level wrapper functions from AITER
    • pid_preprocessing.py: Process ID preprocessing for efficient kernel scheduling
  • Pre-tuned configurations: Added JSON configuration files for AMD GPU architectures:
    • gfx942-GMM.json: pre-tuned configs for gfx942 arch
    • gfx950-GMM.json: pre-tuned configs for gfx950 arch
  • Updated setup.py: Modified package data to include JSON configuration files for Triton kernels
  • Enhanced GroupedLinear module: Refactored grouped_linear.py to support Triton kernel path with proper tensor handling.
  • Added grouped_gemm.py wrapper: Created high-level interface in TE for grouped GEMM operations
  • Extended common utilities: Added Triton kernel support flags in triton_kernels/common.py
  • New test suite: Added comprehensive test cases (From AITER) in tests/pytorch/triton_kernels/test_grouped_gemm.py (516 lines)
  • CI integration: Updated ci/pytorch.sh to include grouped GEMM tests in the CI pipeline

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ipanfilo
Copy link
Collaborator

Update copyright date of modified files

ci/pytorch.sh Outdated
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
run_default_fa 1 triton_kernels/test_grouped_gemm.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move it two lines higher, alphabetical sort helps to find tests

delay_wgrad_compute,
parallel_mode=None,
):
os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This env won't be cleared if the test is skipped of failed

else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)

if not use_grouped_gemm_triton:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it elif

group_sizes_list=kwargs.get("m_splits_list", []),
)

grad_biases = [None] * len(m_splits) if bias is None else bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_splits.shape[0] or len(m_splits_list)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_splits is a 1D tensor so len(m_splits) equals m_splits.shape[0].

import torch
m_splits = torch.tensor([3, 5, 7, 9]) # 1D tensor
print("len(m_splits):", len(m_splits))
len(m_splits): 4
print("m_splits.shape[0]:", m_splits.shape[0])
m_splits.shape[0]: 4

package_data = {"": ["VERSION.txt"]}
package_data = {
"": ["VERSION.txt"],
"transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be part of pytorch extension installation not TE core

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried moving it to pytorch extension installation, the CI TE installation couldn't move the config files to the TE installation directory and was throwing
2026-01-19T07:36:00.3196138Z ), f"'{config_filename}' isn't an existent file." 2026-01-19T07:36:00.3196361Z E AssertionError: '/opt/venv/lib/python3.11/site-packages/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx950-GMM.json' isn't an existent file.
Seems like the core installation and pytorch installation are separate setups and core is the one being installed by the CI. Let me know if I'm doing something wrong. Thanks!

_ = general_grouped_gemm(
general_grouped_gemm_func = general_grouped_gemm_triton if use_grouped_gemm_triton else general_grouped_gemm
# Prepare m_splits for each backend
m_splits_for_kernel = m_splits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be more straightforward to keep m_splits as-is and add mandatory parameter m_splits_tensor or m_splits_for_kernel to general_grouped_gemm_triton(), instead of swapping them here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified it

sugovind added 9 commits January 19, 2026 06:29
…for JSON configs

- Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables.
- Updated the setup.py to include JSON configuration files for Triton kernels in the package data.
- Added a new test case for grouped GEMM functionality in the CI pipeline.
- Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.
…pdate setup for Triton kernel support

- Removed JSON config inclusion from setup.py package data.
- Updated CI test script to fix the order of tests
- Enhanced grouped linear accuracy test with improved environment variable handling.
- Adjusted grouped GEMM function to streamline tensor handling and ensure compatibility with Triton kernels.
@sudhu2k sudhu2k force-pushed the sudhu/aiter_grouped_gemm_integration branch from e58d3aa to 4261bc0 Compare January 19, 2026 06:34
sugovind added 2 commits January 19, 2026 06:43
Reverts the core setup.py removal so configs are shipped by both
transformer_engine and transformer_engine_torch. Some users install
core only and still need gmm configs; keep torch packaging too.
@sudhu2k sudhu2k self-assigned this Jan 19, 2026
package_data = {"": ["VERSION.txt"]}
package_data = {
"": ["VERSION.txt"],
"transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sugovind added 3 commits January 21, 2026 22:55
…tensor handling for Triton support. Removed unused kwargs parameter in general_grouped_gemm function signature.
@sudhu2k sudhu2k requested review from ipanfilo and wangye805 January 22, 2026 17:37
sugovind added 2 commits January 23, 2026 06:19
…dtype_from_str with str_to_torch_dtype for improved consistency and clarity. Updated supported data types to include fp32, fp16, and bf16.
@sudhu2k
Copy link
Contributor Author

sudhu2k commented Jan 23, 2026

Working on adding fp32 compatibility for test_grouped_gemm.py function from AITER. Currently fp32 goes "out of shared memory" because of the chosen configs.

sugovind added 2 commits January 23, 2026 22:39
…figuration handling. Adjusted tolerance levels in tests and modified get_config to adapt block sizes based on dtype.
…merical differences. Updated test logic to conditionally skip based on M size.
import functools
import torch

from transformer_engine.pytorch.triton_kernels.grouped_gemm import general_grouped_gemm_triton
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard by IS_HIP_EXTENSION?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function call is already guarded:

use_grouped_gemm_triton = IS_HIP_EXTENSION and os.getenv("NVTE_USE_GROUPED_GEMM_TRITON", "0") == "1" and not fp8 and not fuse_wgrad_accumulation

Do we still want to guard the import? It shouldn't affect anything. Let me know!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, technically we want all our rocm specific code changes guarded even during importing, not just when running. This can avoid and rocm-specific dependency when we import general_grouped_gemm_triton

# See LICENSE for license information.

"""GroupedLinear API"""
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does NV upstream need this package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't. We need this since we want to access the env variable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it's better we put it under IS_HIP_EXTENSION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the changes.

@sudhu2k sudhu2k requested a review from wangye805 January 27, 2026 16:59
Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants