Skip to content

Commit

Permalink
Add QuantizedMatmul in QAT (#47997)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Dec 8, 2022
1 parent 94fe929 commit 01f5210
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq)
list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul)

endif()

if(LINUX AND WITH_MKLDNN)
Expand Down Expand Up @@ -507,6 +509,7 @@ if(WIN32)
test_imperative_qat_channelwise
test_imperative_qat
test_imperative_qat_lsq
test_imperative_qat_matmul
test_imperative_out_scale
test_graph)
list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS})
Expand Down Expand Up @@ -547,6 +550,7 @@ set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_lsq PROPERTIES TIMEOUT 300)
set_tests_properties(test_imperative_qat_matmul PROPERTIES TIMEOUT 300)

if(LINUX AND WITH_MKLDNN)
set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
Expand Down
234 changes: 234 additions & 0 deletions python/paddle/fluid/contrib/slim/tests/test_imperative_qat_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import os
import numpy as np
import random
import time
import tempfile
import unittest
import logging

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import (
SGDOptimizer,
AdamOptimizer,
MomentumOptimizer,
)
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.nn import Sequential
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.quant.quant_layers import (
QuantizedConv2D,
QuantizedMatmul,
)
from paddle.fluid.framework import _test_eager_guard
from imperative_test_utils import fix_model_dict

paddle.enable_static()

os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10):
super().__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1,
weight_attr=conv2d_w1_attr,
bias_attr=False,
),
BatchNorm2D(6),
ReLU(),
MaxPool2D(kernel_size=2, stride=2),
Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0,
weight_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr,
),
BatchNorm2D(16),
PReLU(),
MaxPool2D(kernel_size=2, stride=2),
)
self.matmul = QuantizedMatmul()
self.fc = Sequential(
Linear(
in_features=400,
out_features=120,
weight_attr=fc_w1_attr,
bias_attr=fc_b1_attr,
),
LeakyReLU(),
Linear(
in_features=120,
out_features=84,
weight_attr=fc_w2_attr,
bias_attr=fc_b2_attr,
),
Sigmoid(),
Linear(
in_features=84,
out_features=num_classes,
weight_attr=fc_w3_attr,
bias_attr=fc_b3_attr,
),
Softmax(),
)

def forward(self, inputs):
inputs = self.features(inputs)
inputs = self.matmul(inputs, inputs, transpose_y=True)
inputs = paddle.flatten(inputs, 1)
x = self.fc(inputs)
return x


class TestImperativeQatMatmul(unittest.TestCase):
def set_vars(self):
self.weight_quantize_type = 'abs_max'
self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True
self.fuse_conv_bn = False

def func_qat(self):
self.set_vars()

imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn,
)

seed = 100
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
paddle.disable_static()
lenet = ImperativeLenet()
lenet = fix_model_dict(lenet)
imperative_qat.quantize(lenet)

optimizer = MomentumOptimizer(
learning_rate=0.1, parameter_list=lenet.parameters(), momentum=0.9
)

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=64, drop_last=True
)
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=32)
epoch_num = 1
for epoch in range(epoch_num):
lenet.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]
).astype('float32')
y_data = (
np.array([x[1] for x in data])
.astype('int64')
.reshape(-1, 1)
)

img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = paddle.static.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = paddle.mean(loss)

avg_loss.backward()
optimizer.minimize(avg_loss)
lenet.clear_gradients()

if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".format(
epoch, batch_id, avg_loss.numpy(), acc.numpy()
)
)

lenet.eval()
eval_acc_top1_list = []
with paddle.no_grad():
for batch_id, data in enumerate(test_reader()):

x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]
).astype('float32')
y_data = (
np.array([x[1] for x in data])
.astype('int64')
.reshape(-1, 1)
)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)

out = lenet(img)
acc_top1 = paddle.static.accuracy(
input=out, label=label, k=1
)
acc_top5 = paddle.static.accuracy(
input=out, label=label, k=5
)

if batch_id % 100 == 0:
eval_acc_top1_list.append(float(acc_top1.numpy()))
_logger.info(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".format(
epoch,
batch_id,
acc_top1.numpy(),
acc_top5.numpy(),
)
)

# check eval acc
eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list)
print('eval_acc_top1', eval_acc_top1)

def test_qat(self):
self.func_qat()


if __name__ == '__main__':
unittest.main()
60 changes: 60 additions & 0 deletions python/paddle/nn/quant/quant_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'QuantStub',
'QuantizedRowParallelLinear',
'QuantizedColumnParallelLinear',
'QuantizedMatmul',
]

_logger = get_logger(
Expand Down Expand Up @@ -999,6 +1000,65 @@ def forward(self, input):
return output


class QuantizedMatmul(Layer):
"""
The computational logic of QuantizedMatmul is the same with Matmul.
The only difference is that its inputs are all fake quantized.
"""

def __init__(
self,
layer=None,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max',
weight_pre_layer=None,
act_pre_layer=None,
weight_quant_layer=None,
act_quant_layer=None,
):
super().__init__()

# For FakeQuant
if act_quant_layer is not None:
self._fake_quant_x = act_quant_layer()
self._fake_quant_y = act_quant_layer()
else:
self._fake_quant_x = _get_fake_quant_type(
activation_quantize_type,
moving_rate=moving_rate,
quant_bits=activation_bits,
quant_on_weight=False,
)
self._fake_quant_y = _get_fake_quant_type(
activation_quantize_type,
moving_rate=moving_rate,
quant_bits=activation_bits,
quant_on_weight=False,
)

self._act_preprocess_x = (
act_pre_layer() if act_pre_layer is not None else None
)
self._act_preprocess_y = (
act_pre_layer() if act_pre_layer is not None else None
)

def forward(self, x, y, transpose_x=False, transpose_y=False, name=None):
if self._act_preprocess_x is not None:
x = self._act_preprocess_x(x)
quant_x = self._fake_quant_x(x)

if self._act_preprocess_y is not None:
y = self._act_preprocess_y(y)
quant_y = self._fake_quant_y(y)

out = paddle.matmul(quant_x, quant_y, transpose_x, transpose_y, name)
return out


class MAOutputScaleLayer(Layer):
"""
Add MovingAverageMaxScale layer to the behind of the input layer.
Expand Down

0 comments on commit 01f5210

Please sign in to comment.