From 8dfbf524f9fa5383767a3254452505b2cba937e2 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 1 Apr 2025 15:53:27 +0530 Subject: [PATCH 1/2] added int8 realquant --- .../nn/modules/tensor_quantizer.py | 6 + .../torch/quantization/qtensor/__init__.py | 1 + .../torch/quantization/qtensor/int8_tensor.py | 104 ++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 modelopt/torch/quantization/qtensor/int8_tensor.py diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index e74988ab3..86d0c1d9d 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -36,6 +36,7 @@ from ...qtensor import ( BaseQuantizedTensor, FP8QTensor, + INT8QTensor, INT4QTensor, NF4QTensor, NVFP4QTensor, @@ -494,6 +495,11 @@ def _real_quantize(self, inputs): inputs, axis=self._axis, block_sizes=self._block_sizes ) buffer_to_register["_scale"] = _scale + elif self._num_bits == 8: + outputs, _scale = INT8QTensor.quantize( + inputs, axis=self._axis, block_sizes=self._block_sizes + ) + buffer_to_register["_scale"] = _scale elif self._block_sizes.get("scale_bits", 0) == 8 and self._block_sizes.get( "scale_block_sizes", None ): diff --git a/modelopt/torch/quantization/qtensor/__init__.py b/modelopt/torch/quantization/qtensor/__init__.py index 56c0247fd..8c65e4c55 100644 --- a/modelopt/torch/quantization/qtensor/__init__.py +++ b/modelopt/torch/quantization/qtensor/__init__.py @@ -19,5 +19,6 @@ from .base_qtensor import * from .fp8_tensor import * +from .int8_tensor import * from .int4_tensor import * from .nf4_tensor import * diff --git a/modelopt/torch/quantization/qtensor/int8_tensor.py b/modelopt/torch/quantization/qtensor/int8_tensor.py new file mode 100644 index 000000000..7252a2bd8 --- /dev/null +++ b/modelopt/torch/quantization/qtensor/int8_tensor.py @@ -0,0 +1,104 @@ +from typing import Union + +import torch + +from ..qtensor.base_qtensor import BaseQuantizedTensor +from ..utils import reduce_amax, reduce_block_amax, reduce_block_padding + + +class INT8QTensor(BaseQuantizedTensor): + """Implements the INT8 quantization on tensors for more efficient storage or computation. + + Attributes: + quantized_data (torch.Tensor): The quantized data stored as an INT8 tensor. + """ + + @classmethod + def quantize( + cls, + input: torch.Tensor, + scales: torch.Tensor = None, + axis: Union[tuple, int, None] = None, + block_sizes: dict = None, + ) -> tuple: + """Converting a tensor to a quantized format based on INT8 quantization. + + Args: + input (torch.Tensor): The input tensor to be quantized. + scales (torch.Tensor): The scales for quantization. + axis: The dimensions to reduce for quantization. None or int or tuple of ints. + block_sizes (dict): A dictionary specifying the block size for each dimension. + + Note: One can only provide axis or block_sizes for INT8 quantization. + + Returns: + tuple: INT8QTensor, scales + """ + original_input = input + if scales is None: + if block_sizes: + input = reduce_block_padding(input, block_sizes) + amax = reduce_block_amax(input, block_sizes) + else: + amax = reduce_amax(input, axis=axis) + scales = amax / 127.0 + + # Calculate the scale shape and make sure it aligns with input and block_sizes + expected_shape = list(input.shape) + expanded_scales = scales.clone() + if block_sizes: + for dim, block_size in block_sizes.items(): + dim = dim if dim >= 0 else len(input.shape) + dim # Convert negative index + assert input.shape[dim] % block_size == 0, ( + f"Tensor dimension {dim}, {input.shape[dim]} is not divisible by {block_size}." + ) + expected_shape[dim] = ( + input.shape[dim] // block_size + ) # Adjust expected shape for blocks + + # Assert the shape of `scales` matches expected reduced dimensions + assert scales.shape == tuple(expected_shape), ( + f"Mismatch in expected scale shape: {scales.shape} vs {tuple(expected_shape)}" + ) + + # Expand scales for broadcasting + for dim, block_size in block_sizes.items(): + expanded_scales = expanded_scales.repeat_interleave(block_size, dim=dim) + + # Quantization + quantized_data = (input / expanded_scales).round().clamp(-128, 127).to(torch.int8) + + return cls(original_input.shape, original_input.dtype, quantized_data), scales + + def dequantize(self, dtype: torch.dtype = None, **kwarg): + """Dequantize INT8 packed tensor to a target dtype.""" + if dtype is None: + dtype = self.metadata["dtype"] + assert "scale" in kwarg, "Require scale for INT8 dequantization." + + # Get args + scales = kwarg["scale"] + block_sizes = kwarg.get("block_sizes", None) + + shape = self._quantized_data.shape + if block_sizes: + # Compute expanded shape for broadcasting scales + expanded_shape = list(shape) + for dim, block_size in block_sizes.items(): + assert shape[dim] % block_size == 0, ( + f"Dimension {shape[dim]} is not divisible by {block_size}." + ) + expanded_shape[dim] //= block_size # Reduce the dimension size for blocks + + assert tuple(expanded_shape) == scales.shape, ( + f"Scales shape {scales.shape} must match expected {tuple(expanded_shape)}." + ) + + # Expand scales for broadcasting + for dim, block_size in block_sizes.items(): + scales = scales.repeat_interleave(block_size, dim=dim) + + # Handle padded tensors + slices = tuple(slice(0, dim) for dim in self.metadata["shape"]) + + return (self._quantized_data.to(torch.int8) * scales.to(dtype))[slices] From e61780ca27f28178632aff88b8bf230c40aadcf6 Mon Sep 17 00:00:00 2001 From: ishan-modi Date: Tue, 1 Apr 2025 16:27:58 +0530 Subject: [PATCH 2/2] update --- modelopt/torch/quantization/qtensor/int8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/qtensor/int8_tensor.py b/modelopt/torch/quantization/qtensor/int8_tensor.py index 7252a2bd8..80c451acd 100644 --- a/modelopt/torch/quantization/qtensor/int8_tensor.py +++ b/modelopt/torch/quantization/qtensor/int8_tensor.py @@ -101,4 +101,4 @@ def dequantize(self, dtype: torch.dtype = None, **kwarg): # Handle padded tensors slices = tuple(slice(0, dim) for dim in self.metadata["shape"]) - return (self._quantized_data.to(torch.int8) * scales.to(dtype))[slices] + return (self._quantized_data.view(torch.int8).to(dtype) * scales.to(dtype))[slices]