17
17
)
18
18
from torchao .utils import TORCH_VERSION_AT_LEAST_2_3 , TORCH_VERSION_AT_LEAST_2_5
19
19
from torchao .quantization .utils import quantize_activation_per_token_absmax
20
+ from torchao .float8 .inference import addmm_float8_unwrapped_inference
20
21
21
22
import torch .nn .functional as F
22
23
@@ -518,14 +519,16 @@ def get_per_token_block_size(x):
518
519
input_float = x ,
519
520
block_size = get_per_token_block_size (x ),
520
521
target_dtype = input_target_dtype ,
521
- layout_type = layout_type
522
+ layout_type = layout_type ,
523
+ scale_dtype = torch .float32 ,
522
524
)
523
525
block_size = get_weight_block_size (weight )
524
526
weight = to_affine_quantized_floatx (
525
527
input_float = weight ,
526
528
block_size = block_size ,
527
529
target_dtype = target_dtype ,
528
- layout_type = layout_type
530
+ layout_type = layout_type ,
531
+ scale_dtype = torch .float32 ,
529
532
)
530
533
weight = super (AQFloat8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
531
534
return weight
@@ -555,14 +558,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
555
558
x_vals_float8 , x_scales = quantize_activation_per_token_absmax (
556
559
act_mat .reshape (- 1 , act_mat .shape [- 1 ]), dtype = torch .float8_e4m3fn
557
560
)
558
- quantized_matmul = (
559
- lambda x_vals_float8 , x_scales , w_vals_float8 :
560
- safe_int_mm (x_vals_float8 , w_vals_float8 ) * x_scales
561
- )
562
- q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
561
+ q_c_matmul = torch .compile (addmm_float8_unwrapped_inference , mode = "max-autotune-no-cudagraphs" )
563
562
with torch .no_grad ():
564
563
w_vals_float8 = w_qtensor .original_weight_tensor .layout_tensor .float8_data .contiguous ().t ()
565
- res_matmul = do_autoquant_bench (q_c_matmul , x_vals_float8 , x_scales .reshape (- 1 ,1 ), w_vals_float8 )
564
+ w_scales = w_qtensor .original_weight_tensor .layout_tensor .scale
565
+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_float8 , x_scales .reshape (- 1 , 1 ), w_vals_float8 , w_scales .reshape (1 , - 1 ), torch .float32 )
566
566
print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
567
567
568
568
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
@@ -586,6 +586,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
586
586
# AQInt8WeightOnlyQuantizedLinearWeight3,
587
587
# TODO this gets picked in places where it makes perf worse, why?
588
588
AQInt8DynamicallyQuantizedLinearWeight ,
589
+ AQFloat8DynamicallyQuantizedLinearWeight ,
589
590
]
590
591
591
592
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
0 commit comments