Skip to content

Commit 345bb63

Browse files
committed
Only support matmul(x, w.t()) for now
1 parent c0f4b4e commit 345bb63

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _(func, types, args, kwargs):
263263
@implements_torch_function(torch.matmul)
264264
def _(func, types, args, kwargs):
265265
input_tensor, weight_tensor = args[0], args[1]
266-
return _float8_linear_impl(input_tensor, weight_tensor.t())
266+
return _float8_mm_impl(input_tensor, weight_tensor)
267267

268268

269269
@implements(aten.addmm_.default)
@@ -273,10 +273,24 @@ def _(func, types, args, kwargs):
273273
args[1],
274274
args[2] if len(args) > 2 else None,
275275
)
276-
out = _float8_linear_impl(input_tensor, weight_tensor.t())
276+
out = _float8_mm_impl(input_tensor, weight_tensor)
277277
return output_tensor.copy_(out)
278278

279279

280+
def _float8_mm_impl(
281+
input_tensor: torch.Tensor,
282+
weight_tensor: torch.Tensor,
283+
) -> torch.Tensor:
284+
assert isinstance(weight_tensor, Float8Tensor), (
285+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
286+
)
287+
# Only support matmul(x, w.t()) for now
288+
is_transposed = weight_tensor.qdata.stride(-2) < weight_tensor.qdata.stride(-1)
289+
if not is_transposed:
290+
raise ValueError("matmul with non-transposed Float8Tensor not supported yet")
291+
return _float8_linear_impl(input_tensor, weight_tensor.t())
292+
293+
280294
def _float8_linear_impl(
281295
input_tensor: torch.Tensor,
282296
weight_tensor: torch.Tensor,
@@ -286,10 +300,10 @@ def _float8_linear_impl(
286300
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
287301
)
288302

289-
# During the backward pass, we transpose the weight tensor,
290-
# so if the weight tensor was originally rowwise quantized,
291-
# now it becomes colwise. In this case, simply dequantize
292-
# the tensor and do a bf16 matmul
303+
# If we perform a matmul during the backward pass (e.g. in a LoRA matmul
304+
# autograd.Function), the weight tensor will be transposed. If the weight
305+
# tensor was originally rowwise quantized, now it becomes colwise.
306+
# In this case, simply dequantize the tensor and do a bf16 matmul
293307
is_colwise = (
294308
weight_tensor.block_size[0] == weight_tensor.shape[0]
295309
and weight_tensor.block_size[1] == 1

0 commit comments

Comments
 (0)