Skip to content

Commit

Permalink
Fix per_token slowdown (#57)
Browse files Browse the repository at this point in the history
* speed fix

* update calculate_qparams sig
  • Loading branch information
Sara Adkins authored May 20, 2024
1 parent 88da7f7 commit f9d8d8b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 58 deletions.
25 changes: 1 addition & 24 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _process_quantization(
q_min = torch.tensor(-bit_range / 2, device=x.device)
group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

if do_dequantize: # if dequantizing the output should be a fp type
Expand Down Expand Up @@ -195,29 +194,7 @@ def _process_quantization(
)
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

else:
else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
Expand Down
31 changes: 11 additions & 20 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
return self.get_qparams(observed=observed)

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
Expand All @@ -70,6 +77,7 @@ def get_qparams(
Convenience function to wrap overwritten calculate_qparams
adds support to make observed tensor optional and support for tracking latest
calculated scale and zero point
:param observed: optional observed tensor to calculate quantization parameters
from
:return: tuple of scale and zero point based on last observed value
Expand Down Expand Up @@ -100,31 +108,14 @@ def get_qparams(
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
24 changes: 15 additions & 9 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -30,19 +30,25 @@ class MemorylessObserver(Observer):
zero point based on the latest observed value without tracking state
"""

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Returns the min and max values of observed
Returns the min and max values of observed tensor
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
min_val, max_val = torch.aminmax(observed)

# ensure zero is in the range
min_val = torch.min(min_val, torch.zeros_like(min_val))
max_val = torch.max(max_val, torch.zeros_like(max_val))
if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

return calculate_qparams(min_val, max_val, self.quantization_args)
4 changes: 0 additions & 4 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,3 @@ def calculate_qparams(
)

return calculate_qparams(self.min_val, self.max_val, self.quantization_args)

def get_qparams_along_dim(self, observed, dim: int):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
1 change: 0 additions & 1 deletion tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
QuantizationStrategy,
apply_quantization_config,
)
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
from torch.nn import Linear


Expand Down

0 comments on commit f9d8d8b

Please sign in to comment.