diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 14e2af24fe4..b0dc88c7eca 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,7 +6,7 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema -from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType +from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ - def __init__(self, model, config_list, optimizer=None, model_inputs=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): """ Parameters ---------- @@ -145,11 +145,11 @@ def __init__(self, model, config_list, optimizer=None, model_inputs=None): state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' - - model_inputs : tuple of tensor + - dummy_input : tuple of tensor inputs to the model, which are used to get the graph of the module """ - super().__init__(model, config_list, optimizer, model_inputs) + super().__init__(model, config_list, optimizer, dummy_input) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device @@ -174,7 +174,7 @@ def _del_simulated_attr(self, module): """ del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', - 'activation_bit'] + 'activation_bit', 'BN_FOLD_TAG'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -338,6 +338,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['weight_bit'] = int(module.weight_bit) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) + + # Recover weight/bias for batch normalization folding + if hasattr(module, BN_FOLD_TAG): + actual_weight = getattr(module, 'old_weight', None) + if actual_weight is None: + logger.warning("Can not recover weight for layer {}. " + "This may lead to a wrong accuracy performance on the backend.".format(name)) + delattr(module, 'weight') + module.register_parameter('weight', actual_weight) + + actual_bias = getattr(module, 'old_bias', None) + delattr(module, 'bias') + if actual_bias is not None: + module.register_parameter('bias', actual_bias) + else: + setattr(module, 'bias', None) + if hasattr(module, 'activation_bit'): calibration_config[name]['activation_bit'] = int(module.activation_bit) calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) @@ -370,7 +387,7 @@ def fold_bn(self, *inputs, wrapper): output = module(*inputs) _ = bn_module(output) running_mean = bn_module.running_mean - running_var = torch.sqrt(bn_module.running_var + 1e-10) + running_var = torch.sqrt(bn_module.running_var + bn_module.eps) bn_weight = bn_module.weight bn_bias = bn_module.bias dimensions = len(module.weight.shape) diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index c93600b39db..4dcde454224 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -514,7 +514,7 @@ def __init__(self, module, module_name, module_type, config, quantizer, bn_modul init_tensor = torch.zeros_like(self.bn_module.weight) delattr(self.module, 'bias') self.module.register_buffer('bias', init_tensor) - + setattr(module, BN_FOLD_TAG, True) def forward(self, *inputs): if 'input' in self.config['quant_types']: @@ -528,6 +528,7 @@ def forward(self, *inputs): # simulate batch normalization folding new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self) self.module.bias = new_bias + self.module.weight = new_weight else: new_weight = self.module.old_weight @@ -571,10 +572,10 @@ class Quantizer(Compressor): Base quantizer for pytorch quantizer """ - def __init__(self, model, config_list, optimizer=None, model_inputs=None): + def __init__(self, model, config_list, optimizer=None, dummy_input=None): self.identity_wrappers = [] self.conv_bn_patterns = {} - self.find_conv_bn_patterns(model, model_inputs) + self.find_conv_bn_patterns(model, dummy_input) super().__init__(model, config_list, optimizer) self.quant_grad = QuantGrad.apply if self.optimizer is not None: @@ -737,7 +738,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ """ raise NotImplementedError('Quantizer must overload export_model()') - def find_conv_bn_patterns(self, model, model_inputs): + def find_conv_bn_patterns(self, model, dummy_input): """ Find all Conv-BN patterns, used for batch normalization folding @@ -745,14 +746,14 @@ def find_conv_bn_patterns(self, model, model_inputs): ---------- model : torch.nn.Module model to be analyzed. - model_inputs : tupel of torch.tensor + dummy_input : tupel of torch.tensor inputs to the model, used for generating the torchscript """ - if model_inputs is None: + if dummy_input is None: _logger.debug("Model inputs are not given, batch normalization folding is disabled") return - graph = build_module_graph(model, model_inputs) + graph = build_module_graph(model, dummy_input) for node_group in graph.nodes_py.nodes_op: if node_group.op_type in BN_FOLD_OP: successors = graph.find_successors(node_group.unique_name) @@ -779,6 +780,7 @@ class QuantType: } BN_FOLD_OP = ["Conv2d"] +BN_FOLD_TAG = 'BN_FOLD_TAG' class QuantGrad(torch.autograd.Function): """