Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Jul 14, 2021
1 parent acdc478 commit e6d1d27
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
29 changes: 23 additions & 6 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']

Expand Down Expand Up @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None, model_inputs=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
"""
Parameters
----------
Expand All @@ -145,11 +145,11 @@ def __init__(self, model, config_list, optimizer=None, model_inputs=None):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
- model_inputs : tuple of tensor
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module
"""

super().__init__(model, config_list, optimizer, model_inputs)
super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
Expand All @@ -174,7 +174,7 @@ def _del_simulated_attr(self, module):
"""
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit',
'activation_bit']
'activation_bit', 'BN_FOLD_TAG']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -338,6 +338,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)

# Recover weight/bias for batch normalization folding
if hasattr(module, BN_FOLD_TAG):
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer {}. "
"This may lead to a wrong accuracy performance on the backend.".format(name))
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)

actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)

if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
Expand Down Expand Up @@ -370,7 +387,7 @@ def fold_bn(self, *inputs, wrapper):
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + 1e-10)
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
Expand Down
16 changes: 9 additions & 7 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def __init__(self, module, module_name, module_type, config, quantizer, bn_modul
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
self.module.register_buffer('bias', init_tensor)

setattr(module, BN_FOLD_TAG, True)

def forward(self, *inputs):
if 'input' in self.config['quant_types']:
Expand All @@ -528,6 +528,7 @@ def forward(self, *inputs):
# simulate batch normalization folding
new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self)
self.module.bias = new_bias
self.module.weight = new_weight
else:
new_weight = self.module.old_weight

Expand Down Expand Up @@ -571,10 +572,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list, optimizer=None, model_inputs=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, model_inputs)
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
Expand Down Expand Up @@ -737,22 +738,22 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
"""
raise NotImplementedError('Quantizer must overload export_model()')

def find_conv_bn_patterns(self, model, model_inputs):
def find_conv_bn_patterns(self, model, dummy_input):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
model_inputs : tupel of torch.tensor
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if model_inputs is None:
if dummy_input is None:
_logger.debug("Model inputs are not given, batch normalization folding is disabled")
return

graph = build_module_graph(model, model_inputs)
graph = build_module_graph(model, dummy_input)
for node_group in graph.nodes_py.nodes_op:
if node_group.op_type in BN_FOLD_OP:
successors = graph.find_successors(node_group.unique_name)
Expand All @@ -779,6 +780,7 @@ class QuantType:
}

BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'

class QuantGrad(torch.autograd.Function):
"""
Expand Down

0 comments on commit e6d1d27

Please sign in to comment.