Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observer Restructure]: Remove MemoryLess Observer; use helper function for dynamic quantization #187

Merged
merged 7 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import torch
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
from compressed_tensors.quantization.observers.helpers import calculate_range
from compressed_tensors.quantization.observers.helpers import (
calculate_range,
compute_memoryless_zp_and_scales,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
Expand Down Expand Up @@ -376,9 +379,8 @@ def maybe_calibrate_or_quantize(
g_idx = getattr(module, "weight_g_idx", None)

if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
observer = getattr(module, f"{base_name}_observer")
scale, zero_point = observer(value, g_idx=g_idx)
# dynamic quantization - no need to invoke observer
scale, zero_point = compute_memoryless_zp_and_scales(value=value, args=args)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
Expand Down
13 changes: 7 additions & 6 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,16 @@ def _initialize_scale_zero_point_observer(
weight_shape: Optional[torch.Size] = None,
force_zero_point: bool = True,
):

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
# no need to register an observer for dynamic quantization
if observer:
module.register_module(f"{base_name}_observer", observer)

# no need to register a scale and zero point for a dynamic quantization
if quantization_args.dynamic:
return # no need to register a scale and zero point for a dynamic observer
return

device = next(module.parameters()).device
if is_module_offloaded(module):
Expand All @@ -173,10 +177,7 @@ def _initialize_scale_zero_point_observer(
expected_shape = (weight_shape[0], 1)
elif quantization_args.strategy == QuantizationStrategy.GROUP:
num_groups = weight_shape[1] // quantization_args.group_size
expected_shape = (
weight_shape[0],
max(num_groups, 1)
)
expected_shape = (weight_shape[0], max(num_groups, 1))

scale_dtype = module.weight.dtype
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@

from .helpers import *
from .base import *
from .memoryless import *
from .min_max import *
from .mse import *
41 changes: 39 additions & 2 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,55 @@
# limitations under the License.

from collections import Counter
from typing import Tuple
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import (
FP8_DTYPE,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
__all__ = [
"calculate_qparams",
"get_observer_token_count",
"calculate_range",
"compute_memoryless_zp_and_scales",
]


def compute_memoryless_zp_and_scales(value: Tensor, args: QuantizationArgs):
dsikka marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the min and max values of observed tensor

:param value: tensor to calculate quantization parameters for
:param args: quantization args
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
if args.strategy == QuantizationStrategy.TOKEN:
dim = {1, 2}
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
elif args.strategy == QuantizationStrategy.TENSOR:
reduce_dims = None
else:
raise ValueError(
dsikka marked this conversation as resolved.
Show resolved Hide resolved
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
"must be used for dynamic quantization",
)

if not reduce_dims:
min_val, max_val = torch.aminmax(value)
else:
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)

return calculate_qparams(min_val, max_val, args)


def get_observer_token_count(module: torch.nn.Module) -> Counter:
Expand Down
56 changes: 0 additions & 56 deletions src/compressed_tensors/quantization/observers/memoryless.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field(
observer: Optional[str] = Field(
default="minmax",
dsikka marked this conversation as resolved.
Show resolved Hide resolved
description=(
"The class to use to compute the quantization param - "
Expand All @@ -115,10 +115,10 @@ def get_observer(self):
"""
from compressed_tensors.quantization.observers.base import Observer

# No observer required for the dynamic case
if self.dynamic:
# override defualt observer for dynamic, you never want minmax which
# keeps state across samples for dynamic
self.observer = "memoryless"
self.observer = None
dsikka marked this conversation as resolved.
Show resolved Hide resolved
return self.observer

return Observer.load_from_registry(self.observer, quantization_args=self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _test_layer_dynamic_quantization_status(
# check inputs always have an observer if quantized but never scale/zp
assert not hasattr(module, "input_scale")
assert not hasattr(module, "input_zero_point")
assert hasattr(module, "input_observer") == inputs
assert not hasattr(module, "input_observer")

# check weights always have scale/zp and observer only if not frozen
assert hasattr(module, "weight_scale") == weights
Expand Down
Loading