1515from torch .distributed ._tensor import DTensor
1616
1717from torchao .prototype .mx_formats .config import (
18+ MXFP8CastKernelChoice ,
1819 MXGemmKernelChoice ,
1920 MXInferenceLinearConfig ,
2021 MXLinearConfig ,
2122)
22- from torchao .prototype .mx_formats .kernels import triton_to_mxfp8_dim1
23+ from torchao .prototype .mx_formats .kernels import (
24+ mxfp8_quantize_cuda ,
25+ triton_to_mxfp8_dim1 ,
26+ )
2327from torchao .prototype .mx_formats .mx_tensor import MXTensor
2428from torchao .quantization .transform_module import (
2529 register_quantize_module_handler ,
@@ -66,6 +70,51 @@ def _triton_to_mxfp8_dim1_wrapper(
6670 return mx_tensor
6771
6872
73+ def _cuda_to_mxfp8_dim1_wrapper (
74+ a , block_size , elem_dtype , hp_dtype , gemm_kernel_choice
75+ ):
76+ _ , a_data , _ , a_scale = mxfp8_quantize_cuda (
77+ a ,
78+ rowwise = False ,
79+ colwise = True ,
80+ scaling_mode = "floor" ,
81+ )
82+ if isinstance (a_data , DTensor ):
83+ assert isinstance (a_scale , DTensor )
84+ a_data_local = a_data .to_local ()
85+ a_scale_local = a_scale .to_local ()
86+ inner = MXTensor (
87+ a_scale_local ,
88+ a_data_local .t (),
89+ elem_dtype ,
90+ block_size ,
91+ hp_dtype ,
92+ False ,
93+ gemm_kernel_choice ,
94+ False ,
95+ )
96+ mx_tensor = DTensor .from_local (
97+ inner ,
98+ a_data .device_mesh ,
99+ a_data .placements ,
100+ run_check = False ,
101+ shape = a_data .t ().size (),
102+ stride = a_data .t ().stride (),
103+ )
104+ else :
105+ mx_tensor = MXTensor (
106+ a_scale ,
107+ a_data .t (),
108+ elem_dtype ,
109+ block_size ,
110+ hp_dtype ,
111+ False ,
112+ gemm_kernel_choice ,
113+ False ,
114+ )
115+ return mx_tensor
116+
117+
69118@torch ._dynamo .allow_in_graph
70119class mx_mm (torch .autograd .Function ):
71120 # There are three gemms in a forward + backward of a Linear layer:
@@ -86,15 +135,15 @@ def forward(
86135 grad_elem_dtype : Any ,
87136 block_size : int ,
88137 gemm_kernel_choice : MXGemmKernelChoice ,
89- use_fp8_dim1_cast_triton_kernel : bool ,
138+ mxfp8_cast_kernel_choice : MXFP8CastKernelChoice ,
90139 ):
91140 ctx .save_for_backward (input_hp , weight_hp )
92141 ctx .in_elem_dtype = in_elem_dtype
93142 ctx .w_elem_dtype = w_elem_dtype
94143 ctx .grad_elem_dtype = grad_elem_dtype
95144 ctx .block_size = block_size
96145 ctx .gemm_kernel_choice = gemm_kernel_choice
97- ctx .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
146+ ctx .mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
98147
99148 # input @ weight_t = output
100149 input_orig_shape = input_hp .shape
@@ -119,7 +168,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
119168 grad_elem_dtype = ctx .grad_elem_dtype
120169 block_size = ctx .block_size
121170 gemm_kernel_choice = ctx .gemm_kernel_choice
122- use_fp8_dim1_cast_triton_kernel = ctx .use_fp8_dim1_cast_triton_kernel
171+ mxfp8_cast_kernel_choice = ctx .mxfp8_cast_kernel_choice
123172
124173 grad_output_orig_shape = grad_output_hp .shape
125174 grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
@@ -135,10 +184,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
135184 gemm_kernel_choice = gemm_kernel_choice ,
136185 )
137186
138- if use_fp8_dim1_cast_triton_kernel :
187+ if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice . TRITON :
139188 weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
140189 weight_hp , block_size , w_elem_dtype , weight_hp .dtype , gemm_kernel_choice
141190 )
191+ elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice .CUDA :
192+ weight_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper (
193+ weight_hp , block_size , w_elem_dtype , weight_hp .dtype , gemm_kernel_choice
194+ )
142195 else :
143196 weight_hp_t_c = weight_hp .t ().contiguous ()
144197 weight_mx_dim1 = MXTensor .to_mx (
@@ -153,14 +206,22 @@ def backward(ctx, grad_output_hp: torch.Tensor):
153206 )
154207
155208 # input_t @ grad_output = grad_weight
156- if use_fp8_dim1_cast_triton_kernel :
209+ if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice . TRITON :
157210 grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
158211 grad_output_hp_r ,
159212 block_size ,
160213 grad_elem_dtype ,
161214 grad_output_hp_r .dtype ,
162215 gemm_kernel_choice ,
163216 )
217+ elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice .CUDA :
218+ grad_output_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper (
219+ grad_output_hp_r ,
220+ block_size ,
221+ grad_elem_dtype ,
222+ grad_output_hp_r .dtype ,
223+ gemm_kernel_choice ,
224+ )
164225 else :
165226 grad_output_mx_dim1 = MXTensor .to_mx (
166227 grad_output_hp_r .t ().contiguous (),
@@ -169,7 +230,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
169230 gemm_kernel_choice = gemm_kernel_choice ,
170231 )
171232
172- if use_fp8_dim1_cast_triton_kernel :
233+ if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice . TRITON :
173234 input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper (
174235 input_hp_r ,
175236 block_size ,
@@ -178,6 +239,15 @@ def backward(ctx, grad_output_hp: torch.Tensor):
178239 gemm_kernel_choice ,
179240 )
180241 input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
242+ elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice .CUDA :
243+ input_t_mx_dim0_tmp = _cuda_to_mxfp8_dim1_wrapper (
244+ input_hp_r ,
245+ block_size ,
246+ in_elem_dtype ,
247+ input_hp_r .dtype ,
248+ gemm_kernel_choice ,
249+ )
250+ input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
181251 else :
182252 input_t_mx_dim0_tmp = MXTensor .to_mx (
183253 input_hp_r .t ().contiguous (),
@@ -232,7 +302,7 @@ def forward(self, x):
232302 config .elem_dtype_grad_output_override or config .elem_dtype ,
233303 config .block_size ,
234304 config .gemm_kernel_choice ,
235- config .use_fp8_dim1_cast_triton_kernel ,
305+ config .mxfp8_cast_kernel_choice ,
236306 )
237307 if self .bias is not None :
238308 y = y + self .bias
0 commit comments