diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 47dca276..2c393c70 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,9 +13,13 @@ # limitations under the License. from functools import wraps +from math import ceil import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module @@ -32,10 +36,9 @@ def quantize( q_min: torch.Tensor, q_max: torch.Tensor, ) -> torch.Tensor: + return torch.clamp( - torch.round( - x / scale + zero_point, - ), + torch.round(x / scale + zero_point), q_min, q_max, ) @@ -57,12 +60,88 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: + """ + Fake quantize the input tensor x depending on the group_size. + if group_size is greater than 0, then q/dq by groups. The groups + must be divisible by the column size + if group_size is -1, then channel wise q/dq. THe input scale and + zero_points are reshaped to support vectorization (Assumes 1 is + the channel dimension) + + :param x: Input tensor + :param scale: scale tensor + :param zero_point: zero point tensor + :param args: quantization args that contain group_size info + :return: fake quantized tensor + + """ bit_range = 2**args.num_bits max_q = torch.tensor(bit_range / 2 - 1, device=x.device) min_q = torch.tensor(-bit_range / 2, device=x.device) - Q = torch.zeros_like(x) - Q = quantize(x, scale, zero_point, min_q, max_q) - return dequantize(Q, scale, zero_point) + + group_size = args.group_size + + # group + if args.strategy == QuantizationStrategy.GROUP: + + DQ = torch.zeros_like(x) + + # TODO: vectorize the for loop + # TODO: fix genetric assumption about the tensor size for computing group + + # TODO: make validation step for inputs + + while scale.ndim < 2: + # pad scale and zero point dims for slicing + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) + + columns = x.shape[1] + if columns >= group_size: + if columns % group_size != 0: + raise ValueError( + "tesnor column shape must be divisble " + f"by the given group_size {group_size}" + ) + for i in range(ceil(columns / group_size)): + # scale.shape should be [nchan, ndim] + # sc.shape should be [nchan, 1] after unsqueeze + + sc = scale[:, i].unsqueeze(1) + zp = zero_point[:, i].unsqueeze(1) + + idx = i * group_size + Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q) + DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp) + + # channel-wise + elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 + # before: scale shape = [channel_size] + # after: scale shape = [1, channel_size] + scale = scale.unsqueeze(0) + zero_point = zero_point.unsqueeze(0) + + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + + # per-token + elif args.strategy == QuantizationStrategy.TOKEN: + # before: scale shape = [num_tokens] + # after: scale shape = [num_tokens, 1] + # x.shape = 1, num_tokens, 1] + # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape + + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) + + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + + else: + Q = quantize(x, scale, zero_point, min_q, max_q) + DQ = dequantize(Q, scale, zero_point) + + return DQ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): @@ -138,5 +217,4 @@ def _maybe_calibrate_or_quantize( device = next(module.parameters()).device scale.data = updated_scale.to(device) zero_point.data = updated_zero_point.to(device) - return fake_quantize(value, scale, zero_point, args) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 96fe1049..87d7c0e2 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -14,7 +14,11 @@ from typing import Optional, Tuple -from compressed_tensors.quantization.quant_args import QuantizationArgs +import torch +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.registry.registry import RegistryMixin from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module @@ -52,6 +56,12 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + def post_calculate_qparams(self) -> None: + """ + Run any logic specific to its observers after running calculate_qparams + """ + ... + def get_qparams( self, observed: Optional[Tensor] = None ) -> Tuple[FloatTensor, IntTensor]: @@ -64,6 +74,57 @@ def get_qparams( :return: tuple of scale and zero point based on last observed value """ if observed is not None: - # re-calcualte scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) + group_size = self.quantization_args.group_size + + if self.quantization_args.strategy == QuantizationStrategy.TENSOR: + + # re-calculate scale and zero point, update the stored value + self._scale, self._zero_point = self.calculate_qparams(observed) + + elif self.quantization_args.strategy == QuantizationStrategy.GROUP: + columns = observed.shape[1] + scales, zero_points = [], [] + for i in range(0, columns, self.quantization_args.group_size): + scale, zero_point = self.get_qparams_along_dim( + observed[:, i : (i + group_size)], + 0, + ) + scales.append(scale) + zero_points.append(zero_point) + + self._scale = torch.stack(scales, dim=1) + self._zero_point = torch.stack(zero_points, dim=1) + + elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: + # assume observed is transposed, because its the output, hence use dim 0 + self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + + elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: + + # use dim 1, assume the obsersed.shape = [batch, token, hidden] + # should be batch, token + + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, dim=1 + ) + return self._scale, self._zero_point + + def get_qparams_along_dim(self, observed, dim: int): + # TODO: add documentation that specifies the shape must + # be padded with 1-dims so the scales are along the right channel + # TODO: generalize the logic for reduce_dims + scales, zero_points = [], [] + + # TODO: make a more generic way to get the channel + num_dims = observed.shape[dim] + + for dim_idx in range(num_dims): + scale, zero_point = self.calculate_qparams( + observed.select(dim=dim, index=dim_idx) + ) + + scales.append(scale) + zero_points.append(zero_point) + # breakpoint() + return torch.stack(scales), torch.stack(zero_points) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index acaa00d5..f8c82d8a 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] @@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum): CHANNEL = "channel" GROUP = "group" BLOCK = "block" + TOKEN = "token" class QuantizationArgs(BaseModel): @@ -63,8 +64,8 @@ class QuantizationArgs(BaseModel): num_bits: int = 8 type: QuantizationType = QuantizationType.INT symmetric: bool = True - strategy: QuantizationStrategy = QuantizationStrategy.TENSOR group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False observer: str = Field( @@ -94,3 +95,31 @@ def get_observer(self): self.observer = "memoryless" return Observer.load_from_registry(self.observer, quantization_args=self) + + @validator("strategy", pre=True, always=True) + def validate_strategy(cls, value, values): + group_size = values.get("group_size") + + # use group_size to determinine strategy if not given explicity + if group_size is not None and value is None: + if group_size > 0: + return QuantizationStrategy.GROUP + + elif group_size == -1: + return QuantizationStrategy.CHANNEL + + else: + raise ValueError( + f"group_size={group_size} with strategy {value} is invald. " + "group_size > 0 for strategy='group' and " + "group_size = -1 for 'channel'" + ) + + if value == QuantizationStrategy.GROUP: + if group_size is None: + raise ValueError(f"strategy {value} requires group_size to be set.") + + if value is None: + return QuantizationStrategy.TENSOR + + return value diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3c00cdbe..8676ef15 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float: compressed_bits = uncompressed_bits if is_module_quantized(submodule): compressed_bits = submodule.quantization_scheme.weights.num_bits + num_weights = parameter.numel() total_compressed += compressed_bits * num_weights total_uncompressed += uncompressed_bits * num_weights