diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 8e047985c..544660900 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -73,6 +73,7 @@ AQInt8WeightOnlyQuantizedLinearWeight3, AutoQuantizableLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8DynamicallyQuantizedLinearWeight, ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -753,6 +754,14 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") + @unittest.skipIf(not is_H100, "Need H100 to run") + def test_aq_float8_dynamic_quant_subclass(self, device, dtype): + self._test_lin_weight_subclass_impl( + AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype + ) + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 81e7b19b1..885143239 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -37,8 +37,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): - return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - + try: + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + except Exception: + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) # error checking for cublas path assert ( mat2.device == input.device @@ -53,7 +57,6 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 089add1d8..19780088e 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -221,7 +221,6 @@ def do_autoquant_bench(op, *args, **kwargs): stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) @@ -492,6 +491,92 @@ def from_float(cls, weight): block_size = (1, weight.shape[1]) return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType()) +class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): + """ + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight + """ + @classmethod + def from_float(cls, weight): + + # avoid circular dep + from torchao.dtypes import to_affine_quantized_floatx + # weight settings + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.float8_e4m3fn + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_target_dtype = torch.float8_e4m3fn + layout_type = Float8LayoutType() + input_quant_func = lambda x: to_affine_quantized_floatx( + input_float=x, + block_size=get_per_token_block_size(x), + target_dtype=input_target_dtype, + layout_type=layout_type + ) + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + layout_type=layout_type + ) + weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func) + return weight + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + """ + Tests and benchmarks the autoquantization process with special handling for interpolate mode. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. + """ + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_float8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_float8, x_scales, w_vals_float8: + safe_int_mm(x_vals_float8, w_vals_float8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t() + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8) + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_float_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}") + return res_f + # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ @@ -511,6 +596,7 @@ def from_float(cls, weight): OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8DynamicallyQuantizedLinearWeight, ]