Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-Pick] Support Fake GroupWise Quant #61900

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions python/paddle/nn/quant/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ def from_quanter(quanter):


class LinearQuanter(Layer):
def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
def __init__(
self,
scales,
zero_point=None,
quant_axis=None,
bit_length=8,
group_size=128,
):
super().__init__()
scales = paddle.to_tensor(scales, dtype="float32")
scale_attr = paddle.framework.ParamAttr(
Expand All @@ -65,9 +72,21 @@ def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
)
self._quant_axis = -1 if quant_axis is None else quant_axis
self._bit_length = bit_length
self._group_size = group_size

def forward(self, input):
if in_dynamic_mode():
if len(self._scales.shape) > 1:
bnt = (1 << (self._bit_length - 1)) - 1
new_s = paddle.repeat_interleave(
self._scales, self._group_size, 0
)
quant_weight = paddle.clip(
paddle.round(input.cast('float32') / new_s * bnt),
-bnt - 1,
bnt,
)
return quant_weight.cast(input.dtype)
return _C_ops.quantize_linear(
input.cast('float32'),
self._scales,
Expand Down Expand Up @@ -105,7 +124,14 @@ def from_quanter(quanter):


class LinearDequanter(Layer):
def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
def __init__(
self,
scales,
zero_point=None,
quant_axis=None,
bit_length=8,
group_size=128,
):
super().__init__()
scales = paddle.to_tensor(scales, dtype="float32")
scale_attr = paddle.framework.ParamAttr(
Expand All @@ -124,9 +150,18 @@ def __init__(self, scales, zero_point=None, quant_axis=None, bit_length=8):
)
self._quant_axis = -1 if quant_axis is None else quant_axis
self._bit_length = bit_length
self._group_size = group_size

def forward(self, input):
if in_dynamic_mode():
if len(self._scales.shape) > 1:
bnt = (1 << (self._bit_length - 1)) - 1
new_s = paddle.repeat_interleave(
self._scales, self._group_size, 0
)
quant_dequant_weight = input.cast('float32') / bnt * new_s
return quant_dequant_weight.cast(input.dtype)

return _C_ops.dequantize_linear(
input.cast('float32'),
self._scales,
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
# limitations under the License.

from .abs_max import AbsmaxObserver
from .groupwise import GroupWiseWeightObserver

__all__ = ["AbsmaxObserver"]
__all__ = ["AbsmaxObserver", "GroupWiseWeightObserver"]
113 changes: 113 additions & 0 deletions python/paddle/quantization/observers/groupwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import paddle

from ..base_observer import BaseObserver
from ..factory import ObserverFactory


class GroupWiseWeightObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values of target weights.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver
quanter = AbsMaxChannelWiseWeightObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""

def __init__(self, quant_bits=8, group_size=128):
super().__init__(quant_bits=quant_bits)

def _get_class(self):
return GroupWiseWeightObserverLayer


class GroupWiseWeightObserverLayer(BaseObserver):
def __init__(self, layer, quant_bits=8, group_size=128):
super().__init__()
self.quant_bits = quant_bits
self.group_size = group_size
self._layer = layer
self._max = None
self._scale = None
self._zero_point = None

def forward(self, inputs):
self._max = self._cal_abs_max(inputs)
return inputs

def _cal_abs_max(self, inputs):
"""Use group_size to group the input, then use the
absmax method to calculate the scale
"""
input_shape = inputs.shape
assert (
self.group_size == 64 or self.group_size == 128
), "group_size only support 64 or 128"
assert (
inputs.shape[0] % self.group_size == 0
), "group_size must be a factor of input channels"
assert len(inputs.shape) == 2, "Currently only support 2D tensor"
input_processed = inputs.transpose([1, 0]).reshape(
[input_shape[1], input_shape[0] // self.group_size, self.group_size]
)

abs_max_values = paddle.max(paddle.abs(input_processed), axis=2).cast(
"float32"
)
abs_max_values = paddle.where(
abs_max_values == np.float32(0), np.float32(1e-8), abs_max_values
)
abs_max_values = abs_max_values.transpose([1, 0])
return abs_max_values

def min_value(self) -> float:
return 0.0

def max_value(self) -> float:
return self._max

def bit_length(self):
return self._quant_bits

def quant_axis(self):
return -1

def cal_thresholds(self):
"""Compute thresholds for MAX function."""
if self._scale is None:
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)

def scales(self):
"""Return output scales."""
if self._scale is None:
self.cal_thresholds()
return self._scale

def zero_points(self):
"""Return output zero points."""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
69 changes: 69 additions & 0 deletions test/quantization/test_groupwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# copyright (c) 2023 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import paddle
from paddle.nn import Linear, Sequential
from paddle.quantization import PTQ, QuantConfig
from paddle.quantization.observers import GroupWiseWeightObserver


class LinearDygraph(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = Sequential(
Linear(128, 128), Linear(128, 128), Linear(128, 128)
)

def forward(self, inputs):
out = self.fc(inputs)
return out


class TestPTQGroupWise(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.path = os.path.join(self.temp_dir.name, 'ptq')

def tearDown(self):
self.temp_dir.cleanup()

def _get_model_for_ptq(self):
observer = GroupWiseWeightObserver(quant_bits=4, group_size=128)
model = LinearDygraph()
model.eval()
q_config = QuantConfig(activation=None, weight=observer)
ptq = PTQ(q_config)
quant_model = ptq.quantize(model)
return quant_model, ptq

def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count

def test_quantize(self):
ptq_model, _ = self._get_model_for_ptq()
inputs = paddle.rand([128, 128], dtype="float32")
out = ptq_model(inputs)
self.assertIsNotNone(out)


if __name__ == '__main__':
unittest.main()