Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix per_token slowdown #57

Merged
merged 2 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _process_quantization(
q_min = torch.tensor(-bit_range / 2, device=x.device)
group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

if do_dequantize: # if dequantizing the output should be a fp type
Expand Down Expand Up @@ -195,29 +194,7 @@ def _process_quantization(
)
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

else:
else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
Expand Down
31 changes: 11 additions & 20 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
return self.get_qparams(observed=observed)

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
Expand All @@ -70,6 +77,7 @@ def get_qparams(
Convenience function to wrap overwritten calculate_qparams
adds support to make observed tensor optional and support for tracking latest
calculated scale and zero point

:param observed: optional observed tensor to calculate quantization parameters
from
:return: tuple of scale and zero point based on last observed value
Expand Down Expand Up @@ -100,31 +108,14 @@ def get_qparams(
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 15 additions & 9 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -30,19 +30,25 @@ class MemorylessObserver(Observer):
zero point based on the latest observed value without tracking state
"""

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Returns the min and max values of observed
Returns the min and max values of observed tensor

:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
min_val, max_val = torch.aminmax(observed)

# ensure zero is in the range
min_val = torch.min(min_val, torch.zeros_like(min_val))
max_val = torch.max(max_val, torch.zeros_like(max_val))
if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

return calculate_qparams(min_val, max_val, self.quantization_args)
4 changes: 0 additions & 4 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,3 @@ def calculate_qparams(
)

return calculate_qparams(self.min_val, self.max_val, self.quantization_args)

def get_qparams_along_dim(self, observed, dim: int):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
1 change: 0 additions & 1 deletion tests/test_quantization/test_configs/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
QuantizationStrategy,
apply_quantization_config,
)
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
from torch.nn import Linear


Expand Down
Loading