Skip to content

Commit

Permalink
[WIP] Dyanmic Quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Apr 17, 2024
1 parent ee6a913 commit 2395833
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 19 deletions.
30 changes: 17 additions & 13 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,22 @@ def _maybe_calibrate_or_quantize(
}:
return value

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)
observer = getattr(module, f"{base_name}_observer")
if observer.DYNAMIC:
# dynamic quantization - get scale and zero point directly from observer
scale, zero_point = observer(value)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# calibration mode - get new quant params from observer
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
7 changes: 5 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def freeze_module_quantization(module: Module):

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
for submodule_name, submodule in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
if getattr(submodule, "DYNAMIC", False):
continue # do not delete dynamic observers

# delete any non-dynamic observers that belong directly to this module
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)
Expand Down
11 changes: 7 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def initialize_module_for_quantization(
def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)

if observer.DYNAMIC:
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device

# initializes empty scale and zero point parameters for the module
Expand All @@ -88,7 +95,3 @@ def _initialize_scale_zero_point_observer(
torch.empty(0, device=device, dtype=int), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .base import *
from .memoryless import *
from .min_max import *
from .dynamic import *
3 changes: 3 additions & 0 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class Observer(Module, RegistryMixin):
pair
"""

# child classes should set to True if they are meant to be used as dynamic
DYNAMIC = False

def __init__(self, quantization_args: QuantizationArgs):
self.quantization_args: QuantizationArgs = quantization_args
super().__init__()
Expand Down
35 changes: 35 additions & 0 deletions src/compressed_tensors/quantization/observers/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.observers.base import Observer
from sparsetensors.quantization.observers.memoryless import MemorylessObserver


__all__ = ["DynamicObserver"]


@Observer.register("dynamic")
class DynamicObserver(MemorylessObserver):
"""
Values targted for a dyanmic observer do not require calibration,
this observer will persist in the model through the lifecycle, calculating
the quantization parameters on the fly for each observed Tensor.
This base dynamic observer uses the `calculate_qparams` from MemorylessObserver
where each scale and zero point is based solely on the currently observed
Tensor.
"""

DYNAMIC = False

0 comments on commit 2395833

Please sign in to comment.