-
Notifications
You must be signed in to change notification settings - Fork 23
Enhance GroupedLinear with integrating AITER triton kernels #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
|
Update copyright date of modified files |
ci/pytorch.sh
Outdated
| run_default_fa 1 triton_kernels/test_cast_mxfp8.py | ||
| run_default_fa 1 triton_kernels/test_norm_common.py | ||
| run_default_fa 1 triton_kernels/test_norms.py | ||
| run_default_fa 1 triton_kernels/test_grouped_gemm.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move it two lines higher, alphabetical sort helps to find tests
tests/pytorch/test_numerics.py
Outdated
| delay_wgrad_compute, | ||
| parallel_mode=None, | ||
| ): | ||
| os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This env won't be cleared if the test is skipped of failed
| else: | ||
| inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) | ||
|
|
||
| if not use_grouped_gemm_triton: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it elif
| group_sizes_list=kwargs.get("m_splits_list", []), | ||
| ) | ||
|
|
||
| grad_biases = [None] * len(m_splits) if bias is None else bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_splits.shape[0] or len(m_splits_list)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_splits is a 1D tensor so len(m_splits) equals m_splits.shape[0].
import torch
m_splits = torch.tensor([3, 5, 7, 9]) # 1D tensor
print("len(m_splits):", len(m_splits))
len(m_splits): 4
print("m_splits.shape[0]:", m_splits.shape[0])
m_splits.shape[0]: 4
| package_data = {"": ["VERSION.txt"]} | ||
| package_data = { | ||
| "": ["VERSION.txt"], | ||
| "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be part of pytorch extension installation not TE core
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ipanfilo , does Sudharshan also need changes in https://github.com/ROCm/TransformerEngine/blob/dev/build_tools/wheel_utils/build_wheels.sh for TE wheel building?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I tried moving it to pytorch extension installation, the CI TE installation couldn't move the config files to the TE installation directory and was throwing
2026-01-19T07:36:00.3196138Z ), f"'{config_filename}' isn't an existent file." 2026-01-19T07:36:00.3196361Z E AssertionError: '/opt/venv/lib/python3.11/site-packages/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx950-GMM.json' isn't an existent file.
Seems like the core installation and pytorch installation are separate setups and core is the one being installed by the CI. Let me know if I'm doing something wrong. Thanks!
| _ = general_grouped_gemm( | ||
| general_grouped_gemm_func = general_grouped_gemm_triton if use_grouped_gemm_triton else general_grouped_gemm | ||
| # Prepare m_splits for each backend | ||
| m_splits_for_kernel = m_splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be more straightforward to keep m_splits as-is and add mandatory parameter m_splits_tensor or m_splits_for_kernel to general_grouped_gemm_triton(), instead of swapping them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified it
…for JSON configs - Added support for using Triton kernels in GroupedLinear, allowing for optimized performance based on environment variables. - Updated the setup.py to include JSON configuration files for Triton kernels in the package data. - Added a new test case for grouped GEMM functionality in the CI pipeline. - Refactored the handling of input tensors and gradients to accommodate the new Triton kernel logic.
…pdate setup for Triton kernel support - Removed JSON config inclusion from setup.py package data. - Updated CI test script to fix the order of tests - Enhanced grouped linear accuracy test with improved environment variable handling. - Adjusted grouped GEMM function to streamline tensor handling and ensure compatibility with Triton kernels.
e58d3aa to
4261bc0
Compare
Reverts the core setup.py removal so configs are shipped by both transformer_engine and transformer_engine_torch. Some users install core only and still need gmm configs; keep torch packaging too.
| package_data = {"": ["VERSION.txt"]} | ||
| package_data = { | ||
| "": ["VERSION.txt"], | ||
| "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ipanfilo , does Sudharshan also need changes in https://github.com/ROCm/TransformerEngine/blob/dev/build_tools/wheel_utils/build_wheels.sh for TE wheel building?
….py in pytorch dir
…tensor handling for Triton support. Removed unused kwargs parameter in general_grouped_gemm function signature.
… ensure consistency in parameterization.
…dtype_from_str with str_to_torch_dtype for improved consistency and clarity. Updated supported data types to include fp32, fp16, and bf16.
|
Working on adding fp32 compatibility for test_grouped_gemm.py function from AITER. Currently fp32 goes "out of shared memory" because of the chosen configs. |
…figuration handling. Adjusted tolerance levels in tests and modified get_config to adapt block sizes based on dtype.
…merical differences. Updated test logic to conditionally skip based on M size.
| import functools | ||
| import torch | ||
|
|
||
| from transformer_engine.pytorch.triton_kernels.grouped_gemm import general_grouped_gemm_triton |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard by IS_HIP_EXTENSION?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function call is already guarded:
| use_grouped_gemm_triton = IS_HIP_EXTENSION and os.getenv("NVTE_USE_GROUPED_GEMM_TRITON", "0") == "1" and not fp8 and not fuse_wgrad_accumulation |
Do we still want to guard the import? It shouldn't affect anything. Let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Emm, technically we want all our rocm specific code changes guarded even during importing, not just when running. This can avoid and rocm-specific dependency when we import general_grouped_gemm_triton
| # See LICENSE for license information. | ||
|
|
||
| """GroupedLinear API""" | ||
| import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does NV upstream need this package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't. We need this since we want to access the env variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then it's better we put it under IS_HIP_EXTENSION
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the changes.
…IP extension support in grouped_linear.py
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
This PR enhances the GroupedLinear module with Triton kernel support for grouped GEMM operations, providing optimized performances. The implementation includes a complete Triton-based grouped matrix multiplication (GMM) backend that can be enabled via environment variables, along with pre-tuned configurations for optimal performance.
Benchmark results:
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3739558113
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3746418683
Type of change
Changes
NVTE_USE_GROUPED_GEMM_TRITON)m_splits_tensorparameter to keep tensor data on GPU and avoid redundant CPU-GPU data transfers for improved performancetransformer_engine/pytorch/triton_kernels/gmm/from AITER including:gmm_common.py: Common utilities and helper functionsgmm_kernels.py: Core Triton kernel implementations for grouped GEMM operationsgmm_wrapper.py: High-level wrapper functions from AITERpid_preprocessing.py: Process ID preprocessing for efficient kernel schedulinggfx942-GMM.json: pre-tuned configs for gfx942 archgfx950-GMM.json: pre-tuned configs for gfx950 archgrouped_linear.pyto support Triton kernel path with proper tensor handling.triton_kernels/common.pytests/pytorch/triton_kernels/test_grouped_gemm.py(516 lines)ci/pytorch.shto include grouped GEMM tests in the CI pipelineChecklist: