-
Notifications
You must be signed in to change notification settings - Fork 22
Enable MXFP8 support in TE JAX integration #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
…or gfx950 ci enablement
…ed with hipblaslt
|
Note that the CI failure is unrelated. |
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still many other places needs rocm specific guards
| (2048, 1024, 1024), | ||
| ] | ||
|
|
||
| TEST_SHAPES = [(64, 32, 64), (128, 64, 128), (128, 256, 256)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard this by is_hip_extension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| const int arch = cuda::sm_arch(); | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| if (arch < 100 && is_fp8_gemm) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we also need similar check to filter non-NT gemms for gfx942
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other conditions we need to guard against? I'm not sure what our support looks like for gfx942 vs gfx950 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For gfx942, we only support NT layout, but for gfx950, we support others (NN, TN)
|
|
||
| size_t num_non_empty_gemms = lhs_list.size(); | ||
|
|
||
| if (is_mxfp8_scaling) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This need rocm specific guards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) | ||
| @pytest_parametrize_wrapper("layout", ["NN"]) | ||
| def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): | ||
| if scaling_mode.is_1d_block_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCM guards needed here as well
| ) | ||
| @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) | ||
| def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): | ||
| if scaling_mode.is_1d_block_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCm guards needed as well
| # | ||
| # See LICENSE for license information. | ||
| from typing import Callable, Sequence, Union, Optional | ||
| from packaging import version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this newly added version used anywhere in this test script?
| BIAS_2_AXES = (W_NO_SHARD_AXES,) | ||
| INTERMEDIATE = 64 | ||
|
|
||
| INTERMEDIATE = 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does NV upstream also need the intermediate set to 128?
| ) | ||
| ) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An extra empty line
| device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config | ||
| layernorm_type = "rmsnorm" | ||
|
|
||
| inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to move this up?
| inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( | ||
| input_shape, activation_type, use_bias, dtype | ||
| ) | ||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need rocm specific guards
| return SdyShardingRule( | ||
| (dz_axes, x_axes, ("…2",)), | ||
| (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), | ||
| (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv_rule, amax, dbias), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put a to-do here for future investigation
| rowwise_scale_inv_shape | ||
| ) | ||
| # Slice out the padding for mxfp8 -- the kernel writes to strided | ||
| # 2D positions, not contiguous. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 324 and 325 comments are from one sentence? Do we have a lint specification for how long a comment should be? Here Line 326 looks longer than 324+325
Description
Enables MXFP8 support in the TE JAX integration and significantly modifies tests to account for remaining support gaps
TODO:
test_layernorm_mlp_grad{_shardy}when using hipblaslt GEMM + bias + no scaling + bf16 dtypeFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
tests/jax/test_custom_call_compute.pytest_custom_call_compute::assert_bitwise_scaled_tensorstest_custom_call_computetest_dense_grad_fp8to allow testing of MXFP8 (currently does not support bias)tests/jax/test_distributed_layernorm_mlp.pywith new test shape to allow for MXFP8 usagexfailto certain configs that fail with the new test case (needs a follow-up investigation)transformer_engine/common/gemm/rocm_gemm.cuIS_NORMtemplate parameter fromcast_mxfp8_2D_kernelscale_invswizzling before GEMM on ROCmscale_invun-padding behavior inNormFwdPrimitiveNormFwdPrimitiveGroupedGemmFFItransformer_engine/jax/layernorm_mlp.pyChecklist: