From 3163dea4a5185d92082aadb3966a0af8088bb380 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 17 May 2024 17:29:31 +0000 Subject: [PATCH 1/2] speed fix --- .../quantization/lifecycle/forward.py | 25 +------------------ .../quantization/observers/base.py | 19 ++------------ .../quantization/observers/memoryless.py | 24 +++++++++++------- .../quantization/observers/min_max.py | 4 --- .../test_configs/test_strategies.py | 1 - 5 files changed, 18 insertions(+), 55 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e4c627c4..25e6b076 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -150,7 +150,6 @@ def _process_quantization( q_min = torch.tensor(-bit_range / 2, device=x.device) group_size = args.group_size - # group if args.strategy == QuantizationStrategy.GROUP: if do_dequantize: # if dequantizing the output should be a fp type @@ -195,29 +194,7 @@ def _process_quantization( ) output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp) - # channel-wise - elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1 - if do_quantize: - output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype) - if do_dequantize: - output = _dequantize(output if do_quantize else x, scale, zero_point) - - # per-token - elif args.strategy == QuantizationStrategy.TOKEN: - # before: scale shape = [num_tokens] - # after: scale shape = [num_tokens, 1] - # x.shape = 1, num_tokens, 1] - # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape - - scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) - - if do_quantize: - output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype) - if do_dequantize: - output = _dequantize(output if do_quantize else x, scale, zero_point) - - else: + else: # covers channel, token and tensor strategies if do_quantize: output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype) if do_dequantize: diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index 93d229d5..d96fabc0 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -111,20 +111,5 @@ def get_qparams( return self._scale, self._zero_point def get_qparams_along_dim(self, observed, dim: int): - # TODO: add documentation that specifies the shape must - # be padded with 1-dims so the scales are along the right channel - # TODO: generalize the logic for reduce_dims - scales, zero_points = [], [] - - # TODO: make a more generic way to get the channel - num_dims = observed.shape[dim] - - for dim_idx in range(num_dims): - scale, zero_point = self.calculate_qparams( - observed.select(dim=dim, index=dim_idx) - ) - - scales.append(scale) - zero_points.append(zero_point) - # breakpoint() - return torch.stack(scales), torch.stack(zero_points) + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) + return self.calculate_qparams(observed, reduce_dims=reduce_dims) diff --git a/src/compressed_tensors/quantization/observers/memoryless.py b/src/compressed_tensors/quantization/observers/memoryless.py index 4287452b..9abfbc74 100644 --- a/src/compressed_tensors/quantization/observers/memoryless.py +++ b/src/compressed_tensors/quantization/observers/memoryless.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import torch from compressed_tensors.quantization.observers.base import Observer @@ -30,19 +30,25 @@ class MemorylessObserver(Observer): zero point based on the latest observed value without tracking state """ - def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + def calculate_qparams( + self, + observed: Tensor, + reduce_dims: Optional[Tuple[int]] = None, + ) -> Tuple[FloatTensor, IntTensor]: """ - Returns the min and max values of observed + Returns the min and max values of observed tensor :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions :return: tuple of scale and zero point derived from the observed tensor """ - # TODO: Add support for full range of quantization Args, only supports 8bit - # per tensor - min_val, max_val = torch.aminmax(observed) - # ensure zero is in the range - min_val = torch.min(min_val, torch.zeros_like(min_val)) - max_val = torch.max(max_val, torch.zeros_like(max_val)) + if not reduce_dims: + min_val, max_val = torch.aminmax(observed) + else: + min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) + max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) return calculate_qparams(min_val, max_val, self.quantization_args) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index a754ac5c..1e8f1ddf 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -74,7 +74,3 @@ def calculate_qparams( ) return calculate_qparams(self.min_val, self.max_val, self.quantization_args) - - def get_qparams_along_dim(self, observed, dim: int): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams(observed, reduce_dims=reduce_dims) diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index 4e8a2ca5..fc48ba60 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -22,7 +22,6 @@ QuantizationStrategy, apply_quantization_config, ) -from compressed_tensors.quantization.lifecycle.forward import fake_quantize from torch.nn import Linear From 87d32050de19b31adbc65baad975bdddc76bae5c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 17 May 2024 19:22:21 +0000 Subject: [PATCH 2/2] update calculate_qparams sig --- .../quantization/observers/base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/base.py b/src/compressed_tensors/quantization/observers/base.py index d96fabc0..0297ff7a 100644 --- a/src/compressed_tensors/quantization/observers/base.py +++ b/src/compressed_tensors/quantization/observers/base.py @@ -50,9 +50,16 @@ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ return self.get_qparams(observed=observed) - def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: + def calculate_qparams( + self, + observed: Tensor, + reduce_dims: Optional[Tuple[int]] = None, + ) -> Tuple[FloatTensor, IntTensor]: """ :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions :return: tuple of scale and zero point derived from the observed tensor """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") @@ -70,6 +77,7 @@ def get_qparams( Convenience function to wrap overwritten calculate_qparams adds support to make observed tensor optional and support for tracking latest calculated scale and zero point + :param observed: optional observed tensor to calculate quantization parameters from :return: tuple of scale and zero point based on last observed value @@ -100,10 +108,8 @@ def get_qparams( self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - # use dim 1, assume the obsersed.shape = [batch, token, hidden] # should be batch, token - self._scale, self._zero_point = self.get_qparams_along_dim( observed, dim=1 )