1515from torch .distributed ._tensor import DTensor
1616
1717from torchao .prototype .mx_formats .config import (
18+ MXFP8Dim1CastKernelChoice ,
1819 MXGemmKernelChoice ,
1920 MXInferenceLinearConfig ,
2021 MXLinearConfig ,
2122)
22- from torchao .prototype .mx_formats .kernels import triton_to_mxfp8_dim1
23+ from torchao .prototype .mx_formats .kernels import (
24+ mxfp8_quantize_cuda ,
25+ triton_to_mxfp8_dim1 ,
26+ )
2327from torchao .prototype .mx_formats .mx_tensor import MXTensor
2428from torchao .quantization .transform_module import (
2529 register_quantize_module_handler ,
2630)
2731
2832
29- def _triton_to_mxfp8_dim1_wrapper (
30- a , block_size , elem_dtype , hp_dtype , gemm_kernel_choice
33+ def _to_mxfp8_dim1_kernel_wrapper (
34+ a ,
35+ block_size ,
36+ elem_dtype ,
37+ hp_dtype ,
38+ gemm_kernel_choice ,
39+ cast_kernel_choice ,
3140):
32- a_data , a_scale = triton_to_mxfp8_dim1 (a , block_size )
41+ if cast_kernel_choice == MXFP8Dim1CastKernelChoice .TRITON :
42+ a_data , a_scale = triton_to_mxfp8_dim1 (a , block_size )
43+ elif cast_kernel_choice == MXFP8Dim1CastKernelChoice .CUDA :
44+ _ , a_data , _ , a_scale = mxfp8_quantize_cuda (
45+ a ,
46+ rowwise = False ,
47+ colwise = True ,
48+ scaling_mode = "floor" ,
49+ )
50+ else :
51+ raise ValueError (f"must be one of [CUDA, TRITON], got { cast_kernel_choice } " )
52+
3353 if isinstance (a_data , DTensor ):
3454 assert isinstance (a_scale , DTensor )
3555 a_data_local = a_data .to_local ()
@@ -86,15 +106,15 @@ def forward(
86106 grad_elem_dtype : Any ,
87107 block_size : int ,
88108 gemm_kernel_choice : MXGemmKernelChoice ,
89- use_fp8_dim1_cast_triton_kernel : bool ,
109+ mxfp8_cast_kernel_choice : MXFP8Dim1CastKernelChoice ,
90110 ):
91111 ctx .save_for_backward (input_hp , weight_hp )
92112 ctx .in_elem_dtype = in_elem_dtype
93113 ctx .w_elem_dtype = w_elem_dtype
94114 ctx .grad_elem_dtype = grad_elem_dtype
95115 ctx .block_size = block_size
96116 ctx .gemm_kernel_choice = gemm_kernel_choice
97- ctx .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
117+ ctx .mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
98118
99119 # input @ weight_t = output
100120 input_orig_shape = input_hp .shape
@@ -119,7 +139,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
119139 grad_elem_dtype = ctx .grad_elem_dtype
120140 block_size = ctx .block_size
121141 gemm_kernel_choice = ctx .gemm_kernel_choice
122- use_fp8_dim1_cast_triton_kernel = ctx .use_fp8_dim1_cast_triton_kernel
142+ mxfp8_cast_kernel_choice = ctx .mxfp8_cast_kernel_choice
123143
124144 grad_output_orig_shape = grad_output_hp .shape
125145 grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
@@ -135,9 +155,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
135155 gemm_kernel_choice = gemm_kernel_choice ,
136156 )
137157
138- if use_fp8_dim1_cast_triton_kernel :
139- weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
140- weight_hp , block_size , w_elem_dtype , weight_hp .dtype , gemm_kernel_choice
158+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice .TORCH :
159+ weight_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper (
160+ weight_hp ,
161+ block_size ,
162+ w_elem_dtype ,
163+ weight_hp .dtype ,
164+ gemm_kernel_choice ,
165+ mxfp8_cast_kernel_choice ,
141166 )
142167 else :
143168 weight_hp_t_c = weight_hp .t ().contiguous ()
@@ -153,13 +178,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
153178 )
154179
155180 # input_t @ grad_output = grad_weight
156- if use_fp8_dim1_cast_triton_kernel :
157- grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
181+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice . TORCH :
182+ grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper (
158183 grad_output_hp_r ,
159184 block_size ,
160185 grad_elem_dtype ,
161186 grad_output_hp_r .dtype ,
162187 gemm_kernel_choice ,
188+ mxfp8_cast_kernel_choice ,
163189 )
164190 else :
165191 grad_output_mx_dim1 = MXTensor .to_mx (
@@ -169,13 +195,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
169195 gemm_kernel_choice = gemm_kernel_choice ,
170196 )
171197
172- if use_fp8_dim1_cast_triton_kernel :
173- input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper (
198+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice . TORCH :
199+ input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper (
174200 input_hp_r ,
175201 block_size ,
176202 in_elem_dtype ,
177203 input_hp_r .dtype ,
178204 gemm_kernel_choice ,
205+ mxfp8_cast_kernel_choice ,
179206 )
180207 input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
181208 else :
@@ -232,7 +259,7 @@ def forward(self, x):
232259 config .elem_dtype_grad_output_override or config .elem_dtype ,
233260 config .block_size ,
234261 config .gemm_kernel_choice ,
235- config .use_fp8_dim1_cast_triton_kernel ,
262+ config .mxfp8_cast_kernel_choice ,
236263 )
237264 if self .bias is not None :
238265 y = y + self .bias
0 commit comments