Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add batch normalization folding to QAT quantizer (#3911)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 authored Jul 26, 2021
1 parent 441c5da commit 7fc5af0
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 18 deletions.
21 changes: 18 additions & 3 deletions docs/en_US/Compression/Quantizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,25 @@ configuration needed by this algorithm :
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0

note
^^^^
Batch normalization folding
^^^^^^^^^^^^^^^^^^^^^^^^^^^

batch normalization folding is currently not supported.
Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to
the quantizer, like:

.. code-block:: python
# assume your model takes an input of shape (1, 1, 28, 28)
# and dummy_input must be on the same device as the model
dummy_input = torch.randn(1, 1, 28, 28)
# pass the dummy_input to the quantizer
quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input)
The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training
graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling
`quantizer.export_model`.

----

Expand Down
69 changes: 61 additions & 8 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']

Expand Down Expand Up @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
"""
Parameters
----------
Expand All @@ -145,8 +145,13 @@ def __init__(self, model, config_list, optimizer=None):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
super().__init__(model, config_list, optimizer)

super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
Expand All @@ -169,8 +174,9 @@ def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit',
'activation_bit', 'BN_FOLD_TAG']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -334,6 +340,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)

# Recover weight/bias for batch normalization folding
if hasattr(module, BN_FOLD_TAG):
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)

actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)

if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
Expand All @@ -344,9 +367,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_

return calibration_config

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def fold_bn(self, *inputs, wrapper):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias

def step_with_optimizer(self):
"""
Expand Down
125 changes: 118 additions & 7 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import logging
import torch
from nni.common.graph_utils import build_module_graph
from . import default_layers

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0):


class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Expand All @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer):
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
"""
super().__init__()
# origin layer information
Expand All @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer):
# config and pruner
self.config = config
self.quantizer = quantizer
self.bn_module = bn_module

# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
Expand All @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer):
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)

# for batch normalization folding
if self.bn_module is not None:
if _check_bias(self.module):
self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias))
init_tensor = self.module.old_bias
else:
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
self.module.register_buffer('bias', init_tensor)
setattr(module, BN_FOLD_TAG, True)

def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad(
Expand All @@ -509,13 +524,20 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
# simulate batch normalization folding
new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self)
self.module.bias = new_bias
self.module.weight = new_weight
else:
new_weight = self.module.old_weight

self.quantizer.quant_grad(
self.module.old_weight,
new_weight,
QuantType.QUANT_WEIGHT,
self, inputs[0])
result = self.module(*inputs)
else:
result = self.module(*inputs)

result = self.module(*inputs)

if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad(
Expand All @@ -525,12 +547,35 @@ def forward(self, *inputs):
return result


class QuantizerIdentityWrapper(torch.nn.Module):
def __init__(self, module, module_name):
"""
Used to wrap modules that should be treated as torch.Identity
Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super().__init__()
self.module = module
self.module_name = module_name

def forward(self, x):
return x


class Quantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
Expand All @@ -540,6 +585,10 @@ def __init__(self, model, config_list, optimizer=None):
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if hasattr(wrapper.module, "old_bias"):
self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")})

def quantize_weight(self, wrapper, **kwargs):
"""
Expand Down Expand Up @@ -597,7 +646,36 @@ def _wrap_modules(self, layer, config):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type

return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
# bound bn module to corresponding conv module
bn_module = None
if layer.name in self.conv_bn_patterns:
bn_module_name = self.conv_bn_patterns[layer.name]
for name, module in self.bound_model.named_modules():
if name == bn_module_name:
bn_module = module
break
assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name)
self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name))
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module)

def _wrap_model(self):
"""
wrap all modules that needed to be compressed
"""
# wrap folded bn in order to bypass its forward process
for wrapper in reversed(self.identity_wrappers):
_setattr(self.bound_model, wrapper.module_name, wrapper)
super()._wrap_model()

def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed
"""
for wrapper in self.identity_wrappers:
_setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model()

def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None):
Expand Down Expand Up @@ -660,6 +738,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
"""
raise NotImplementedError('Quantizer must overload export_model()')

def find_conv_bn_patterns(self, model, dummy_input):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if dummy_input is None:
_logger.debug("Model inputs are not given, batch normalization folding is disabled")
return

graph = build_module_graph(model, dummy_input)
for node_group in graph.nodes_py.nodes_op:
if node_group.op_type in BN_FOLD_OP:
successors = graph.find_successors(node_group.unique_name)
successors = [graph.name_to_node[x] for x in successors]
for successor in successors:
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name

def step_with_optimizer(self):
pass

Expand All @@ -677,6 +779,9 @@ class QuantType:
2: "output"
}

BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'

class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
Expand Down Expand Up @@ -773,6 +878,12 @@ def _check_weight(module):
except AttributeError:
return False

def _check_bias(module):
try:
return isinstance(module.bias.data, torch.Tensor)
except AttributeError:
return False

def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
Expand Down

0 comments on commit 7fc5af0

Please sign in to comment.