Skip to content

Commit ebcfb9e

Browse files
committed
float8 dynamic autoquant
1 parent f646519 commit ebcfb9e

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torchao/quantization/autoquant.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
1919
from torchao.quantization.utils import quantize_activation_per_token_absmax
20+
from torchao.float8.inference import addmm_float8_unwrapped_inference
2021

2122
import torch.nn.functional as F
2223

@@ -518,14 +519,16 @@ def get_per_token_block_size(x):
518519
input_float=x,
519520
block_size=get_per_token_block_size(x),
520521
target_dtype=input_target_dtype,
521-
layout_type=layout_type
522+
layout_type=layout_type,
523+
scale_dtype=torch.float32,
522524
)
523525
block_size = get_weight_block_size(weight)
524526
weight = to_affine_quantized_floatx(
525527
input_float=weight,
526528
block_size=block_size,
527529
target_dtype=target_dtype,
528-
layout_type=layout_type
530+
layout_type=layout_type,
531+
scale_dtype=torch.float32,
529532
)
530533
weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
531534
return weight
@@ -555,14 +558,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
555558
x_vals_float8, x_scales = quantize_activation_per_token_absmax(
556559
act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn
557560
)
558-
quantized_matmul = (
559-
lambda x_vals_float8, x_scales, w_vals_float8:
560-
safe_int_mm(x_vals_float8, w_vals_float8) * x_scales
561-
)
562-
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
561+
q_c_matmul=torch.compile(addmm_float8_unwrapped_inference, mode="max-autotune-no-cudagraphs")
563562
with torch.no_grad():
564563
w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t()
565-
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8)
564+
w_scales = w_qtensor.original_weight_tensor.layout_tensor.scale
565+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1, 1), w_vals_float8, w_scales.reshape(1, -1), torch.float32)
566566
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
567567

568568
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
@@ -586,6 +586,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
586586
# AQInt8WeightOnlyQuantizedLinearWeight3,
587587
# TODO this gets picked in places where it makes perf worse, why?
588588
AQInt8DynamicallyQuantizedLinearWeight,
589+
AQFloat8DynamicallyQuantizedLinearWeight,
589590
]
590591

591592
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [

0 commit comments

Comments
 (0)