-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support fuse layers for ptq #35015
support fuse layers for ptq #35015
Changes from 35 commits
f4f31f2
11fbba0
ac21a60
350048e
cdfa3fe
111387f
9cfc38f
4b047da
e5ea4eb
3231853
da48df7
7981cb3
43976be
8ec36b6
fa20111
fc74ab0
ccd1675
199cf30
eb0fa57
0250eed
7819d62
ff95292
2f7cb4b
4140a48
b75e4cf
efb1acd
8870cc3
eca29a6
323cc3d
7a34ddb
97a1666
4b9c5dd
ba689a7
6c74925
5d3396a
fb67b15
1c334c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
import paddle | ||
import paddle.nn as nn | ||
from . import utils | ||
|
||
|
||
class Identity(nn.Layer): | ||
'''a layer to replace bn or relu layers''' | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(Identity, self).__init__() | ||
|
||
def forward(self, input): | ||
return input | ||
|
||
|
||
def fuse_layers(model, layers_to_fuse, inplace=False): | ||
'''fuse layers in layers_to_fuse''' | ||
if inplace == False: | ||
model = copy.deepcopy(model) | ||
for layers in layers_to_fuse: | ||
_fuse_layers(model, layers) | ||
return model | ||
|
||
|
||
def _fuse_layers(model, layers_list): | ||
'''fuse all the layers in layers_list''' | ||
lay_list = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 改成layer_list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for layer_name in layers_list: | ||
parent_layer, sub_name = utils.find_parent_layer_and_sub_name( | ||
model, layer_name) | ||
lay_list.append(getattr(parent_layer, sub_name)) | ||
new_layers = fuse_func(lay_list) | ||
for i, item in enumerate(layers_list): | ||
parent_layer, sub_name = utils.find_parent_layer_and_sub_name(model, | ||
item) | ||
setattr(parent_layer, sub_name, new_layers[i]) | ||
|
||
|
||
def fuse_func(lay_list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 方法命名规则请保持一致,如果是模块内函数,请加上下划线。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
'''choose the fuser method and fuse layers''' | ||
types = tuple(type(m) for m in lay_list) | ||
fuser_method = layer_list_to_fuse_method.get(types, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 严格来讲,不是layer list to fuse method, 是types_to_fusion_method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请斟酌下"fuser_method"的命名是否合适。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已改成fusion_method |
||
new_layers = [None] * len(lay_list) | ||
fused = fuser_method(*lay_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fused改成fused_layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
for handle_id, pre_hook_fn in lay_list[0]._forward_pre_hooks.items(): | ||
fused.register_forward_pre_hook(pre_hook_fn) | ||
del lay_list[0]._forward_pre_hooks[handle_id] | ||
for handle_id, hook_fn in lay_list[-1]._forward_post_hooks.items(): | ||
fused.register_forward_post_hook(hook_fn) | ||
del lay_list[-1]._forward_post_hooks[handle_id] | ||
new_layers[0] = fused | ||
for i in range(1, len(lay_list)): | ||
identity = Identity() | ||
identity.training = lay_list[0].training | ||
new_layers[i] = identity | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要加identity呢?不能直接把bn layer删掉么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样方便把bn的post hook改到identity上 |
||
return new_layers | ||
|
||
|
||
def fuse_conv_bn(conv, bn): | ||
'''fuse conv and bn for train or eval''' | ||
assert(conv.training == bn.training),\ | ||
"Conv and BN both must be in the same mode (train or eval)." | ||
if conv.training: | ||
assert bn._num_features == conv._out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' | ||
raise NotImplementedError | ||
else: | ||
return fuse_conv_bn_eval(conv, bn) | ||
|
||
|
||
def fuse_conv_bn_eval(conv, bn): | ||
'''fuse conv and bn for eval''' | ||
assert (not (conv.training or bn.training)), "Fusion only for eval!" | ||
fused_conv = copy.deepcopy(conv) | ||
|
||
fused_weight, fused_bias = fuse_conv_bn_weights( | ||
fused_conv.weight, fused_conv.bias, bn._mean, bn._variance, bn._epsilon, | ||
bn.weight, bn.bias) | ||
fused_conv.weight.set_value(fused_weight) | ||
if fused_conv.bias is None: | ||
fused_conv.bias = paddle.create_parameter( | ||
shape=[fused_conv._out_channels], is_bias=True, dtype='float32') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dtype设置为bn.bias.dtype? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
fused_conv.bias.set_value(fused_bias) | ||
return fused_conv | ||
|
||
|
||
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): | ||
'''fuse weights and bias of conv and bn''' | ||
if conv_b is None: | ||
conv_b = paddle.zeros_like(bn_rm) | ||
if bn_w is None: | ||
bn_w = paddle.ones_like(bn_rm) | ||
if bn_b is None: | ||
bn_b = paddle.zeros_like(bn_rm) | ||
bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps) | ||
conv_w = conv_w * \ | ||
(bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) | ||
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b | ||
return conv_w, conv_b | ||
|
||
|
||
def fuse_linear_bn(linear, bn): | ||
'''fuse linear and bn''' | ||
assert (linear.training == bn.training),\ | ||
"Linear and BN both must be in the same mode (train or eval)." | ||
if linear.training: | ||
assert bn._num_features == linear.weight.shape[ | ||
1], 'Output channel of Linear must match num_features of BatchNorm' | ||
raise NotImplementedError | ||
else: | ||
return fuse_linear_bn_eval(linear, bn) | ||
|
||
|
||
def fuse_linear_bn_eval(linear, bn): | ||
'''fuse linear and bn for eval''' | ||
assert (not (linear.training or bn.training)), "Fusion only for eval!" | ||
fused_linear = copy.deepcopy(linear) | ||
|
||
fused_weight, fused_bias = fuse_linear_bn_weights( | ||
fused_linear.weight, fused_linear.bias, bn._mean, bn._variance, | ||
bn._epsilon, bn.weight, bn.bias) | ||
fused_linear.weight.set_value(fused_weight) | ||
if fused_linear.bias is None: | ||
fused_linear.bias = paddle.create_parameter( | ||
shape=[fused_linear.weight.shape[1]], is_bias=True, dtype='float32') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dtype同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
fused_linear.bias.set_value(fused_bias) | ||
return fused_linear | ||
|
||
|
||
def fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, | ||
bn_b): | ||
'''fuse weights and bias of linear and bn''' | ||
if linear_b is None: | ||
linear_b = paddle.zeros_like(bn_rm) | ||
bn_scale = bn_w * paddle.rsqrt(bn_rv + bn_eps) | ||
fused_w = linear_w * bn_scale.unsqueeze(-1) | ||
fused_b = (linear_b - bn_rm) * bn_scale + bn_b | ||
return fused_w, fused_b | ||
|
||
|
||
layer_list_to_fuse_method = { | ||
(nn.Conv2D, nn.BatchNorm2D): fuse_conv_bn, | ||
(nn.Linear, nn.BatchNorm1D): fuse_linear_bn, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from paddle.fluid.log_helper import get_logger | ||
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX | ||
|
||
from . import fuse_utils | ||
from . import utils | ||
from . import ptq_hooks | ||
from . import ptq_config | ||
|
@@ -55,23 +56,28 @@ def __init__(self, quant_config=ptq_config.default_ptq_config): | |
|
||
self._quant_config = quant_config | ||
|
||
def quantize(self, model, inplace=False): | ||
def quantize(self, model, inplace=False, fuse=False, fuse_list=None): | ||
""" | ||
Add quant config and hook to the target layer. | ||
|
||
Args: | ||
model(paddle.nn.Layer): The model to be quantized. | ||
inplace(bool): Whether apply quantization to the input model. | ||
Default: False. | ||
fuse(bool): Whether fuse layers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whether to fuse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
Default: False. | ||
fuse_list(list): The layers to fuse. | ||
Default: None. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这里的注释不全。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
Returns: | ||
quantized_model(paddle.nn.Layer): The quantized model. | ||
""" | ||
assert isinstance(model, paddle.nn.Layer), \ | ||
"The model must be the instance of paddle.nn.Layer." | ||
|
||
if not inplace: | ||
model = copy.deepcopy(model) | ||
|
||
if fuse: | ||
model.eval() | ||
model = fuse_utils.fuse_layers(model, fuse_list) | ||
for name, layer in model.named_sublayers(): | ||
if PTQRegistry.is_supported_layer(layer) \ | ||
and utils.is_leaf_layer(layer) \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,18 +23,45 @@ | |
import copy | ||
import logging | ||
|
||
import paddle.nn as nn | ||
import paddle | ||
import paddle.fluid as fluid | ||
from paddle.fluid.contrib.slim.quantization import * | ||
from paddle.fluid.log_helper import get_logger | ||
from paddle.dataset.common import download | ||
|
||
from imperative_test_utils import fix_model_dict, ImperativeLenet | ||
from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn | ||
from imperative_test_utils import ImperativeLinearBn_hook | ||
|
||
_logger = get_logger( | ||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') | ||
|
||
|
||
class TestFuseLinearBn(unittest.TestCase): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
|
||
def test_fuse(self): | ||
model = ImperativeLinearBn() | ||
model_h = ImperativeLinearBn_hook() | ||
inputs = paddle.randn((3, 10), dtype="float32") | ||
config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) | ||
ptq = ImperativePTQ(config) | ||
f_l = [['linear', 'bn']] | ||
quant_model = ptq.quantize(model, fuse=True, fuse_list=f_l) | ||
quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l) | ||
for name, layer in quant_model.named_sublayers(): | ||
print(name, layer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要check fuse之后quant_model是否符合预期,比如check layer的类型,而不是只print There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已加assert,用于判断是否还存在bn层 |
||
out = model(inputs) | ||
out_h = model_h(inputs) | ||
out_quant = quant_model(inputs) | ||
out_quant_h = quant_h(inputs) | ||
cos_sim_func = nn.CosineSimilarity(axis=0) | ||
print('fuse linear+bn', | ||
cos_sim_func(out.flatten(), out_quant.flatten())) | ||
print(cos_sim_func(out_h.flatten(), out_quant_h.flatten())) | ||
|
||
|
||
class TestImperativePTQ(unittest.TestCase): | ||
""" | ||
""" | ||
|
@@ -177,7 +204,6 @@ def test_ptq(self): | |
model = ImperativeLenet() | ||
model_state_dict = paddle.load(params_path) | ||
model.set_state_dict(model_state_dict) | ||
|
||
# Quantize, calibrate and save | ||
quant_model = self.ptq.quantize(model) | ||
before_acc_top1 = self.model_test(quant_model, self.batch_num, | ||
|
@@ -216,6 +242,61 @@ def test_ptq(self): | |
print("total time: %ss \n" % (end_time - start_time)) | ||
|
||
|
||
class TestImperativePTQfuse(TestImperativePTQ): | ||
def test_ptq(self): | ||
start_time = time.time() | ||
|
||
self.set_vars() | ||
|
||
# Load model | ||
params_path = self.download_model(self.lenet_url, self.lenet_md5, | ||
"lenet") | ||
params_path += "/lenet_pretrained/lenet.pdparams" | ||
|
||
model = ImperativeLenet() | ||
model_state_dict = paddle.load(params_path) | ||
model.set_state_dict(model_state_dict) | ||
# Quantize, calibrate and save | ||
f_l = [['features.0', 'features.1'], ['features.4', 'features.5']] | ||
quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l) | ||
for name, layer in quant_model.named_sublayers(): | ||
print(name, layer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
before_acc_top1 = self.model_test(quant_model, self.batch_num, | ||
self.batch_size) | ||
|
||
input_spec = [ | ||
paddle.static.InputSpec( | ||
shape=[None, 1, 28, 28], dtype='float32') | ||
] | ||
self.ptq.save_quantized_model( | ||
model=quant_model, path=self.save_path, input_spec=input_spec) | ||
print('Quantized model saved in {%s}' % self.save_path) | ||
|
||
after_acc_top1 = self.model_test(quant_model, self.batch_num, | ||
self.batch_size) | ||
|
||
paddle.enable_static() | ||
infer_acc_top1 = self.program_test(self.save_path, self.batch_num, | ||
self.batch_size) | ||
paddle.disable_static() | ||
|
||
# Check | ||
print('Before converted acc_top1: %s' % before_acc_top1) | ||
print('After converted acc_top1: %s' % after_acc_top1) | ||
print('Infer acc_top1: %s' % infer_acc_top1) | ||
|
||
self.assertTrue( | ||
after_acc_top1 >= self.eval_acc_top1, | ||
msg="The test acc {%f} is less than {%f}." % | ||
(after_acc_top1, self.eval_acc_top1)) | ||
self.assertTrue( | ||
infer_acc_top1 >= after_acc_top1, | ||
msg='The acc is lower after converting model.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加下注释为什么 after_acc_top1, eval_acc_top1, infer_acc_top1 是这个关系。 和 before_acc_top1是啥关系 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done。after和before主要是convert前后的关系,eval_acc_top1是0.95,用于检查精度是否正确,infer是存储后又load出来 |
||
|
||
end_time = time.time() | ||
print("total time: %ss \n" % (end_time - start_time)) | ||
|
||
|
||
class TestImperativePTQHist(TestImperativePTQ): | ||
def set_vars(self): | ||
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请添加参数说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done