-
Notifications
You must be signed in to change notification settings - Fork 365
Add CUDA kernel for MXFP8 dim1 casting #2513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ | |
| triton_to_mxfp8_dim1_reference, | ||
| unpack_uint4, | ||
| ) | ||
| from torchao.prototype.mx_formats.mx_tensor import MXTensor | ||
| from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx | ||
| from torchao.prototype.mx_formats.utils import to_blocked | ||
| from torchao.utils import ( | ||
| TORCH_VERSION_AT_LEAST_2_8, | ||
|
|
@@ -56,6 +56,15 @@ | |
| pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
|
||
|
|
||
| # TODO: shared utils file for benchmarking and testing | ||
| def to_mx_dim1_reference(x_hp, block_size, scaling_mode): | ||
| x_hp = x_hp.t().contiguous() | ||
| scale_d1, data_d1 = to_mx( | ||
| x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode | ||
| ) | ||
| return data_d1.t(), scale_d1 | ||
|
|
||
|
|
||
| @pytest.mark.skip( | ||
| reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501 | ||
| ) | ||
|
|
@@ -488,3 +497,99 @@ def test_rearrange(shape): | |
| eager = to_blocked(scales, False) | ||
| triton = to_blocked(scales, True) | ||
| torch.testing.assert_close(eager, triton, atol=0, rtol=0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100(), | ||
| reason="MXFP8 requires CUDA capability 10.0 or greater", | ||
| ) | ||
| @pytest.mark.parametrize("M", (32, 64, 2048)) | ||
| @pytest.mark.parametrize("K", (32, 64, 2048)) | ||
| @pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16)) | ||
| @pytest.mark.parametrize( | ||
| "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) | ||
| ) | ||
| def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): | ||
| from torchao.prototype import mxfp8_cuda | ||
|
|
||
| scaling_mode_str = ( | ||
| "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" | ||
| ) | ||
| block_size = 32 | ||
|
|
||
| # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. | ||
| x = ( | ||
| torch.arange(0, M * K, dtype=input_dtype, device="cuda") | ||
| .reshape(M, K) | ||
| .contiguous() | ||
| ) | ||
|
|
||
| y_d1_ref, s_d1_ref = to_mx_dim1_reference( | ||
| x, | ||
| block_size=block_size, | ||
| scaling_mode=scaling_mode, | ||
| ) | ||
|
|
||
| _, y_d1, _, s_d1 = mxfp8_cuda.quantize( | ||
| x, | ||
| rowwise=False, | ||
| colwise=True, | ||
| scaling_mode=scaling_mode_str, | ||
| scale_dim_x=1, | ||
| scale_dim_y=block_size, | ||
| ) | ||
|
|
||
| # check scales | ||
| torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) | ||
|
|
||
| # check quantized values | ||
| torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should also test the memory layout of all the tensors vs reference
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to verify memory layout of quantized tensor is identical. However, I deliberately wrote the scale tensor in a different memory layout, to avoid uncoalesced global accesses. This is also why I modified the triton_scale_swizzle kernel to accept column-major inputs in danielvegamyhre/private-torchao#19 |
||
| assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100(), | ||
| reason="MXFP8 requires CUDA capability 10.0 or greater", | ||
| ) | ||
| def test_cuda_mx_dim0_not_supported(): | ||
| from torchao.prototype import mxfp8_cuda | ||
|
|
||
| M, K = 64, 64 | ||
| block_size = 32 | ||
| x = ( | ||
| torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") | ||
| .reshape(M, K) | ||
| .contiguous() | ||
| ) | ||
| with pytest.raises(RuntimeError): | ||
| _, y_d1, _, s_d1 = mxfp8_cuda.quantize( | ||
| x, | ||
| rowwise=True, | ||
| colwise=False, | ||
| scale_dim_x=block_size, | ||
| scale_dim_y=1, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100(), | ||
| reason="MXFP8 requires CUDA capability 10.0 or greater", | ||
| ) | ||
| def test_cuda_mx_dim1_invalid_block_size(): | ||
| from torchao.prototype import mxfp8_cuda | ||
|
|
||
| M, K = 64, 64 | ||
| x = ( | ||
| torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") | ||
| .reshape(M, K) | ||
| .contiguous() | ||
| ) | ||
| invalid_block_size = 4 | ||
| with pytest.raises(RuntimeError): | ||
| _, y_d1, _, s_d1 = mxfp8_cuda.quantize( | ||
| x, | ||
| rowwise=False, | ||
| colwise=True, | ||
| scale_dim_x=1, | ||
| scale_dim_y=invalid_block_size, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove "floor" to match all the others, or add "floor" to all the others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added floor to others, I like the more explicit naming.