Skip to content

Commit

Permalink
[StaticQuant] Update how block_size is calculated with Observers (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored and HDCharles committed Sep 8, 2024
1 parent ea15567 commit 1cc9245
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 13 deletions.
74 changes: 74 additions & 0 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import torch
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
Expand Down Expand Up @@ -34,6 +35,79 @@ def test_min_max_per_channel_affine(self):
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
self._test_obs_helper(obs, ref_obs)

def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2048),
torch.randn(7, 2048),
]
for example_input in example_inputs:
obs(example_input)

obs.calculate_qparams()

obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
for example_input in example_inputs:
obs(example_input)

obs.calculate_qparams()

def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2048),
]
expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([10]) != min_val:torch.Size([9])"
escaped_error_msg = re.escape(expected_error_msg)
with self.assertRaisesRegex(AssertionError, escaped_error_msg):
for example_input in example_inputs:
obs(example_input)

obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2047),
]
expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([2048]) != min_val:torch.Size([2047])"
escaped_error_msg = re.escape(expected_error_msg)
with self.assertRaisesRegex(AssertionError, escaped_error_msg):
for example_input in example_inputs:
obs(example_input)


if __name__ == "__main__":
unittest.main()
62 changes: 50 additions & 12 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,40 @@

@dataclass(frozen=True)
class GranularityType:
"""
Base class for representing the granularity of quantization.
This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass

@dataclass(frozen=True)
class PerTensor(GranularityType):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass

@dataclass(frozen=True)
class PerAxis(GranularityType):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

# borrowed from torch.ao.quantization.observer
Expand Down Expand Up @@ -59,7 +85,16 @@ def _with_args(cls_or_self, *args, **kwargs):
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
return r

def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]:

def get_block_size(
input_shape: Tuple[int, ...], granularity_type: GranularityType
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.
Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity_type: The granularity type of the quantization
"""
if isinstance(granularity_type, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
Expand All @@ -84,8 +119,7 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module):
def __init__(self,
mapping_type: MappingType,
target_dtype: torch.dtype,
block_size: Optional[Tuple[int, ...]] = None,
granularity_type: Optional[GranularityType] = None,
granularity_type: GranularityType,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
Expand All @@ -95,12 +129,10 @@ def __init__(self,
zero_point_domain = ZeroPointDomain.INT,
):
super().__init__()
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type"
if block_size is not None and granularity_type is not None:
logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}")
assert granularity_type is not None, "granularity_type is None"

self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.block_size = block_size
self.granularity_type = granularity_type
self.quant_min = quant_min
self.quant_max = quant_max
Expand Down Expand Up @@ -130,17 +162,21 @@ def forward(self, input: torch.Tensor):
return input

input_detached = input.detach()
if self.block_size is None:
self.block_size = get_block_size(input_detached.shape, self.granularity_type)
assert self.granularity_type is not None, "granularity_type is None"
block_size = get_block_size(input_detached.shape, self.granularity_type)

shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size())
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input_detached.size()
)
input_detached = input_detached.view(shape_for_reduction)
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
self.min_val = min_val
self.max_val = max_val
else:
assert self.min_val.shape == min_val.shape, f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
assert self.max_val.shape == max_val.shape, f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
Expand All @@ -149,12 +185,14 @@ def forward(self, input: torch.Tensor):
return input

def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
assert (
hasattr(self, "min_val") and hasattr(self, "max_val")
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
self.block_size,
[], # BlockSize is not needed because the min/max are already reduced
self.target_dtype,
self.quant_min,
self.quant_max,
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def choose_qparams_affine_with_min_max(
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name,
zero_point_domain.name if zero_point_domain is not None else None,
min_val,
max_val,
)
Expand Down

0 comments on commit 1cc9245

Please sign in to comment.