From acdc47816806da0ce29e02fe602b2b8d27bc551e Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Wed, 7 Jul 2021 10:20:28 +0800 Subject: [PATCH] Add batch normalization folding to QAT quantizer --- .../pytorch/quantization/quantizers.py | 48 ++++++- nni/compression/pytorch/compressor.py | 123 +++++++++++++++++- 2 files changed, 157 insertions(+), 14 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index dca1ef778e4..14e2af24fe4 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer): http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, model_inputs=None): """ Parameters ---------- @@ -145,8 +145,11 @@ def __init__(self, model, config_list, optimizer=None): state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 - op_types : list of string types of nn.module you want to apply quantization, eg. 'Conv2d' + - model_inputs : tuple of tensor + inputs to the model, which are used to get the graph of the module """ - super().__init__(model, config_list, optimizer) + + super().__init__(model, config_list, optimizer, model_inputs) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device @@ -169,8 +172,9 @@ def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ - del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ - 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] + del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', + 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', + 'activation_bit'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -344,9 +348,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ return calibration_config - def fold_bn(self, config, **kwargs): - # TODO simulate folded weight - pass + def fold_bn(self, *inputs, wrapper): + """ + Simulate batch normalization folding in the training graph. Folded weight and bias are + returned for the following operations. + + Parameters + ---------- + inputs : tuple of torch.Tensor + inputs for the module + wrapper : QuantizerModuleWrapper + the wrapper for origin module + + Returns + ------- + Tuple of torch.Tensor + """ + module = wrapper.module + bn_module = wrapper.bn_module + with torch.no_grad(): + output = module(*inputs) + _ = bn_module(output) + running_mean = bn_module.running_mean + running_var = torch.sqrt(bn_module.running_var + 1e-10) + bn_weight = bn_module.weight + bn_bias = bn_module.bias + dimensions = len(module.weight.shape) + shape = [-1] + [1] * (dimensions - 1) + new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape) + if hasattr(module, 'old_bias'): + new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight + else: + new_bias = bn_bias - running_mean / running_var * bn_weight + return new_weight, new_bias def step_with_optimizer(self): """ diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 01b8bb24e45..c93600b39db 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -4,6 +4,7 @@ import types import logging import torch +from nni.common.graph_utils import build_module_graph from . import default_layers _logger = logging.getLogger(__name__) @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0): class QuantizerModuleWrapper(torch.nn.Module): - def __init__(self, module, module_name, module_type, config, quantizer): + def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None): """ Wrap an module to enable data parallel, forward method customization and buffer registeration. @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer): the type of the module to compress quantizer :quantizer the quantizer used to calculate mask + bn_module : torch.nn.Module + batch norm layer corresponding to current module, used for simulating batch normalization folding """ super().__init__() # origin layer information @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer): # config and pruner self.config = config self.quantizer = quantizer + self.bn_module = bn_module # register buffer and parameter # old_weight is used to store origin weight and weight is used to store quantized weight @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer): delattr(self.module, 'weight') self.module.register_buffer('weight', self.module.old_weight) + # for batch normalization folding + if self.bn_module is not None: + if _check_bias(self.module): + self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias)) + init_tensor = self.module.old_bias + else: + init_tensor = torch.zeros_like(self.bn_module.weight) + delattr(self.module, 'bias') + self.module.register_buffer('bias', init_tensor) + + def forward(self, *inputs): if 'input' in self.config['quant_types']: inputs = self.quantizer.quant_grad( @@ -509,13 +524,19 @@ def forward(self, *inputs): self) if 'weight' in self.config['quant_types'] and _check_weight(self.module): + if self.bn_module is not None: + # simulate batch normalization folding + new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self) + self.module.bias = new_bias + else: + new_weight = self.module.old_weight + self.quantizer.quant_grad( - self.module.old_weight, + new_weight, QuantType.QUANT_WEIGHT, self, inputs[0]) - result = self.module(*inputs) - else: - result = self.module(*inputs) + + result = self.module(*inputs) if 'output' in self.config['quant_types']: result = self.quantizer.quant_grad( @@ -525,12 +546,35 @@ def forward(self, *inputs): return result +class QuantizerIdentityWrapper(torch.nn.Module): + def __init__(self, module, module_name): + """ + Used to wrap modules that should be treated as torch.Identity + + Parameters + ---------- + module : pytorch module + the module to be wrapped + module_name : str + the name of the module to wrapped, wrapper module shares same name + """ + super().__init__() + self.module = module + self.module_name = module_name + + def forward(self, x): + return x + + class Quantizer(Compressor): """ Base quantizer for pytorch quantizer """ - def __init__(self, model, config_list, optimizer=None): + def __init__(self, model, config_list, optimizer=None, model_inputs=None): + self.identity_wrappers = [] + self.conv_bn_patterns = {} + self.find_conv_bn_patterns(model, model_inputs) super().__init__(model, config_list, optimizer) self.quant_grad = QuantGrad.apply if self.optimizer is not None: @@ -540,6 +584,10 @@ def __init__(self, model, config_list, optimizer=None): # old_weight is registered to keep track of weight before quantization # and it is trainable, therefore, it should be added to optimizer. self.optimizer.add_param_group({"params": wrapper.module.old_weight}) + # This is for conv with bias + bn. Although this situation is relatively rare, + # we still need to deal with the old_bias when it occurs + if hasattr(wrapper.module, "old_bias"): + self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")}) def quantize_weight(self, wrapper, **kwargs): """ @@ -597,7 +645,36 @@ def _wrap_modules(self, layer, config): for quant_type in config['quant_types']: assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type - return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) + # bound bn module to corresponding conv module + bn_module = None + if layer.name in self.conv_bn_patterns: + bn_module_name = self.conv_bn_patterns[layer.name] + for name, module in self.bound_model.named_modules(): + if name == bn_module_name: + bn_module = module + break + assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name) + self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name)) + return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module) + + def _wrap_model(self): + """ + wrap all modules that needed to be compressed + + """ + # wrap folded bn in order to bypass its forward process + for wrapper in reversed(self.identity_wrappers): + _setattr(self.bound_model, wrapper.module_name, wrapper) + super()._wrap_model() + + def _unwrap_model(self): + """ + unwrap all modules that needed to be compressed + + """ + for wrapper in self.identity_wrappers: + _setattr(self.bound_model, wrapper.module_name, wrapper.module) + super()._unwrap_model() def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, input_shape=None, device=None): @@ -660,6 +737,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ """ raise NotImplementedError('Quantizer must overload export_model()') + def find_conv_bn_patterns(self, model, model_inputs): + """ + Find all Conv-BN patterns, used for batch normalization folding + + Parameters + ---------- + model : torch.nn.Module + model to be analyzed. + model_inputs : tupel of torch.tensor + inputs to the model, used for generating the torchscript + """ + if model_inputs is None: + _logger.debug("Model inputs are not given, batch normalization folding is disabled") + return + + graph = build_module_graph(model, model_inputs) + for node_group in graph.nodes_py.nodes_op: + if node_group.op_type in BN_FOLD_OP: + successors = graph.find_successors(node_group.unique_name) + successors = [graph.name_to_node[x] for x in successors] + for successor in successors: + if successor.op_type == 'BatchNorm2d': + self.conv_bn_patterns[node_group.name] = successor.name + def step_with_optimizer(self): pass @@ -677,6 +778,8 @@ class QuantType: 2: "output" } +BN_FOLD_OP = ["Conv2d"] + class QuantGrad(torch.autograd.Function): """ Base class for overriding backward function of quantization operation. @@ -773,6 +876,12 @@ def _check_weight(module): except AttributeError: return False +def _check_bias(module): + try: + return isinstance(module.bias.data, torch.Tensor) + except AttributeError: + return False + def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): if quant_type == QuantType.QUANT_INPUT: output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)