@@ -221,7 +221,6 @@ def do_autoquant_bench(op, *args, **kwargs):
221221 stream .synchronize ()
222222 torch .cuda .current_stream ().wait_stream (stream )
223223 torch .cuda .synchronize ()
224-
225224 graph = torch .cuda .CUDAGraph ()
226225 with torch .cuda .graph (graph , stream = stream ):
227226 op (* args , ** kwargs )
@@ -492,6 +491,92 @@ def from_float(cls, weight):
492491 block_size = (1 , weight .shape [1 ])
493492 return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494493
494+ class AQFloat8DynamicallyQuantizedLinearWeight (AQMixin , LinearActivationQuantizedTensor ):
495+ """
496+ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
497+ """
498+ @classmethod
499+ def from_float (cls , weight ):
500+
501+ # avoid circular dep
502+ from torchao .dtypes import to_affine_quantized_floatx
503+ # weight settings
504+ def get_weight_block_size (x ):
505+ return (1 , x .shape [1 ])
506+ target_dtype = torch .float8_e4m3fn
507+
508+ # input settings
509+ def get_per_token_block_size (x ):
510+ block_size = list (x .shape )
511+ for i in range (len (block_size )- 1 ):
512+ block_size [i ] = 1
513+ return block_size
514+
515+ input_target_dtype = torch .float8_e4m3fn
516+ layout_type = Float8LayoutType ()
517+ input_quant_func = lambda x : to_affine_quantized_floatx (
518+ input_float = x ,
519+ block_size = get_per_token_block_size (x ),
520+ target_dtype = input_target_dtype ,
521+ layout_type = layout_type
522+ )
523+ block_size = get_weight_block_size (weight )
524+ weight = to_affine_quantized_floatx (
525+ input_float = weight ,
526+ block_size = block_size ,
527+ target_dtype = target_dtype ,
528+ layout_type = layout_type
529+ )
530+ weight = super (AQFloat8DynamicallyQuantizedLinearWeight , cls ).from_float (weight , input_quant_func )
531+ return weight
532+
533+ @classmethod
534+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
535+ """
536+ Tests and benchmarks the autoquantization process with special handling for interpolate mode.
537+
538+ Args:
539+ act_mat (torch.Tensor): The activation matrix.
540+ weight (torch.Tensor): The weight tensor.
541+ bias (torch.Tensor or None): The bias tensor.
542+ best_time (float): The best time to beat for the quantization process.
543+ mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
544+ (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
545+
546+ Returns:
547+ float: The benchmarked time for the autoquantization process.
548+ """
549+ if not _is_interpolate_mode (mode ):
550+ return super ()._autoquant_test (act_mat , weight , bias , best_time , mode )
551+
552+ # SAM best is between .8 and 1, SDXL also performs best in this range
553+ INTERPOLATION_CONSTANT = mode [1 ]
554+ w_qtensor = cls .from_float (weight )
555+ x_vals_float8 , x_scales = quantize_activation_per_token_absmax (
556+ act_mat .reshape (- 1 , act_mat .shape [- 1 ])
557+ )
558+ quantized_matmul = (
559+ lambda x_vals_float8 , x_scales , w_vals_float8 :
560+ safe_int_mm (x_vals_float8 , w_vals_float8 ) * x_scales
561+ )
562+ q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs" )
563+ with torch .no_grad ():
564+ w_vals_float8 = w_qtensor .original_weight_tensor .layout_tensor .float8_data .contiguous ().t ()
565+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_float8 , x_scales .reshape (- 1 ,1 ), w_vals_float8 )
566+ print (f">>time: { res_matmul :0.3f} ms for { cls } matmul, to_beat: { best_time :0.3f} ms" )
567+
568+ # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
569+ if res_matmul >= best_time :
570+ return res_matmul
571+
572+ # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
573+ to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
574+ res = super ()._autoquant_test (act_mat , weight , bias , to_beat )
575+ max_float_const_win = (best_time - res_matmul )/ (res - res_matmul )
576+ res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
577+ print (f">>time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_float_const_win :0.2f} " )
578+ return res_f
579+
495580
496581# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497582DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +596,7 @@ def from_float(cls, weight):
511596
512597OTHER_AUTOQUANT_CLASS_LIST = [
513598 AQFloat8WeightOnlyQuantizedLinearWeight ,
599+ AQFloat8DynamicallyQuantizedLinearWeight ,
514600]
515601
516602
0 commit comments