@@ -259,11 +259,18 @@ def _(func, types, args, kwargs):
259259    return  _float8_linear_impl (input_tensor , weight_tensor , bias )
260260
261261
262- @implements (aten .mm .default ) 
262+ @implements (aten .matmul .default ) 
263263@implements_torch_function (torch .matmul ) 
264264def  _ (func , types , args , kwargs ):
265265    input_tensor , weight_tensor  =  args [0 ], args [1 ]
266-     return  _float8_mm_impl (input_tensor , weight_tensor )
266+     return  _float8_matmul_impl (input_tensor , weight_tensor )
267+ 
268+ 
269+ @implements (aten .mm .default ) 
270+ @implements_torch_function (torch .mm ) 
271+ def  _ (func , types , args , kwargs ):
272+     input_tensor , weight_tensor  =  args [0 ], args [1 ]
273+     return  _float8_matmul_impl (input_tensor , weight_tensor )
267274
268275
269276@implements (aten .addmm_ .default ) 
@@ -275,24 +282,110 @@ def _(func, types, args, kwargs):
275282    )
276283    assert  kwargs .get ("alpha" , 1 ) ==  1 , "only alpha=1 is supported" 
277284    assert  kwargs .get ("beta" , 1 ) ==  1 , "only beta=1 is supported" 
278-     out  =  _float8_mm_impl (input_tensor , weight_tensor )
285+     out  =  _float8_matmul_impl (input_tensor , weight_tensor )
279286    return  bias_tensor .add_ (out )
280287
281288
282- def  _float8_mm_impl (
283-     input_tensor : torch .Tensor ,
284-     weight_tensor : torch .Tensor ,
289+ def  _get_matmul_kernel_choice (weight_tensor : Float8Tensor ) ->  str :
290+     """ 
291+     Return the kernel choice for matmuls, either "fbgemm" or "torch". 
292+     """ 
293+     if  weight_tensor .kernel_preference  ==  KernelPreference .AUTO :
294+         kernel_choice  =  "torch" 
295+         if  _is_fbgemm_gpu_genai_available () and  is_sm_at_least_90 ():
296+             kernel_choice  =  "fbgemm" 
297+     elif  weight_tensor .kernel_preference  ==  KernelPreference .FBGEMM :
298+         kernel_choice  =  "fbgemm" 
299+     else :
300+         assert  weight_tensor .kernel_preference  ==  KernelPreference .TORCH , (
301+             f"{ weight_tensor .kernel_preference = }  
302+         )
303+         kernel_choice  =  "torch" 
304+     return  kernel_choice 
305+ 
306+ 
307+ def  _call_fbgemm_f8f8bf16_matmul (
308+     input_tensor : Float8Tensor ,
309+     weight_tensor : Float8Tensor ,
310+     bias : Optional [torch .Tensor ] =  None ,
311+     weight_is_already_transposed : bool  =  False ,
285312) ->  torch .Tensor :
286-     assert  isinstance (weight_tensor , Float8Tensor ), (
287-         f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )}  
313+     """ 
314+     Call `torch.ops.fbgemm.f8f8bf16*` ops. 
315+ 
316+     These ops expect the weight tensor to be transposed. 
317+     If `weight_is_already_transposed=True` (e.g. in the linear case), 
318+     then we avoid unnecessarily double transposing the weight. 
319+     """ 
320+     assert  _is_fbgemm_gpu_genai_available (), (
321+         "Expected fbgemm_gpu_genai package to be installed" 
288322    )
289-     is_transposed  =  weight_tensor .qdata .stride (- 2 ) <  weight_tensor .qdata .stride (- 1 )
290-     # For matmul(x, w.t()), just call the linear implementation 
291-     # For matmul(x, w), just dequantize for now, we can optimize later 
292-     if  is_transposed :
293-         return  _float8_linear_impl (input_tensor , weight_tensor .t ())
323+     assert  is_sm_at_least_90 (), "Expected SM90+ for fbgemm_gpu_genai" 
324+     mm_config  =  weight_tensor .mm_config 
325+     assert  mm_config  is  not None 
326+ 
327+     if  not  weight_is_already_transposed :
328+         weight_tensor  =  weight_tensor .t ()
329+ 
330+     out_shape  =  get_out_shape (input_tensor .shape , weight_tensor .shape )
331+     xq  =  input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
332+     wq  =  weight_tensor .qdata 
333+     x_scale  =  input_tensor .scale 
334+     w_scale  =  weight_tensor .scale 
335+     if  _is_rowwise_scaled (weight_tensor ):
336+         assert  _is_rowwise_scaled (input_tensor ), (
337+             "Input tensor must be rowwise block size" 
338+         )
339+         res  =  torch .ops .fbgemm .f8f8bf16_rowwise (
340+             xq ,
341+             wq ,
342+             x_scale ,
343+             w_scale ,
344+             bias = bias ,
345+             use_fast_accum = mm_config .use_fast_accum ,
346+         ).reshape (out_shape )
294347    else :
295-         return  torch .matmul (input_tensor , weight_tensor .dequantize ())
348+         assert  _is_tensorwise_scaled (weight_tensor )
349+         assert  _is_tensorwise_scaled (input_tensor )
350+         res  =  torch .ops .fbgemm .f8f8bf16 (
351+             xq ,
352+             wq ,
353+             x_scale  *  w_scale ,
354+             use_fast_accum = mm_config .use_fast_accum ,
355+         ).reshape (out_shape )
356+         if  bias  is  not None :
357+             res  =  res  +  bias 
358+     return  res 
359+ 
360+ 
361+ def  _call_float8_scaled_mm (
362+     input_tensor : Float8Tensor ,
363+     weight_tensor : Float8Tensor ,
364+     bias : Optional [torch .Tensor ] =  None ,
365+ ) ->  torch .Tensor :
366+     scaled_mm_config  =  weight_tensor .mm_config 
367+     assert  scaled_mm_config  is  not None 
368+     assert  weight_tensor .dim () ==  2 
369+     out_shape  =  (* input_tensor .shape [:- 1 ], weight_tensor .shape [1 ])
370+ 
371+     # Extract tensor data and scales 
372+     inpt_data  =  input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
373+     w_data  =  weight_tensor .qdata 
374+     input_scale  =  input_tensor .scale 
375+     w_scale  =  weight_tensor .scale 
376+ 
377+     input_scale  =  preprocess_scale (input_scale , input_tensor .shape )
378+     inpt_data , w_data  =  preprocess_data (inpt_data , w_data , scaled_mm_config )
379+ 
380+     return  addmm_float8_unwrapped_inference (
381+         inpt_data ,
382+         input_scale ,
383+         w_data ,
384+         w_scale ,
385+         output_dtype = input_tensor .dtype ,
386+         bias = bias ,
387+         use_fast_accum = scaled_mm_config .use_fast_accum ,
388+     ).reshape (out_shape )
296389
297390
298391def  _float8_linear_impl (
@@ -312,88 +405,17 @@ def _float8_linear_impl(
312405        )
313406
314407    if  isinstance (input_tensor , Float8Tensor ):
315-         kernel_choice  =  None 
316- 
317-         if  weight_tensor .kernel_preference  ==  KernelPreference .AUTO :
318-             kernel_choice  =  "torch" 
319-             if  _is_fbgemm_gpu_genai_available () and  is_sm_at_least_90 ():
320-                 kernel_choice  =  "fbgemm" 
321-         elif  weight_tensor .kernel_preference  ==  KernelPreference .FBGEMM :
322-             kernel_choice  =  "fbgemm" 
323-         else :
324-             assert  weight_tensor .kernel_preference  ==  KernelPreference .TORCH , (
325-                 f"{ weight_tensor .kernel_preference = }  
326-             )
327-             kernel_choice  =  "torch" 
328- 
408+         kernel_choice  =  _get_matmul_kernel_choice (weight_tensor )
329409        if  kernel_choice  ==  "fbgemm" :
330-             assert  _is_fbgemm_gpu_genai_available (), (
331-                 "Expected fbgemm_gpu_genai package to be installed" 
410+             return  _call_fbgemm_f8f8bf16_matmul (
411+                 input_tensor ,
412+                 weight_tensor ,
413+                 bias ,
414+                 weight_is_already_transposed = True ,
332415            )
333-             assert  is_sm_at_least_90 (), "Expected SM90+ for fbgemm_gpu_genai" 
334-             mm_config  =  weight_tensor .mm_config 
335-             assert  mm_config  is  not None 
336- 
337-             out_shape  =  get_out_shape (input_tensor .shape , weight_tensor .shape )
338-             xq  =  input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
339-             wq  =  weight_tensor .qdata 
340-             x_scale  =  input_tensor .scale 
341-             w_scale  =  weight_tensor .scale 
342-             if  _is_rowwise_scaled (weight_tensor ):
343-                 assert  _is_rowwise_scaled (input_tensor ), (
344-                     "Input tensor must be rowwise block size" 
345-                 )
346-                 res  =  torch .ops .fbgemm .f8f8bf16_rowwise (
347-                     xq ,
348-                     wq ,
349-                     x_scale ,
350-                     w_scale ,
351-                     bias = bias ,
352-                     use_fast_accum = mm_config .use_fast_accum ,
353-                 ).reshape (out_shape )
354-             else :
355-                 assert  _is_tensorwise_scaled (weight_tensor )
356-                 assert  _is_tensorwise_scaled (input_tensor )
357-                 res  =  torch .ops .fbgemm .f8f8bf16 (
358-                     xq ,
359-                     wq ,
360-                     x_scale  *  w_scale ,
361-                     use_fast_accum = mm_config .use_fast_accum ,
362-                 ).reshape (out_shape )
363-                 if  bias  is  not None :
364-                     res  =  res  +  bias 
365-             return  res 
366416        else :
367417            assert  kernel_choice  ==  "torch" 
368-             scaled_mm_config  =  weight_tensor .mm_config 
369-             assert  scaled_mm_config  is  not None 
370-             out_shape  =  get_out_shape (input_tensor .shape , weight_tensor .shape )
371- 
372-             # Extract tensor data and scales 
373-             inpt_data  =  input_tensor .qdata .reshape (- 1 , input_tensor .qdata .shape [- 1 ])
374-             w_data  =  weight_tensor .qdata 
375-             input_scale  =  input_tensor .scale 
376-             w_scale  =  weight_tensor .scale 
377- 
378-             # Handle rowwise scaling 
379-             if  _is_rowwise_scaled (weight_tensor ):
380-                 assert  _is_rowwise_scaled (input_tensor ), (
381-                     "Input tensor must be rowwise block size" 
382-                 )
383-                 w_scale  =  w_scale .transpose (- 1 , - 2 )
384- 
385-             input_scale  =  preprocess_scale (input_scale , input_tensor .shape )
386-             inpt_data , w_data  =  preprocess_data (inpt_data , w_data .T , scaled_mm_config )
387- 
388-             return  addmm_float8_unwrapped_inference (
389-                 inpt_data ,
390-                 input_scale ,
391-                 w_data ,
392-                 w_scale ,
393-                 output_dtype = input_tensor .dtype ,
394-                 bias = bias ,
395-                 use_fast_accum = scaled_mm_config .use_fast_accum ,
396-             ).reshape (out_shape )
418+             return  _call_float8_scaled_mm (input_tensor , weight_tensor .t (), bias )
397419    else :
398420        assert  not  isinstance (input_tensor , TorchAOBaseTensor ), (
399421            "Expecting input_tensor to be unquantized" 
@@ -405,6 +427,37 @@ def _float8_linear_impl(
405427        )
406428
407429
430+ def  _float8_matmul_impl (
431+     input_tensor : torch .Tensor ,
432+     weight_tensor : torch .Tensor ,
433+ ) ->  torch .Tensor :
434+     assert  isinstance (weight_tensor , Float8Tensor ), (
435+         f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )}  
436+     )
437+ 
438+     act_quant_kwargs  =  weight_tensor .act_quant_kwargs 
439+     # quantizing activation, if `act_quant_kwargs` is specified 
440+     if  act_quant_kwargs  is  not None :
441+         input_tensor  =  _choose_quant_func_and_quantize_tensor (
442+             input_tensor , act_quant_kwargs 
443+         )
444+ 
445+     if  isinstance (input_tensor , Float8Tensor ):
446+         kernel_choice  =  _get_matmul_kernel_choice (weight_tensor )
447+         if  kernel_choice  ==  "fbgemm" :
448+             return  _call_fbgemm_f8f8bf16_matmul (input_tensor , weight_tensor )
449+         else :
450+             assert  kernel_choice  ==  "torch" 
451+             return  _call_float8_scaled_mm (input_tensor , weight_tensor )
452+     else :
453+         assert  not  isinstance (input_tensor , TorchAOBaseTensor ), (
454+             "Expecting input_tensor to be unquantized" 
455+         )
456+         # when input is not `Float8Tensor`, we expect that it is not quantized 
457+         # so this is float8 weight only quantization 
458+         return  torch .matmul (input_tensor , weight_tensor .dequantize ())
459+ 
460+ 
408461@implements_torch_function (torch .bmm ) 
409462def  _ (func , types , args , kwargs ):
410463    input_tensor , weight_tensor  =  (
0 commit comments