Skip to content

Commit ffca55c

Browse files
committed
float8 dynamic autoquant
1 parent fbe97a0 commit ffca55c

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

test/integration/test_integration.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
7575
AQFloat8WeightOnlyQuantizedLinearWeight,
76+
AQFloat8DynamicallyQuantizedLinearWeight,
7677
)
7778
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7879
import os
@@ -753,6 +754,14 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
753754
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
754755
)
755756

757+
@parameterized.expand(COMMON_DEVICE_DTYPE)
758+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
759+
@unittest.skipIf(not is_H100, "Need H100 to run")
760+
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
761+
self._test_lin_weight_subclass_impl(
762+
AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
763+
)
764+
756765
@parameterized.expand(COMMON_DEVICE_DTYPE)
757766
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
758767
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/quantization/autoquant.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,93 @@ def from_float(cls, weight):
492492
block_size = (1, weight.shape[1])
493493
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())
494494

495+
class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
496+
"""
497+
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
498+
"""
499+
@classmethod
500+
def from_float(cls, weight):
501+
502+
# avoid circular dep
503+
from torchao.dtypes import to_affine_quantized_floatx
504+
# weight settings
505+
def get_weight_block_size(x):
506+
return (1, x.shape[1])
507+
target_dtype = torch.float8_e4m3fn
508+
509+
# input settings
510+
def get_per_token_block_size(x):
511+
block_size = list(x.shape)
512+
for i in range(len(block_size)-1):
513+
block_size[i] = 1
514+
return block_size
515+
516+
input_target_dtype = torch.float8_e4m3fn
517+
layout_type = Float8LayoutType()
518+
input_quant_func = lambda x: to_affine_quantized_floatx(
519+
input_float=x,
520+
block_size=get_per_token_block_size(x),
521+
target_dtype=input_target_dtype,
522+
layout_type=layout_type
523+
)
524+
525+
block_size = get_weight_block_size(weight)
526+
weight = to_affine_quantized_floatx(
527+
input_float=weight,
528+
block_size=block_size,
529+
target_dtype=target_dtype,
530+
layout_type=layout_type
531+
)
532+
weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
533+
return weight
534+
535+
@classmethod
536+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
537+
"""
538+
Tests and benchmarks the autoquantization process with special handling for interpolate mode.
539+
540+
Args:
541+
act_mat (torch.Tensor): The activation matrix.
542+
weight (torch.Tensor): The weight tensor.
543+
bias (torch.Tensor or None): The bias tensor.
544+
best_time (float): The best time to beat for the quantization process.
545+
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
546+
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
547+
548+
Returns:
549+
float: The benchmarked time for the autoquantization process.
550+
"""
551+
if not _is_interpolate_mode(mode):
552+
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
553+
554+
# SAM best is between .8 and 1, SDXL also performs best in this range
555+
INTERPOLATION_CONSTANT = mode[1]
556+
w_qtensor = cls.from_float(weight)
557+
x_vals_float8, x_scales = quantize_activation_per_token_absmax(
558+
act_mat.reshape(-1, act_mat.shape[-1])
559+
)
560+
quantized_matmul = (
561+
lambda x_vals_float8, x_scales, w_vals_float8:
562+
safe_int_mm(x_vals_float8, w_vals_float8) * x_scales
563+
)
564+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
565+
with torch.no_grad():
566+
w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t()
567+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8)
568+
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
569+
570+
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
571+
if res_matmul>=best_time:
572+
return res_matmul
573+
574+
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
575+
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
576+
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
577+
max_float_const_win = (best_time-res_matmul)/(res-res_matmul)
578+
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
579+
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}")
580+
return res_f
581+
495582

496583
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
497584
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -511,6 +598,7 @@ def from_float(cls, weight):
511598

512599
OTHER_AUTOQUANT_CLASS_LIST = [
513600
AQFloat8WeightOnlyQuantizedLinearWeight,
601+
AQFloat8DynamicallyQuantizedLinearWeight,
514602
]
515603

516604

0 commit comments

Comments
 (0)