Skip to content

Commit

Permalink
support quantization of conv2d_transpose (#34547)
Browse files Browse the repository at this point in the history
  • Loading branch information
XGZhang11 authored Aug 18, 2021
1 parent 4d88cdb commit 8967a66
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 32 deletions.
101 changes: 76 additions & 25 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ class ImperativeQuantAware(object):
Applying quantization aware training (QAT) to the dgraph model.
"""

def __init__(self,
quantizable_layer_type=['Conv2D', 'Linear'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
def __init__(
self,
quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
"""
The constructor for ImperativeQuantAware.
Expand Down Expand Up @@ -212,9 +213,44 @@ def quantize(self, model):
the out_scale value of outputs would be calculated.
Args:
model(fluid.dygraph.Layer): the model to be quantized.
model(paddle.nn.Layer): the model to be quantized.
Returns:
None
Examples:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
def __init__(self):
super(ImperativeModel, self).__init__()
# self.linear_0 would skip the quantization.
self.linear_0 = paddle.nn.Linear(784, 400)
self.linear_0.skip_quant = True
# self.linear_1 would not skip the quantization.
self.linear_1 = paddle.nn.Linear(400, 10)
self.linear_1.skip_quant = False
def forward(self, inputs):
x = self.linear_0(inputs)
x = self.linear_1(inputs)
return x
model = ImperativeModel()
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
# Add the fake quant logical.
# The original model will be rewrite.
#
# There is only one Layer(self.linear1) would be added the
# fake quant logical.
imperative_qat.quantize(model)
"""
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
Expand All @@ -232,17 +268,18 @@ class ImperativeQuantizeInputs(object):
logic both for activation inputs and weight inputs.
"""

def __init__(self,
quantizable_layer_type=['Conv2D', 'Linear'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
def __init__(
self,
quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
act_quantize_layer=None):
"""
The constructor for ImperativeQuantizeInputs.
Expand Down Expand Up @@ -303,6 +340,18 @@ def __init__(self,
}

def apply(self, model):
"""
Quantize the weights and activations to calculate for specific
layers.
Args:
model(paddle.nn.Layer): The target model which would
calculate the input quantization scale.
Returns:
None
"""

assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."

Expand Down Expand Up @@ -354,7 +403,7 @@ def apply(self, model):
output scales for specific layers in the dygraph model.
Args:
model(fluid.dygraph.Layer): The target model which would be
model(paddle.nn.Layer): The target model which would be
calculate the output quantization scale.
Returns:
Expand Down Expand Up @@ -544,7 +593,9 @@ def _is_skip_quant_op(self, block, in_op):
1. the type of input op should be conv2d, depthwise_conv2d or matmul
2. the previous ops of the input op are not fake_quantize_dequantize ops
"""
target_op_types = ["conv2d", "depthwise_conv2d", "matmul"]
target_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose"
]
if in_op.type not in target_op_types:
return False

Expand Down
19 changes: 14 additions & 5 deletions python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..quantization_pass import _get_input_name_index

layer_name_map = {
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
Expand All @@ -46,8 +47,9 @@
}

# Apply fake quant for the inputs of these layers
# TODO (jc): support paddle.nn.Conv2DTranspose
fake_quant_input_layers = [paddle.nn.Conv2D, paddle.nn.Linear]
fake_quant_input_layers = [
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose
]

# Apply fake quant for the output of these layers
# TODO(jc): fix the problem of adding duplicate fake_quant ops
Expand All @@ -65,7 +67,8 @@
]

fake_quant_wrap_layers = [
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear,
quant_layers.QuantizedConv2DTranspose
]

# The weight format of these layers is Cin * Cout * H * W
Expand All @@ -84,9 +87,9 @@


def load_variable_data(scope, var_name):
'''
"""
Load variable value from scope
'''
"""
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Can not find " + var_name + " in the scope."
Expand Down Expand Up @@ -120,6 +123,12 @@ def find_parent_layer_and_sub_name(model, name):
the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be quantized.
name(string): the name of a layer
Returns:
parent_layer, subname
"""
assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.dygraph.container import Sequential
from paddle.nn import Linear, Conv2D, Softmax
from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.quant.quant_layers import QuantizedConv2D
from paddle.nn.quant.quant_layers import QuantizedConv2D, QuantizedConv2DTranspose

from imperative_test_utils import fix_model_dict, ImperativeLenet

Expand Down Expand Up @@ -75,6 +75,12 @@ def test_qat(self):
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
quant_conv1(fluid.dygraph.to_variable(data))

conv_transpose = Conv2DTranspose(4, 6, (3, 3))
quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
x_var = paddle.uniform(
(2, 4, 8, 8), dtype='float32', min=-1.0, max=1.0)
quant_conv_transpose(x_var)

seed = 1
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import Pool2D
from paddle.fluid.dygraph import Linear
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
from paddle.fluid.log_helper import get_logger

os.environ["CPU_NUM"] = "1"
Expand Down Expand Up @@ -100,6 +101,19 @@ def dequantize(x, lower_bound, delta, interval):
return x


class ModelForConv2dT(nn.Layer):
def __init__(self, num_classes=10):
super(ModelForConv2dT, self).__init__()
self.features = nn.Conv2DTranspose(4, 6, (3, 3))
self.fc = Linear(input_dim=600, output_dim=num_classes)

def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x


class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
Expand Down Expand Up @@ -168,6 +182,11 @@ def test_quant_aware_training(self):
imperative_qat.quantize(lenet)
adam = Adam(learning_rate=0.001, parameters=lenet.parameters())
dynamic_loss_rec = []
#for CI coverage
conv_transpose = ModelForConv2dT()
imperative_qat.quantize(conv_transpose)
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv_transpose(x_var)

def train(model):
adam = Adam(learning_rate=0.001, parameters=model.parameters())
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/nn/quant/quant_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'FakeQuantMovingAverageAbsMax',
'FakeQuantChannelWiseAbsMax',
'QuantizedConv2D',
'QuantizedConv2DTranspose',
'QuantizedLinear',
'MovingAverageAbsMaxScale',
'MAOutputScaleLayer',
Expand Down Expand Up @@ -481,6 +482,112 @@ def forward(self, input):
data_format=self._data_format)


class QuantizedConv2DTranspose(layers.Layer):
"""
The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
The only difference is that its inputs are all fake quantized.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = nn.Conv2DTranspose(4, 6, (3, 3))
conv_quantized = QuantizedConv2DTranspose(conv)
y_quantized = conv_quantized(x_var)
y_var = conv(x_var)
y_quantized_np = y_quantized.numpy()
y_np = y_var.numpy()
print(y_np.shape, y_quantized_np.shape)
# (2, 6, 10, 10), (2, 6, 10, 10)
"""

def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max',
weight_pre_layer=None,
act_pre_layer=None,
weight_quant_layer=None,
act_quant_layer=None):
r"""
Constructor.
The arguments are the same as ImperativeQuantAware.
"""
super(QuantizedConv2DTranspose, self).__init__()
# For Conv2DTranspose
self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding')
self._output_padding = getattr(layer, 'output_padding')
self._dilation = getattr(layer, '_dilation')
self._data_format = getattr(layer, '_data_format')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._conv2d_transpose_quant_axis = 1
if weight_quant_layer is not None:
self._fake_quant_weight = weight_quant_layer()
else:
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True,
channel_num=self.weight.shape[
self._conv2d_transpose_quant_axis],
quant_axis=self._conv2d_transpose_quant_axis)
if act_quant_layer is not None:
self._fake_quant_input = act_quant_layer()
else:
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype,
quant_on_weight=False)

self._act_preprocess = act_pre_layer(
) if act_pre_layer is not None else None
self._weight_preprocess = weight_pre_layer(
) if weight_pre_layer is not None else None

def forward(self, input, output_size=None):
if self._act_preprocess is not None:
input = self._act_preprocess(input)
quant_input = self._fake_quant_input(input)

weight = self.weight
if self._weight_preprocess is not None:
weight = self._weight_preprocess(self.weight)
quant_weight = self._fake_quant_weight(weight)

if output_size is None:
output_padding = self._output_padding
else:
output_padding = 0

return F.conv2d_transpose(
quant_input,
quant_weight,
bias=self.bias,
padding=self._padding,
output_padding=output_padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
output_size=output_size,
data_format=self._data_format)


class QuantizedLinear(layers.Layer):
"""
The computational logic of QuantizedLinear is the same with Linear.
Expand Down
1 change: 1 addition & 0 deletions tools/sampcd_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def get_filenames(full_test=False):
'''
global whl_error
import paddle
import paddle.fluid.contrib.slim.quantization
whl_error = []
if full_test:
get_full_api_from_pr_spec()
Expand Down

0 comments on commit 8967a66

Please sign in to comment.