Skip to content

Commit

Permalink
[Observer Restructure]: Remove MemoryLess Observer; use helper functi…
Browse files Browse the repository at this point in the history
…on for dynamic quantization (#187)

* remove memoryless observer; use helper function for dynamic quantization

* update init

* clean-up

* update test case

* fix arg

* validation + update name

* update preset schemes; swap condition check
  • Loading branch information
dsikka authored Oct 11, 2024
1 parent b876a60 commit b2abe72
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 74 deletions.
10 changes: 6 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import torch
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
from compressed_tensors.quantization.observers.helpers import calculate_range
from compressed_tensors.quantization.observers.helpers import (
calculate_range,
compute_dynamic_scales_and_zp,
)
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
Expand Down Expand Up @@ -376,9 +379,8 @@ def maybe_calibrate_or_quantize(
g_idx = getattr(module, "weight_g_idx", None)

if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
observer = getattr(module, f"{base_name}_observer")
scale, zero_point = observer(value, g_idx=g_idx)
# dynamic quantization - no need to invoke observer
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
Expand Down
13 changes: 7 additions & 6 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,16 @@ def _initialize_scale_zero_point_observer(
weight_shape: Optional[torch.Size] = None,
force_zero_point: bool = True,
):

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
# no need to register an observer for dynamic quantization
if observer:
module.register_module(f"{base_name}_observer", observer)

# no need to register a scale and zero point for a dynamic quantization
if quantization_args.dynamic:
return # no need to register a scale and zero point for a dynamic observer
return

device = next(module.parameters()).device
if is_module_offloaded(module):
Expand All @@ -173,10 +177,7 @@ def _initialize_scale_zero_point_observer(
expected_shape = (weight_shape[0], 1)
elif quantization_args.strategy == QuantizationStrategy.GROUP:
num_groups = weight_shape[1] // quantization_args.group_size
expected_shape = (
weight_shape[0],
max(num_groups, 1)
)
expected_shape = (weight_shape[0], max(num_groups, 1))

scale_dtype = module.weight.dtype
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@

from .helpers import *
from .base import *
from .memoryless import *
from .min_max import *
from .mse import *
42 changes: 40 additions & 2 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,56 @@
# limitations under the License.

from collections import Counter
from typing import Tuple
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import (
FP8_DTYPE,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from torch import FloatTensor, IntTensor, Tensor


__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
__all__ = [
"calculate_qparams",
"get_observer_token_count",
"calculate_range",
"compute_dynamic_scales_and_zp",
]


def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
"""
Returns the computed scales and zero points for dynamic activation
qunatization.
:param value: tensor to calculate quantization parameters for
:param args: quantization args
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
if args.strategy == QuantizationStrategy.TOKEN:
dim = {1, 2}
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
elif args.strategy == QuantizationStrategy.TENSOR:
reduce_dims = None
else:
raise ValueError(
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
"must be used for dynamic quantization",
)

if not reduce_dims:
min_val, max_val = torch.aminmax(value)
else:
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)

return calculate_qparams(min_val, max_val, args)


def get_observer_token_count(module: torch.nn.Module) -> Counter:
Expand Down
56 changes: 0 additions & 56 deletions src/compressed_tensors/quantization/observers/memoryless.py

This file was deleted.

32 changes: 28 additions & 4 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from enum import Enum
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field(
observer: Optional[str] = Field(
default="minmax",
description=(
"The class to use to compute the quantization param - "
Expand All @@ -115,10 +116,10 @@ def get_observer(self):
"""
from compressed_tensors.quantization.observers.base import Observer

# No observer required for the dynamic case
if self.dynamic:
# override defualt observer for dynamic, you never want minmax which
# keeps state across samples for dynamic
self.observer = "memoryless"
self.observer = None
return self.observer

return Observer.load_from_registry(self.observer, quantization_args=self)

Expand Down Expand Up @@ -171,6 +172,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
strategy = model.strategy
group_size = model.group_size
actorder = model.actorder
dynamic = model.dynamic
observer = model.observer

# infer strategy
if strategy is None:
Expand Down Expand Up @@ -207,6 +210,27 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
"activation ordering"
)

if dynamic:
if strategy not in (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
):
raise ValueError(
f"One of {QuantizationStrategy.TOKEN} or "
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
"quantization",
)
if observer is not None:
warnings.warn(
"No observer is used for dynamic quantization, setting to None"
)
model.observer = None

# if we have not set an observer and we
# are running static quantization, use minmax
if not observer and not dynamic:
model.observer = "minmax"

# write back modified values
model.strategy = strategy
return model
Expand Down
3 changes: 3 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def is_preset_scheme(name: str) -> bool:
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)

Expand Down Expand Up @@ -164,6 +165,7 @@ def is_preset_scheme(name: str) -> bool:
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)

Expand Down Expand Up @@ -200,6 +202,7 @@ def is_preset_scheme(name: str) -> bool:
strategy=QuantizationStrategy.TOKEN,
symmetric=True,
dynamic=True,
observer=None,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _test_layer_dynamic_quantization_status(
# check inputs always have an observer if quantized but never scale/zp
assert not hasattr(module, "input_scale")
assert not hasattr(module, "input_zero_point")
assert hasattr(module, "input_observer") == inputs
assert not hasattr(module, "input_observer")

# check weights always have scale/zp and observer only if not frozen
assert hasattr(module, "weight_scale") == weights
Expand Down

0 comments on commit b2abe72

Please sign in to comment.