@@ -263,7 +263,7 @@ def _(func, types, args, kwargs):
263263@implements_torch_function(torch.matmul)
264264def _(func, types, args, kwargs):
265265    input_tensor, weight_tensor = args[0], args[1]
266-     return _float8_linear_impl (input_tensor, weight_tensor.t() )
266+     return _float8_mm_impl (input_tensor, weight_tensor)
267267
268268
269269@implements(aten.addmm_.default)
@@ -273,10 +273,24 @@ def _(func, types, args, kwargs):
273273        args[1],
274274        args[2] if len(args) > 2 else None,
275275    )
276-     out = _float8_linear_impl (input_tensor, weight_tensor.t() )
276+     out = _float8_mm_impl (input_tensor, weight_tensor)
277277    return output_tensor.copy_(out)
278278
279279
280+ def _float8_mm_impl(
281+     input_tensor: torch.Tensor,
282+     weight_tensor: torch.Tensor,
283+ ) -> torch.Tensor:
284+     assert isinstance(weight_tensor, Float8Tensor), (
285+         f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
286+     )
287+     # Only support matmul(x, w.t()) for now
288+     is_transposed = weight_tensor.qdata.stride(-2) < weight_tensor.qdata.stride(-1)
289+     if not is_transposed:
290+         raise ValueError("matmul with non-transposed Float8Tensor not supported yet")
291+     return _float8_linear_impl(input_tensor, weight_tensor.t())
292+ 
293+ 
280294def _float8_linear_impl(
281295    input_tensor: torch.Tensor,
282296    weight_tensor: torch.Tensor,
@@ -286,10 +300,10 @@ def _float8_linear_impl(
286300        f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
287301    )
288302
289-     # During  the backward pass, we transpose the weight tensor, 
290-     # so if  the weight tensor was originally rowwise quantized, 
291-     # now it becomes colwise. In this case, simply dequantize 
292-     # the tensor and do a bf16 matmul
303+     # If we perform a matmul during  the backward pass (e.g. in a LoRA matmul 
304+     # autograd.Function),  the weight tensor will be transposed. If the weight 
305+     # tensor was originally rowwise quantized, now it becomes colwise. 
306+     # In this case, simply dequantize  the tensor and do a bf16 matmul
293307    is_colwise = (
294308        weight_tensor.block_size[0] == weight_tensor.shape[0]
295309        and weight_tensor.block_size[1] == 1
0 commit comments