@@ -263,7 +263,7 @@ def _(func, types, args, kwargs):
263263@implements_torch_function (torch .matmul ) 
264264def  _ (func , types , args , kwargs ):
265265    input_tensor , weight_tensor  =  args [0 ], args [1 ]
266-     return  _float8_linear_impl (input_tensor , weight_tensor . t () )
266+     return  _float8_mm_impl (input_tensor , weight_tensor )
267267
268268
269269@implements (aten .addmm_ .default ) 
@@ -273,10 +273,24 @@ def _(func, types, args, kwargs):
273273        args [1 ],
274274        args [2 ] if  len (args ) >  2  else  None ,
275275    )
276-     out  =  _float8_linear_impl (input_tensor , weight_tensor . t () )
276+     out  =  _float8_mm_impl (input_tensor , weight_tensor )
277277    return  output_tensor .copy_ (out )
278278
279279
280+ def  _float8_mm_impl (
281+     input_tensor : torch .Tensor ,
282+     weight_tensor : torch .Tensor ,
283+ ) ->  torch .Tensor :
284+     assert  isinstance (weight_tensor , Float8Tensor ), (
285+         f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )}  
286+     )
287+     # Only support matmul(x, w.t()) for now 
288+     is_transposed  =  weight_tensor .qdata .stride (- 2 ) <  weight_tensor .qdata .stride (- 1 )
289+     if  not  is_transposed :
290+         raise  ValueError ("Only matmul(x, w.t()) is supported for now" )
291+     return  _float8_linear_impl (input_tensor , weight_tensor .t ())
292+ 
293+ 
280294def  _float8_linear_impl (
281295    input_tensor : torch .Tensor ,
282296    weight_tensor : torch .Tensor ,
@@ -286,10 +300,10 @@ def _float8_linear_impl(
286300        f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )}  
287301    )
288302
289-     # During  the backward pass, we transpose the weight tensor,  
290-     # so if  the weight tensor was originally rowwise quantized,  
291-     # now it becomes colwise. In this case, simply dequantize  
292-     # the tensor and do a bf16 matmul 
303+     # If we perform a matmul during  the backward pass (e.g. in a LoRA matmul  
304+     # autograd.Function),  the weight tensor will be transposed. If the weight  
305+     # tensor was originally rowwise quantized, now it becomes colwise.  
306+     # In this case, simply dequantize  the tensor and do a bf16 matmul 
293307    is_colwise  =  (
294308        weight_tensor .block_size [0 ] ==  weight_tensor .shape [0 ]
295309        and  weight_tensor .block_size [1 ] ==  1 
0 commit comments