Skip to content

Commit

Permalink
Quantization Lifecycle Implementation (#5)
Browse files Browse the repository at this point in the history
* draft

* add memoryless

* run bin.quant

* before tests, correctness verified

* specify sparszoo version

* remove sparsezoo

Co-authored-by: Benjamin Fineran <benjaminfineran@gmail.com>
  • Loading branch information
horheynm and bfineran authored Apr 12, 2024
1 parent 318ad21 commit 373ff85
Show file tree
Hide file tree
Showing 12 changed files with 577 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"]

def _setup_extras() -> Dict:
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]}
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0",]}

setup(
name="sparsetensors",
Expand Down
21 changes: 21 additions & 0 deletions src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
43 changes: 43 additions & 0 deletions src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module


__all__ = [
"set_module_for_calibration",
]


_LOGGER = logging.getLogger(__name__)


def set_module_for_calibration(module: Module):
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return
status = getattr(module, "quantization_status", None)
if not status or status != QuantizationStatus.INITIALIZED:
raise _LOGGER.warning(
f"Attempting set module with status {status} to calibration mode. "
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
"be calibrating an uninitialized module which may fail or attempting "
"to re-calibrate a frozen module"
)

module.quantization_status = QuantizationStatus.CALIBRATION
127 changes: 127 additions & 0 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


__all__ = ["wrap_module_forward_quantized"]


def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
0,
q_max,
)


def dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
) -> torch.Tensor:
return (x_q - zero_point) * scale


def fake_quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, max_q)
return dequantize(Q, scale, zero_point)


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
# expects a module already initialized and injected with the parameters in
# initialize_module_for_quantization
forward_func_orig = module.forward.__func__

@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
input_ = args[0]

if scheme.input_activations is not None:
# calibrate and (fake) quantize input activations when applicable
input_ = _maybe_calibrate_or_quantize(
module, input_, "input", scheme.input_activations
)

if scheme.weights is not None:
# calibrate and (fake) quantize weights when applicable
self.weight.data = _maybe_calibrate_or_quantize(
module, self.weight, "weight", scheme.weights
)

# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
input_, *args[1:], **kwargs
)

if scheme.output_activations is not None:
# calibrate and (fake) quantize output activations when applicable
output = _maybe_calibrate_or_quantize(
module, output, "output", scheme.output_activations
)

return output

# bind wrapped forward to module class so reference to `self` is correct
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
# set forward to wrapped forward
setattr(module, "forward", bound_wrapped_forward)


def _maybe_calibrate_or_quantize(
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# only run quantized for the included stages
if module.quantization_status not in {
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
}:
return value

scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

print(scale, zero_point)

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale
zero_point.data = updated_zero_point

return fake_quantize(value, scale, zero_point, args)
36 changes: 36 additions & 0 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module


__all__ = [
"freeze_module_quantization",
]


def freeze_module_quantization(module: Module):
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return

# delete observers from module
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)

module.quantization_status = QuantizationStatus.FROZEN
71 changes: 71 additions & 0 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter


__all__ = [
"initialize_module_for_quantization",
]


_LOGGER = logging.getLogger(__name__)


def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:

_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
"has no attribute weight, skipping weight quantization "
f"for {type(module)}"
)
if scheme.output_activations is not None:
_initialize_scale_zero_point_observer(
module, "output", scheme.output_activations
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

# wrap forward call of module to perform quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)


def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initializes empty scale and zero point parameters for the module
init_scale = Parameter(torch.empty(0), requires_grad=False)
module.register_parameter(f"{base_name}_scale", init_scale)

init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
26 changes: 26 additions & 0 deletions src/sparsetensors/quantization/lifecycle/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
19 changes: 19 additions & 0 deletions src/sparsetensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
from .memoryless import *
from .min_max import *
Loading

0 comments on commit 373ff85

Please sign in to comment.