Skip to content

Conversation

@Micky774
Copy link
Contributor

Description

Enables MXFP8 support in the TE JAX integration and significantly modifies tests to account for remaining support gaps

TODO:

  • Investigate grouped-GEMM test failures for MXFP8
  • Investigate newly exposed failures in test_layernorm_mlp_grad{_shardy} when using hipblaslt GEMM + bias + no scaling + bf16 dtype

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Adds new test cases to tests/jax/test_custom_call_compute.py
  • Corrects comparison bug in test_custom_call_compute::assert_bitwise_scaled_tensors
  • Removes dtype-based skip on MXFP8 GEMMs in test_custom_call_compute
  • Adds shape-based skips for MXFP8 GEMMs across several files
  • Added bias parameterization to test_dense_grad_fp8 to allow testing of MXFP8 (currently does not support bias)
  • Updated tests/jax/test_distributed_layernorm_mlp.py with new test shape to allow for MXFP8 usage
  • Added xfail to certain configs that fail with the new test case (needs a follow-up investigation)
  • Added explicit shape checks to transformer_engine/common/gemm/rocm_gemm.cu
  • Removed IS_NORM template parameter from cast_mxfp8_2D_kernel
  • Disabled scale_inv swizzling before GEMM on ROCm
  • Corrected scale_inv un-padding behavior in NormFwdPrimitive
  • Removed redundant un-padding in NormFwdPrimitive
  • Removed swizzling in GroupedGemmFFI
  • Skipped grouped GEMM MXFP8 tests due to outstanding failures (needs follow-up investigation)
  • Corrected bias bug in transformer_engine/jax/layernorm_mlp.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774
Copy link
Contributor Author

Note that the CI failure is unrelated.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still many other places needs rocm specific guards

(2048, 1024, 1024),
]

TEST_SHAPES = [(64, 32, 64), (128, 64, 128), (128, 256, 256)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard this by is_hip_extension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const int arch = cuda::sm_arch();

#ifndef __HIP_PLATFORM_AMD__
if (arch < 100 && is_fp8_gemm) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we also need similar check to filter non-NT gemms for gfx942

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other conditions we need to guard against? I'm not sure what our support looks like for gfx942 vs gfx950 here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For gfx942, we only support NT layout, but for gfx950, we support others (NN, TN)


size_t num_non_empty_gemms = lhs_list.size();

if (is_mxfp8_scaling) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This need rocm specific guards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode.is_1d_block_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCM guards needed here as well

)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode.is_1d_block_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm guards needed as well

#
# See LICENSE for license information.
from typing import Callable, Sequence, Union, Optional
from packaging import version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this newly added version used anywhere in this test script?

BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64

INTERMEDIATE = 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does NV upstream also need the intermediate set to 128?

)
)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An extra empty line

device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm"

inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to move this up?

inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
input_shape, activation_type, use_bias, dtype
)
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need rocm specific guards

return SdyShardingRule(
(dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv_rule, amax, dbias),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a to-do here for future investigation

rowwise_scale_inv_shape
)
# Slice out the padding for mxfp8 -- the kernel writes to strided
# 2D positions, not contiguous.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 324 and 325 comments are from one sentence? Do we have a lint specification for how long a comment should be? Here Line 326 looks longer than 324+325

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants