Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add batch normalization folding to QAT quantizer #3911

Merged
merged 9 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions docs/en_US/Compression/Quantizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,25 @@ configuration needed by this algorithm :
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0

note
^^^^
Batch normalization folding
^^^^^^^^^^^^^^^^^^^^^^^^^^^

batch normalization folding is currently not supported.
Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to
the quantizer, like:

.. code-block:: python

# assume your model takes an input of shape (1, 1, 28, 28)
# and dummy_input must be on the same device as the model
dummy_input = torch.randn(1, 1, 28, 28)

# pass the dummy_input to the quantizer
quantizer = QAT_Quantizer(model, config_list, dummy_input=dummy_input)


The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training
graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling
`quantizer.export_model`.

----

Expand Down
69 changes: 61 additions & 8 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']

Expand Down Expand Up @@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
"""
Parameters
----------
Expand All @@ -145,8 +145,13 @@ def __init__(self, model, config_list, optimizer=None):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
super().__init__(model, config_list, optimizer)

super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
Expand All @@ -169,8 +174,9 @@ def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation',
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bit',
'activation_bit', 'BN_FOLD_TAG']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -334,6 +340,23 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)

# Recover weight/bias for batch normalization folding
if hasattr(module, BN_FOLD_TAG):
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)

actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)

if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
Expand All @@ -344,9 +367,39 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_

return calibration_config

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass
def fold_bn(self, *inputs, wrapper):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is QAT_Quantizer specific? other quantizers may have a different fold_bn function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should also work well for other quantizers. (at least for lsq quantizer I think:) ). I will make it a common utility function in the pr that enables batch normalization folding for other quantizers.

"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.

Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module

Returns
-------
Tuple of torch.Tensor
"""
module = wrapper.module
bn_module = wrapper.bn_module
with torch.no_grad():
output = module(*inputs)
_ = bn_module(output)
running_mean = bn_module.running_mean
running_var = torch.sqrt(bn_module.running_var + bn_module.eps)
bn_weight = bn_module.weight
bn_bias = bn_module.bias
dimensions = len(module.weight.shape)
shape = [-1] + [1] * (dimensions - 1)
new_weight = module.old_weight * bn_weight.reshape(shape) / running_var.reshape(shape)
if hasattr(module, 'old_bias'):
new_bias = bn_bias + (module.old_bias - running_mean) / running_var * bn_weight
else:
new_bias = bn_bias - running_mean / running_var * bn_weight
return new_weight, new_bias

def step_with_optimizer(self):
"""
Expand Down
125 changes: 118 additions & 7 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import logging
import torch
from nni.common.graph_utils import build_module_graph
from . import default_layers

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -463,7 +464,7 @@ def get_pruned_weights(self, dim=0):


class QuantizerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, quantizer):
def __init__(self, module, module_name, module_type, config, quantizer, bn_module=None):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.

Expand All @@ -479,6 +480,8 @@ def __init__(self, module, module_name, module_type, config, quantizer):
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
"""
super().__init__()
# origin layer information
Expand All @@ -488,6 +491,7 @@ def __init__(self, module, module_name, module_type, config, quantizer):
# config and pruner
self.config = config
self.quantizer = quantizer
self.bn_module = bn_module

# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
Expand All @@ -501,6 +505,17 @@ def __init__(self, module, module_name, module_type, config, quantizer):
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)

# for batch normalization folding
if self.bn_module is not None:
if _check_bias(self.module):
self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias))
init_tensor = self.module.old_bias
else:
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
self.module.register_buffer('bias', init_tensor)
setattr(module, BN_FOLD_TAG, True)

def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad(
Expand All @@ -509,13 +524,20 @@ def forward(self, *inputs):
self)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
# simulate batch normalization folding
new_weight, new_bias = self.quantizer.fold_bn(*inputs, wrapper=self)
self.module.bias = new_bias
self.module.weight = new_weight
else:
new_weight = self.module.old_weight

self.quantizer.quant_grad(
self.module.old_weight,
new_weight,
QuantType.QUANT_WEIGHT,
self, inputs[0])
result = self.module(*inputs)
else:
result = self.module(*inputs)

result = self.module(*inputs)

if 'output' in self.config['quant_types']:
result = self.quantizer.quant_grad(
Expand All @@ -525,12 +547,35 @@ def forward(self, *inputs):
return result


class QuantizerIdentityWrapper(torch.nn.Module):
def __init__(self, module, module_name):
"""
Used to wrap modules that should be treated as torch.Identity

Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super().__init__()
self.module = module
self.module_name = module_name

def forward(self, x):
return x


class Quantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
J-shang marked this conversation as resolved.
Show resolved Hide resolved
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantGrad.apply
if self.optimizer is not None:
Expand All @@ -540,6 +585,10 @@ def __init__(self, model, config_list, optimizer=None):
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if hasattr(wrapper.module, "old_bias"):
self.optimizer.add_param_group({"params": getattr(wrapper.module, "old_bias")})

def quantize_weight(self, wrapper, **kwargs):
"""
Expand Down Expand Up @@ -597,7 +646,36 @@ def _wrap_modules(self, layer, config):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type

return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self)
# bound bn module to corresponding conv module
bn_module = None
if layer.name in self.conv_bn_patterns:
bn_module_name = self.conv_bn_patterns[layer.name]
for name, module in self.bound_model.named_modules():
if name == bn_module_name:
bn_module = module
break
assert bn_module is not None, "BN module corresponding to layer {} is not found".format(layer.name)
self.identity_wrappers.append(QuantizerIdentityWrapper(bn_module, bn_module_name))
return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self, bn_module)

def _wrap_model(self):
"""
wrap all modules that needed to be compressed

"""
# wrap folded bn in order to bypass its forward process
for wrapper in reversed(self.identity_wrappers):
_setattr(self.bound_model, wrapper.module_name, wrapper)
super()._wrap_model()

def _unwrap_model(self):
"""
unwrap all modules that needed to be compressed

"""
for wrapper in self.identity_wrappers:
_setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model()

def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None):
Expand Down Expand Up @@ -660,6 +738,30 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
"""
raise NotImplementedError('Quantizer must overload export_model()')

def find_conv_bn_patterns(self, model, dummy_input):
"""
Find all Conv-BN patterns, used for batch normalization folding

Parameters
----------
model : torch.nn.Module
model to be analyzed.
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if dummy_input is None:
_logger.debug("Model inputs are not given, batch normalization folding is disabled")
return

graph = build_module_graph(model, dummy_input)
for node_group in graph.nodes_py.nodes_op:
if node_group.op_type in BN_FOLD_OP:
successors = graph.find_successors(node_group.unique_name)
successors = [graph.name_to_node[x] for x in successors]
for successor in successors:
if successor.op_type == 'BatchNorm2d':
self.conv_bn_patterns[node_group.name] = successor.name

def step_with_optimizer(self):
pass

Expand All @@ -677,6 +779,9 @@ class QuantType:
2: "output"
}

BN_FOLD_OP = ["Conv2d"]
BN_FOLD_TAG = 'BN_FOLD_TAG'

class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
Expand Down Expand Up @@ -773,6 +878,12 @@ def _check_weight(module):
except AttributeError:
return False

def _check_bias(module):
try:
return isinstance(module.bias.data, torch.Tensor)
except AttributeError:
return False

def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
Expand Down