diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 2ea1638c9..409385006 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -28,6 +28,7 @@ from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, is_fbcode, ) @@ -98,7 +99,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( .to(torch.int32) .reshape_as(w) ) - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if TORCH_VERSION_AFTER_2_5: + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 5c9d225d0..9e783b5ce 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -24,6 +24,7 @@ ) from typing import ClassVar from dataclasses import dataclass +from torchao.utils import TORCH_VERSION_AFTER_2_5 aten = torch.ops.aten @@ -500,8 +501,11 @@ def from_plain( layout_type: LayoutType ): assert isinstance(layout_type, TensorCoreTiledLayoutType) - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype" + if TORCH_VERSION_AFTER_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + else: + assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index aaa137bda..1e8c5fc38 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,6 +12,7 @@ from hqq.core.utils import * import torch.nn.functional as F +from torchao.utils import TORCH_VERSION_AFTER_2_5 class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -198,7 +199,8 @@ def hqq_quants_to_torch_quants( .reshape(shape) .contiguous() ) - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if TORCH_VERSION_AFTER_2_5: + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 4c0c78a8b..2442e3d5c 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -17,6 +17,7 @@ dequantize_affine, int_scaled_matmul, ) +from torchao.utils import TORCH_VERSION_AFTER_2_5 __all__ = [ "compute_error", @@ -349,7 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max = 2 ** n_bit - 1 int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + if TORCH_VERSION_AFTER_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data def groupwise_affine_dequantize_tensor_from_qparams(