@@ -492,6 +492,93 @@ def from_float(cls, weight):
492
492
block_size = (1 , weight .shape [1 ])
493
493
return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494
494
495
+ class AQFloat8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
496
+ """
497
+ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
498
+ """
499
+ @classmethod
500
+ def from_float (cls , weight ):
501
+
502
+ # avoid circular dep
503
+ from torchao .dtypes import to_affine_quantized_floatx
504
+ # weight settings
505
+ def get_weight_block_size (x ):
506
+ return (1 , x .shape [1 ])
507
+ target_dtype = torch .float8_e4m3fn
508
+
509
+ # input settings
510
+ def get_per_token_block_size (x ):
511
+ block_size = list (x .shape )
512
+ for i in range (len (block_size )- 1 ):
513
+ block_size [i ] = 1
514
+ return block_size
515
+
516
+ input_target_dtype = torch .float8_e4m3fn
517
+ layout_type = Float8LayoutType ()
518
+ input_quant_func = lambda x : to_affine_quantized_floatx (
519
+ input_float = x ,
520
+ block_size = get_per_token_block_size (x ),
521
+ target_dtype = input_target_dtype ,
522
+ layout_type = layout_type
523
+ )
524
+
525
+ block_size = get_weight_block_size (weight )
526
+ weight = to_affine_quantized_floatx (
527
+ input_float = weight ,
528
+ block_size = block_size ,
529
+ target_dtype = target_dtype ,
530
+ layout_type = layout_type
531
+ )
532
+ weight = super (AQFloat8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
533
+ return weight
534
+
535
+ @classmethod
536
+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
537
+ """
538
+ Tests and benchmarks the autoquantization process with special handling for interpolate mode.
539
+
540
+ Args:
541
+ act_mat (torch.Tensor): The activation matrix.
542
+ weight (torch.Tensor): The weight tensor.
543
+ bias (torch.Tensor or None): The bias tensor.
544
+ best_time (float): The best time to beat for the quantization process.
545
+ mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
546
+ (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
547
+
548
+ Returns:
549
+ float: The benchmarked time for the autoquantization process.
550
+ """
551
+ if not _is_interpolate_mode (mode ):
552
+ return super ()._autoquant_test (act_mat , weight , bias , best_time , mode )
553
+
554
+ # SAM best is between .8 and 1, SDXL also performs best in this range
555
+ INTERPOLATION_CONSTANT = mode [1 ]
556
+ w_qtensor = cls .from_float (weight )
557
+ x_vals_float8 , x_scales = quantize_activation_per_token_absmax (
558
+ act_mat .reshape (- 1 , act_mat .shape [- 1 ])
559
+ )
560
+ quantized_matmul = (
561
+ lambda x_vals_float8 , x_scales , w_vals_float8 :
562
+ safe_int_mm (x_vals_float8 , w_vals_float8 ) * x_scales
563
+ )
564
+ q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
565
+ with torch .no_grad ():
566
+ w_vals_float8 = w_qtensor .original_weight_tensor .layout_tensor .float8_data .contiguous ().t ()
567
+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_float8 , x_scales .reshape (- 1 ,1 ), w_vals_float8 )
568
+ print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
569
+
570
+ # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
571
+ if res_matmul >= best_time :
572
+ return res_matmul
573
+
574
+ # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
575
+ to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
576
+ res = super ()._autoquant_test (act_mat , weight , bias , to_beat )
577
+ max_float_const_win = (best_time - res_matmul )/ (res - res_matmul )
578
+ res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
579
+ print (f">>time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_float_const_win :0.2f} " )
580
+ return res_f
581
+
495
582
496
583
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497
584
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +598,7 @@ def from_float(cls, weight):
511
598
512
599
OTHER_AUTOQUANT_CLASS_LIST = [
513
600
AQFloat8WeightOnlyQuantizedLinearWeight ,
601
+ AQFloat8DynamicallyQuantizedLinearWeight ,
514
602
]
515
603
516
604
0 commit comments