diff --git a/docs/en_US/Compression/CustomizeCompressor.rst b/docs/en_US/Compression/CustomizeCompressor.rst index 103bff818c..55efd5e88c 100644 --- a/docs/en_US/Compression/CustomizeCompressor.rst +++ b/docs/en_US/Compression/CustomizeCompressor.rst @@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw grad_output : Tensor gradient of the output of quantization operation quant_type : QuantType - the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, + the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`, you can define different behavior for different types. Returns ------- @@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw """ # for quant_output function, set grad to zero if the absolute value of tensor is larger than 1 - if quant_type == QuantType.QUANT_OUTPUT: + if quant_type == QuantType.OUTPUT: grad_output[torch.abs(tensor) > 1] = 0 return grad_output diff --git a/examples/model_compress/quantization/QAT_torch_quantizer.py b/examples/model_compress/quantization/QAT_torch_quantizer.py index fd8e248a53..fd174588c2 100644 --- a/examples/model_compress/quantization/QAT_torch_quantizer.py +++ b/examples/model_compress/quantization/QAT_torch_quantizer.py @@ -2,11 +2,13 @@ import torch.nn.functional as F from torchvision import datasets, transforms from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer +from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype import sys sys.path.append('../models') from mnist.naive import NaiveModel + def train(model, device, train_loader, optimizer): model.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -58,22 +60,32 @@ def main(): # {'quant_types': ['input'], 'op_names': ['b']} in the configure_list. configure_list = [{ - 'quant_types': ['weight', 'input'], - 'quant_bits': {'weight': 8, 'input': 8}, - 'op_names': ['conv1', 'conv2'] - }, { - 'quant_types': ['output'], - 'quant_bits': {'output': 8, }, - 'op_names': ['relu1', 'relu2'] - }, { - 'quant_types': ['output', 'weight', 'input'], - 'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, - 'op_names': ['fc1'], - }, { - 'quant_types': ['output', 'weight', 'input'], - 'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, - 'op_names': ['fc2'], - }] + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv1', 'conv2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8, }, + 'op_names': ['relu1', 'relu2'] + }, { + 'quant_types': ['output', 'weight', 'input'], + 'quant_bits': {'output': 8, 'weight': 8, 'input': 8}, + 'op_names': ['fc1', 'fc2'], + }] + + # you can also set the quantization dtype and scheme layer-wise through configure_list like: + # configure_list = [{ + # 'quant_types': ['weight', 'input'], + # 'quant_bits': {'weight': 8, 'input': 8}, + # 'op_names': ['conv1', 'conv2'], + # 'quant_dtype': 'int', + # 'quant_scheme': 'per_channel_symmetric' + # }] + # For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine, + # per_tensor_symmetric, per_channel_affine and per_channel_symmetric. + set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int') + set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int') + set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int') model = NaiveModel().to(device) dummy_input = torch.randn(1, 1, 28, 28).to(device) @@ -98,5 +110,6 @@ def main(): calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device) print("Generated calibration config is: ", calibration_config) + if __name__ == '__main__': main() diff --git a/nni/algorithms/compression/pytorch/quantization/observers.py b/nni/algorithms/compression/pytorch/quantization/observers.py deleted file mode 100644 index 7631f46ccd..0000000000 --- a/nni/algorithms/compression/pytorch/quantization/observers.py +++ /dev/null @@ -1,3 +0,0 @@ -from torch.quantization import default_weight_observer, default_histogram_observer - -__all__ = ["default_weight_observer", "default_histogram_observer"] diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index fb92d0f64a..e273b4b196 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,9 +6,21 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema -from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType - -from .observers import default_weight_observer, default_histogram_observer +from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad +from nni.compression.pytorch.quantization.literal import ( + PER_CHANNEL_QUANT_SCHEME, + QuantScheme, + QuantDtype, + QuantType +) +from nni.compression.pytorch.quantization.observers import default_weight_observer, default_histogram_observer +from nni.compression.pytorch.quantization.settings import LayerQuantSetting +from nni.compression.pytorch.quantization.utils import ( + calculate_qmin_qmax, + get_bits_length, + get_min_max_value, + get_quant_shape +) __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer'] @@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay): return biased_ema -def update_quantization_param(bits, rmin, rmax): +def update_quantization_param(bits, rmin, rmax, dtype, scheme): """ calculate the `zero_point` and `scale`. @@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax): min value of real value rmax : Tensor max value of real value - + dtype : QuantDtype + quantized data type + scheme : QuantScheme + quantization scheme to be used Returns ------- float, float """ + # extend the [min, max] interval to ensure that it contains 0. # Otherwise, we would not meet the requirement that 0 be an exactly # representable value. - rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device)) - rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device)) - qmin = torch.Tensor([0]).to(rmin.device) - qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device) - - # First determine the scale. - scale = (rmax - rmin) / (qmax - qmin) - - # Zero-point computation. - initial_zero_point = qmin - rmin / scale - - # Now we need to nudge the zero point to be an integer - if initial_zero_point < qmin: - nudged_zero_point = qmin - elif initial_zero_point > qmax: - nudged_zero_point = qmax + # I think this is for activations that need to be pad in the training. + # However this is a default behavior in PyTorch quantization observer. + # So we also make it a default behavior + rmin = torch.min(rmin, torch.zeros_like(rmin)) + rmax = torch.max(rmax, torch.zeros_like(rmax)) + zero_point = torch.zeros_like(rmin) + + # todo: there is no need to calculate qmin and qmax again + qmin, qmax = calculate_qmin_qmax(bits, dtype) + + if scheme in [QuantScheme.PER_TENSOR_SYMMETRIC, QuantScheme.PER_CHANNEL_SYMMETRIC]: + abs_max = torch.max(torch.abs(rmin), torch.abs(rmax)) + scale = abs_max / (float(qmax - qmin) / 2) + if dtype == QuantDtype.UINT: + zero_point_val = (qmin + qmax) // 2 + zero_point = zero_point.new_full(zero_point.size(), zero_point_val) else: - nudged_zero_point = torch.round(initial_zero_point) + scale = (rmax - rmin) / float(qmax - qmin) + zero_point = qmin - torch.round(rmin / scale) - return scale, nudged_zero_point + zero_point = torch.clamp(zero_point, qmin, qmax) + # todo: add these lines + # eps = torch.finfo(torch.float32).eps + # scale = torch.max(scale, eps) + + return scale, zero_point -def get_bits_length(config, quant_type): - if isinstance(config["quant_bits"], int): - return config["quant_bits"] - else: - return config["quant_bits"].get(quant_type) class QATGrad(QuantGrad): @staticmethod @@ -384,22 +401,49 @@ def __init__(self, model, config_list, optimizer, dummy_input=None): self.bound_model.register_buffer("steps", torch.tensor(1)) for layer, config in modules_to_compress: module = layer.module - module.register_buffer("zero_point", torch.tensor([0.0])) - module.register_buffer("scale", torch.tensor([1.0])) - module.register_buffer('ema_decay', torch.tensor([0.99])) + name = layer.name + # TODO: may relax this limitation? + assert name in self.all_shapes, "Could not found shapes for layer {}".format(name) + input_shape, output_shape = self.all_shapes[name] + layer_quant_setting = LayerQuantSetting(config) + layer_quant_setting.ema_decay = 0.99 + quant_start_step = config.get('quant_start_step', 0) + layer_quant_setting.quant_start_step = quant_start_step + # todo: support other ranks and remove this check + if isinstance(module, torch.nn.Linear): + if "input" in config.get("quant_types", []) and \ + layer_quant_setting.input.quant_scheme in PER_CHANNEL_QUANT_SCHEME: + if len(input_shape) != 2: + logger.warning("When quantize torch.nn.Linear, make sure that the rank of the inputs " + "of the layer is 2. Skip quantization of layer %s.", name) + continue + if "output" in config.get("quant_types", []) and \ + layer_quant_setting.output.quant_scheme in PER_CHANNEL_QUANT_SCHEME: + if len(output_shape) != 2: + logger.warning("When quantize torch.nn.Linear, make sure that the rank of the outputs " + "of the layer is 2. Skip quantization of layer %s.", name) + continue + if "weight" in config.get("quant_types", []): - weight_bits = get_bits_length(config, 'weight') - layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)])) + quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme) + module.register_buffer('weight_scale', torch.zeros(quant_shape)) + module.register_buffer('weight_zero_point', torch.zeros(quant_shape)) + if "input" in config.get("quant_types", []): - input_bits = get_bits_length(config, 'input') - layer.module.register_buffer('tracked_min_input', torch.zeros(1)) - layer.module.register_buffer('tracked_max_input', torch.zeros(1)) - layer.module.register_buffer('input_bits', torch.Tensor([int(input_bits)])) + quant_shape = get_quant_shape(input_shape, QuantType.INPUT, layer_quant_setting.input.quant_scheme) + module.register_buffer('tracked_min_input', torch.zeros(quant_shape)) + module.register_buffer('tracked_max_input', torch.zeros(quant_shape)) + module.register_buffer('input_scale', torch.zeros(quant_shape)) + module.register_buffer('input_zero_point', torch.zeros(quant_shape)) + if "output" in config.get("quant_types", []): - output_bits = get_bits_length(config, 'output') - layer.module.register_buffer('output_bits', torch.Tensor([int(output_bits)])) - layer.module.register_buffer('tracked_min_output', torch.zeros(1)) - layer.module.register_buffer('tracked_max_output', torch.zeros(1)) + quant_shape = get_quant_shape(output_shape, QuantType.OUTPUT, layer_quant_setting.output.quant_scheme) + module.register_buffer('tracked_min_output', torch.zeros(quant_shape)) + module.register_buffer('tracked_max_output', torch.zeros(quant_shape)) + module.register_buffer('output_scale', torch.zeros(quant_shape)) + module.register_buffer('output_zero_point', torch.zeros(quant_shape)) + + setattr(module, "layer_quant_setting", layer_quant_setting) self.bound_model.to(device) def _del_simulated_attr(self, module): @@ -407,8 +451,9 @@ def _del_simulated_attr(self, module): delete redundant parameters in quantize module """ del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output', - 'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits', - 'output_bits', 'BN_FOLD_TAG', 'input_bits'] + 'tracked_min_input', 'tracked_max_input', 'BN_FOLD_TAG', + 'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point', + 'output_scale', 'output_zero_point', 'layer_quant_setting'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -422,6 +467,7 @@ def validate_config(self, model, config_list): config_list : list of dict List of configurations """ + SUPPORTED_OPS = ['Conv2d', 'Linear', 'ReLU', 'ReLU6'] schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ @@ -429,41 +475,51 @@ def validate_config(self, model, config_list): Optional('weight'): And(int, lambda n: 0 < n < 32), Optional('output'): And(int, lambda n: 0 < n < 32), })), + Optional('quant_scheme'): Or(lambda x: x in QuantScheme, Schema({ + Optional('input'): lambda x: x in QuantScheme, + Optional('weight'): lambda x: x in QuantScheme, + Optional('output'): lambda x: x in QuantScheme + })), + Optional('quant_dtype'): Or(lambda x: x in QuantDtype, Schema({ + Optional('input'): lambda x: x in QuantDtype, + Optional('weight'): lambda x: x in QuantDtype, + Optional('output'): lambda x: x in QuantDtype + })), Optional('quant_start_step'): And(int, lambda n: n >= 0), - Optional('op_types'): [str], + Optional('op_types'): [And(str, lambda n: n in SUPPORTED_OPS)], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list) - def _quantize(self, bits, op, real_val): + def _quantize(self, real_value, scale, zero_point, qmin, qmax): """ quantize real value. Parameters ---------- - bits : int - quantization bits length - op : torch.nn.Module - target module - real_val : Tensor - real value to be quantized + real_value : torch.Tensor + the real value to be quantized + scale : torch.Tensor + quantization scale + zero_point : torch.Tensor + quantization zero point + qmin : int + lower bound of the int range + qmax : int + upper bound of the int range Returns ------- Tensor """ - op.zero_point = op.zero_point.to(real_val.device) - op.scale = op.scale.to(real_val.device) - transformed_val = op.zero_point + real_val / op.scale - qmin = 0 - qmax = (1 << bits) - 1 + transformed_val = zero_point + real_value / scale clamped_val = torch.clamp(transformed_val, qmin, qmax) quantized_val = torch.round(clamped_val) return quantized_val - def _dequantize(self, op, quantized_val): + def _dequantize(self, quantized_val, scale, zero_point): """ dequantize quantized value. Because we simulate quantization in training process, all the computations still happen as float point computations, which means we @@ -471,103 +527,149 @@ def _dequantize(self, op, quantized_val): Parameters ---------- - op : torch.nn.Module - target module - quantized_val : float - quantized_val value to be dequantized + quantized_val : torch.Tensor + the quantized value to be de-quantized + scale : torch.Tensor + quantization scale + zero_point : torch.Tensor + quantization zero point Returns ------- - float + Tensor """ - real_val = op.scale * (quantized_val - op.zero_point) + real_val = scale * (quantized_val - zero_point) return real_val def quantize_weight(self, wrapper, **kwargs): - config = wrapper.config module = wrapper.module weight = module.weight - weight_bits = int(module.weight_bits) - quant_start_step = config.get('quant_start_step', 0) - assert weight_bits >= 1, "quant bits length should be at least 1" + layer_quant_setting = module.layer_quant_setting + tensor_quant_setting = layer_quant_setting.weight - if quant_start_step > int(self.bound_model.steps): - return weight + # layer-wise settings + quant_start_step = layer_quant_setting.quant_start_step + + # tensor-wise settings + dtype = tensor_quant_setting.quant_dtype + scheme = tensor_quant_setting.quant_scheme + qmin, qmax = tensor_quant_setting.get_qmin_qmax() + bits = tensor_quant_setting.bits + # In evaluation mode, we only quantize weight without updating statistics if not wrapper.training: + scale, zero_point = module.weight_scale, module.weight_zero_point + weight = self._quantize(weight, scale, zero_point, qmin, qmax) + weight = self._dequantize(weight, scale, zero_point) + module.weight = weight return weight - # quantize weight - rmin, rmax = torch.min(weight), torch.max(weight) - scale, zero_point = update_quantization_param(weight_bits, rmin, rmax) - module.scale.copy_(scale) - module.zero_point.copy_(zero_point) - weight = self._quantize(weight_bits, module, weight) - weight = self._dequantize(module, weight) + if quant_start_step > int(self.bound_model.steps): + return weight + + current_min, current_max = get_min_max_value(weight, QuantType.WEIGHT, scheme) + scale, zero_point = update_quantization_param(bits, current_min, current_max, dtype, scheme) + module.weight_scale.copy_(scale) + module.weight_zero_point.copy_(zero_point) + weight = self._quantize(weight, scale, zero_point, qmin, qmax) + weight = self._dequantize(weight, scale, zero_point) + # Weight can not be in-place modified, so when use torch.nn.DataParallel, this update + # will be lost after each forward process. However, this update takes effect on each + # replicated module during each forward process, which will make the quantized weight + # be used correctly. wrapper.module.weight = weight return weight def quantize_input(self, inputs, wrapper, **kwargs): - config = wrapper.config module = wrapper.module - input_bits = int(module.input_bits) - quant_start_step = config.get('quant_start_step', 0) - assert input_bits >= 1, "quant bits length should be at least 1" - if quant_start_step > int(self.bound_model.steps): - current_min, current_max = torch.min(inputs), torch.max(inputs) - module.tracked_min_input.copy_(current_min) - module.tracked_max_input.copy_(current_max) + layer_quant_setting = module.layer_quant_setting + tensor_quant_setting = layer_quant_setting.input + + # layer-wise settings + quant_start_step = layer_quant_setting.quant_start_step + ema_decay = layer_quant_setting.ema_decay + + # tensor-wise settings + dtype = tensor_quant_setting.quant_dtype + scheme = tensor_quant_setting.quant_scheme + qmin, qmax = tensor_quant_setting.get_qmin_qmax() + bits = tensor_quant_setting.bits + + if not wrapper.training: + scale = module.input_scale + zero_point = module.input_zero_point + inputs = self._quantize(inputs, scale, zero_point, qmin, qmax) + inputs = self._dequantize(inputs, scale, zero_point) return inputs - # we dont update output quantization parameters in evaluation stage - if wrapper.training: - current_min, current_max = torch.min(inputs), torch.max(inputs) - current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay) - current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay) + current_min, current_max = get_min_max_value(inputs, QuantType.INPUT, scheme) + + if int(self.bound_model.steps) == 1: module.tracked_min_input.copy_(current_min) module.tracked_max_input.copy_(current_max) + tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay) + tracked_max_input = update_ema(module.tracked_max_input, current_max, ema_decay) + module.tracked_min_input.copy_(tracked_min_input) + module.tracked_max_input.copy_(tracked_max_input) + + if quant_start_step > int(self.bound_model.steps): + return inputs + scale, zero_point = update_quantization_param( - input_bits, module.tracked_min_input, module.tracked_max_input) - module.scale.copy_(scale) - module.zero_point.copy_(zero_point) + bits, module.tracked_min_input, module.tracked_max_input, dtype, scheme) + module.input_scale.copy_(scale) + module.input_zero_point.copy_(zero_point) - inp = self._quantize(input_bits, module, inputs) - inp = self._dequantize(module, inp) - return inp + inputs = self._quantize(inputs, scale, zero_point, qmin, qmax) + inputs = self._dequantize(inputs, scale, zero_point) + return inputs def quantize_output(self, output, wrapper, **kwargs): - config = wrapper.config module = wrapper.module - output_bits = int(module.output_bits) - quant_start_step = config.get('quant_start_step', 0) - assert output_bits >= 1, "quant bits length should be at least 1" + layer_quant_setting = module.layer_quant_setting + tensor_quant_setting = layer_quant_setting.output - if quant_start_step > int(self.bound_model.steps): - current_min, current_max = torch.min(output), torch.max(output) + # layer-wise settings + quant_start_step = layer_quant_setting.quant_start_step + ema_decay = layer_quant_setting.ema_decay + + # tensor-wise settings + dtype = tensor_quant_setting.quant_dtype + scheme = tensor_quant_setting.quant_scheme + qmin, qmax = tensor_quant_setting.get_qmin_qmax() + bits = tensor_quant_setting.bits + + if not wrapper.training: + scale = module.output_scale + zero_point = module.output_zero_point + output = self._quantize(output, scale, zero_point, qmin, qmax) + output = self._dequantize(output, scale, zero_point) + return output + + current_min, current_max = get_min_max_value(output, QuantType.OUTPUT, scheme) + + if int(self.bound_model.steps) == 1: module.tracked_min_output.copy_(current_min) module.tracked_max_output.copy_(current_max) - return output - # we dont update output quantization parameters in evaluation stage - if wrapper.training: - current_min, current_max = torch.min(output), torch.max(output) - tracked_min_output = update_ema(module.tracked_min_output, current_min, - module.ema_decay) - tracked_max_output = update_ema(module.tracked_max_output, current_max, - module.ema_decay) - module.tracked_min_output.copy_(tracked_min_output) - module.tracked_max_output.copy_(tracked_max_output) + tracked_min_output = update_ema(module.tracked_min_output, current_min, ema_decay) + tracked_max_output = update_ema(module.tracked_max_output, current_max, ema_decay) + module.tracked_min_output.copy_(tracked_min_output) + module.tracked_max_output.copy_(tracked_max_output) + + if quant_start_step > int(self.bound_model.steps): + return output scale, zero_point = update_quantization_param( - output_bits, module.tracked_min_output, module.tracked_max_output) - module.scale.copy_(scale) - module.zero_point.copy_(zero_point) + bits, module.tracked_min_output, module.tracked_max_output, dtype, scheme) + module.output_scale.copy_(scale) + module.output_zero_point.copy_(zero_point) - out = self._quantize(output_bits, module, output) - out = self._dequantize(module, out) - return out + output = self._quantize(output, scale, zero_point, qmin, qmax) + output = self._dequantize(output, scale, zero_point) + return output def load_calibration_config(self, calibration_config): modules_to_compress = self.get_modules_to_compress() @@ -581,12 +683,12 @@ def load_calibration_config(self, calibration_config): assert calibration_config[name]['weight_bits'] == module.weight_bits, f"weight bits of module {name} fail to match" if hasattr(module, 'input_bits'): assert calibration_config[name]['input_bits'] == module.input_bits, f"input bits of module {name} fail to match" - module.tracked_min_input.data = torch.Tensor([calibration_config[name]['tracked_min_input']]) - module.tracked_max_input.data = torch.Tensor([calibration_config[name]['tracked_max_input']]) + module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']]) + module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']]) if hasattr(module, 'output_bits'): assert calibration_config[name]['output_bits'] == module.output_bits, f"output bits of module {name} fail to match" - module.tracked_min_output.data = torch.Tensor([calibration_config[name]['tracked_min_output']]) - module.tracked_max_output.data = torch.Tensor([calibration_config[name]['tracked_max_output']]) + module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']]) + module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']]) def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): """ @@ -619,6 +721,8 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name] = {} if hasattr(module, 'weight_bits'): calibration_config[name]['weight_bits'] = int(module.weight_bits) + calibration_config[name]['weight_scale'] = module.weight_scale + calibration_config[name]['weight_zero_point'] = module.weight_zero_point # Recover weight/bias for batch normalization folding actual_weight = getattr(module, 'old_weight', None) @@ -759,7 +863,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ class ClipGrad(QuantGrad): @staticmethod def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax): - if quant_type == QuantType.QUANT_OUTPUT: + if quant_type == QuantType.OUTPUT: grad_output[torch.abs(tensor) > 1] = 0 return grad_output diff --git a/nni/common/version.py b/nni/common/version.py new file mode 100644 index 0000000000..b8881f48ad --- /dev/null +++ b/nni/common/version.py @@ -0,0 +1,7 @@ +import logging +try: + import torch + TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) +except Exception: + logging.info("PyTorch is not installed.") + TORCH_VERSION = None diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 4a4cce9bdb..7049504d72 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import copy import types import logging import torch from nni.common.graph_utils import build_module_graph +from nni.compression.pytorch.quantization.literal import QuantType, BN_FOLD_OP, BN_FOLD_TAG +from nni.compression.pytorch.quantization.observers import RecordingObserver from . import default_layers _logger = logging.getLogger(__name__) @@ -547,7 +550,7 @@ def forward(self, *inputs): assert len(inputs) == 1, "Quantization of input only supports ops with single input." new_inp = self.quantizer.quant_grad( inputs[0], - QuantType.QUANT_INPUT, + QuantType.INPUT, self) inputs = (new_inp,) @@ -563,7 +566,7 @@ def forward(self, *inputs): self.quantizer.quant_grad( new_weight, - QuantType.QUANT_WEIGHT, + QuantType.WEIGHT, self, inputs[0]) result = self.module(*inputs) @@ -571,7 +574,7 @@ def forward(self, *inputs): if 'output' in self.config['quant_types']: result = self.quantizer.quant_grad( result, - QuantType.QUANT_OUTPUT, + QuantType.OUTPUT, self) return result @@ -604,10 +607,13 @@ class Quantizer(Compressor): def __init__(self, model, config_list, optimizer=None, dummy_input=None): if isinstance(model, torch.nn.DataParallel): model = model.module + model_copied = copy.deepcopy(model) self.identity_wrappers = [] self.conv_bn_patterns = {} self.find_conv_bn_patterns(model, dummy_input) super().__init__(model, config_list, optimizer) + self.all_shapes = {} + self.record_shape(model_copied, dummy_input) self.quant_grad = QuantGrad.apply if self.optimizer is not None: self.patch_optimizer(self.step_with_optimizer) @@ -845,25 +851,54 @@ def find_conv_bn_patterns(self, model, dummy_input): if successor.op_type == 'BatchNorm2d': self.conv_bn_patterns[node_group.name] = successor.name - def step_with_optimizer(self): - pass + def record_shape(self, model, dummy_input): + """ + Record input/output's shapes of each module to be quantized -class QuantType: - """ - Enum class for quantization type. - """ - QUANT_INPUT = 0 - QUANT_WEIGHT = 1 - QUANT_OUTPUT = 2 + Parameters + ---------- + model : torch.nn.Module + model to be recorded. + dummy_input : tupel of torch.tensor + inputs to the model. + """ + def _pre_forward_hook(self, inp): + # Only record the first tensor of the input + return self.pre_forward(inp[0]) + + def _post_forward_hook(self, _, out): + return self.post_forward(out) + + if dummy_input is None: + return + + all_handles = [] + all_observers = {} + modules_to_compress = self.get_modules_to_compress() + compress_names = [layer_info[0].name for layer_info in modules_to_compress] + for name, module in model.named_modules(): + if name in compress_names: + all_observers[name] = {} + all_observers[name]['input_hook'] = RecordingObserver() + all_observers[name]['output_hook'] = RecordingObserver() + module.add_module('pre_forward', all_observers[name]['input_hook']) + module.add_module('post_forward', all_observers[name]['output_hook']) + all_handles.append(module.register_forward_pre_hook(_pre_forward_hook)) + all_handles.append(module.register_forward_hook(_post_forward_hook)) + model(dummy_input) + for name, hooks in all_observers.items(): + # only support single input + input_val = hooks['input_hook'].tensor_val + input_shape = input_val[0].shape if input_val else None + output_val = hooks['output_hook'].tensor_val + output_shape = output_val[0].shape if output_val else None + shapes = [input_shape, output_shape] + self.all_shapes[name] = shapes + return -QType_Dict = { - 0: "input", - 1: "weight", - 2: "output" -} + def step_with_optimizer(self): + pass -BN_FOLD_OP = ["Conv2d"] -BN_FOLD_TAG = 'BN_FOLD_TAG' class QuantGrad(torch.autograd.Function): """ @@ -920,8 +955,8 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma grad_output : Tensor gradient of the output of quantization operation scale : Tensor - the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, - `QuantType.QUANT_OUTPUT`, you can define different behavior for different types. + the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, + `QuantType.OUTPUT`, you can define different behavior for different types. zero_point : Tensor zero_point for quantizing tensor qmin : Tensor @@ -939,28 +974,39 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) - bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) - qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) - if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'): + if hasattr(wrapper.module, "layer_quant_setting"): + layer_quant_setting = wrapper.module.layer_quant_setting + qmin, qmax = getattr(layer_quant_setting, quant_type).get_qmin_qmax() + else: + # todo: when dtype/scheme customization is ready for all quantizers, remove this + bits = QuantGrad.get_bits_length(wrapper.config, quant_type) + qmin, qmax = 0, (1 << bits) - 1 + + scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name() + if hasattr(wrapper.module, scale_name) and hasattr(wrapper.module, zero_point_name): + scale = getattr(wrapper.module, scale_name) + zero_point = getattr(wrapper.module, zero_point_name) + # todo: remove this when other quantizers use different scale & zero point for input/weight/output + elif hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'): scale = wrapper.module.scale zero_point = wrapper.module.zero_point else: scale, zero_point = None, None - ctx.save_for_backward(tensor) # Only tensors have gradients flowing back needs to be saved by save_for_backward. # Others should directly assign to ctx. - ctx.scale = scale - ctx.zero_point = zero_point + ctx.save_for_backward(tensor) ctx.quant_type = quant_type ctx.qmin, ctx.qmax = qmin, qmax + ctx.scale = scale + ctx.zero_point = zero_point return output @classmethod def backward(cls, ctx, grad_output): tensor = ctx.saved_variables[0] scale, zero_point = ctx.scale, ctx.zero_point - qmin, qmax = ctx.qmin, ctx.qmax quant_type = ctx.quant_type + qmin, qmax = ctx.qmin, ctx.qmax output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax) return output, None, None, None @@ -977,11 +1023,11 @@ def _check_bias(module): return False def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): - if quant_type == QuantType.QUANT_INPUT: + if quant_type == QuantType.INPUT: output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs) - elif quant_type == QuantType.QUANT_WEIGHT: + elif quant_type == QuantType.WEIGHT: output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) - elif quant_type == QuantType.QUANT_OUTPUT: + elif quant_type == QuantType.OUTPUT: output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) else: raise ValueError("unrecognized QuantType.") diff --git a/nni/compression/pytorch/quantization/literal.py b/nni/compression/pytorch/quantization/literal.py new file mode 100644 index 0000000000..eaad1dcf25 --- /dev/null +++ b/nni/compression/pytorch/quantization/literal.py @@ -0,0 +1,65 @@ +from enum import Enum, EnumMeta + + +class _QuantLiteralEnumMeta(EnumMeta): + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True + + +class _QuantLiteralEnum(Enum, metaclass=_QuantLiteralEnumMeta): + pass + + +class QuantScheme(str, _QuantLiteralEnum): + PER_TENSOR_AFFINE = 'per_tensor_affine' + PER_TENSOR_SYMMETRIC = 'per_tensor_symmetric' + PER_CHANNEL_AFFINE = 'per_channel_affine' + PER_CHANNEL_SYMMETRIC = 'per_channel_symmetric' + + +PER_CHANNEL_QUANT_SCHEME = [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC] + + +class QuantDtype(str, _QuantLiteralEnum): + UINT = 'uint' + INT = 'int' + + +class QuantType(str, _QuantLiteralEnum): + INPUT = 'input' + WEIGHT = 'weight' + OUTPUT = 'output' + + def type_to_scale_zero_point_name(self): + if self == QuantType.INPUT: + return 'input_scale', 'input_zero_point' + elif self == QuantType.WEIGHT: + return 'weight_scale', 'weight_zero_point' + elif self == QuantType.OUTPUT: + return 'output_scale', 'output_zero_point' + else: + raise TypeError + + +# Just show each attribute's name, no practical effect +class QuantConfigLiteral(str, _QuantLiteralEnum): + QUANT_SETTINGS = 'quant_settings' + QUANT_SCHEME = 'quant_scheme' + QUANT_DTYPE = 'quant_dtype' + BITS = 'bits' + QMIN = 'qmin' + QMAX = 'qmax' + INPUT_SCALE = 'input_scale' + INPUT_ZERO_POINT = 'input_zero_point' + OUTPUT_SCALE = 'output_scale' + OUTPUT_ZERO_POINT = 'output_zero_point' + WEIGHT_SCALE = 'weight_scale' + WEIGHT_ZERO_POINT = 'weight_zero_point' + + +BN_FOLD_OP = ["Conv2d"] +BN_FOLD_TAG = 'BN_FOLD_TAG' diff --git a/nni/compression/pytorch/quantization/observers.py b/nni/compression/pytorch/quantization/observers.py new file mode 100644 index 0000000000..bd7b2bc288 --- /dev/null +++ b/nni/compression/pytorch/quantization/observers.py @@ -0,0 +1,15 @@ +from torch.quantization import default_weight_observer, default_histogram_observer +from torch.quantization import RecordingObserver as _RecordingObserver + +__all__ = ["default_weight_observer", "default_histogram_observer", "RecordingObserver"] + + +class RecordingObserver(_RecordingObserver): + """ + A extended version of PyTorch's RecordingObserver, used to record gpu tensor + """ + + def forward(self, x): + val = x.cpu() + super().forward(val) + return x diff --git a/nni/compression/pytorch/quantization/settings.py b/nni/compression/pytorch/quantization/settings.py new file mode 100644 index 0000000000..b4206e239a --- /dev/null +++ b/nni/compression/pytorch/quantization/settings.py @@ -0,0 +1,118 @@ +from typing import Any, Optional + +from .literal import QuantDtype, QuantType, QuantScheme +from .utils import calculate_qmin_qmax, get_bits_length + + +# default settings for quantization module +quant_default_settings = { + QuantType.WEIGHT: { + 'quant_scheme': QuantScheme.PER_TENSOR_AFFINE, + 'quant_dtype': QuantDtype.UINT, + }, + QuantType.INPUT: { + 'quant_scheme': QuantScheme.PER_TENSOR_AFFINE, + 'quant_dtype': QuantDtype.UINT + }, + QuantType.OUTPUT: { + 'quant_scheme': QuantScheme.PER_TENSOR_AFFINE, + 'quant_dtype': QuantDtype.UINT + } +} + + +class TensorQuantSetting(object): + def __init__(self, **kwargs): + self._fields = {} + for k, v in kwargs.items(): + self._fields[k] = v + + def __setattr__(self, name: str, val: Any) -> None: + if name.startswith("_"): + super().__setattr__(name, val) + else: + self._fields[name] = val + + def __getattr__(self, name): + if name == "_fields" or name not in self._fields: + raise AttributeError("Cannot find {} in TensorQuantSetting!".format(name)) + return self._fields[name] + + def get_qmin_qmax(self): + assert 'qmin' in self._fields and 'qmax' in self._fields, \ + "Can not found qmin & qmax in TensorQuantSetting" + return self._fields['qmin'], self._fields['qmax'] + + +class LayerQuantSetting(object): + def __init__(self, config): + self.input: Optional[TensorQuantSetting] = None + self.weight: Optional[TensorQuantSetting] = None + self.output: Optional[TensorQuantSetting] = None + self._extra_layer_setting = {} + + for quant_type in QuantType: + if quant_type in config.get("quant_types", []): + setting = TensorQuantSetting() + + quant_scheme = self.parse_optional_config(config, quant_type, 'quant_scheme') + setting.quant_scheme = quant_scheme + quant_dtype = self.parse_optional_config(config, quant_type, 'quant_dtype') + setting.quant_dtype = quant_dtype + + bits = get_bits_length(config, quant_type) + qmin, qmax = calculate_qmin_qmax(bits, quant_dtype) + setting.bits = bits + setting.qmin = qmin + setting.qmax = qmax + setattr(self, quant_type, setting) + + def __setattr__(self, name: str, val: Any) -> None: + if name.startswith("_") or name in QuantType: + super().__setattr__(name, val) + else: + self._extra_layer_setting[name] = val + + def __getattr__(self, name): + if name == "_extra_layer_setting" or name not in self._extra_layer_setting: + raise AttributeError("Cannot find {} in LayerQuantSetting!".format(name)) + return self._extra_layer_setting[name] + + @staticmethod + def parse_optional_config(config, quant_type, target): + def get_config(config, quant_type, target): + if not config.get(target): + return None + + if isinstance(config[target], dict): + return config[target].get(quant_type) + else: + return config[target] + + default_val = quant_default_settings[quant_type].get(target, None) + config_val = get_config(config, quant_type, target) + val = config_val if config_val else default_val + return val + + +def set_quant_scheme_dtype(quant_type, new_scheme=None, new_dtype=None): + # todo: remove this if we convert string config to enum type. + if isinstance(quant_type, str): + assert quant_type in QuantType, "Wrong quant_type" + if isinstance(new_scheme, str): + assert new_scheme in QuantScheme, "Wrong quant_scheme" + if isinstance(new_dtype, str): + assert new_dtype in QuantDtype, "Wrong quant_dtype" + + # TODO: It is not a good idea to directly modify global settings. A better choice is + # making this function an attribute function of Quantizer and call this function after + # the quantizer is initialized. However, within current framework of quantization, if + # we want to modify the dtype & scheme when the quantizer is initialized, we must do + # some other things (like changing the shapes of scales and zero_points and other quantization + # information in the subclass). + global quant_default_settings + if new_scheme is not None: + quant_default_settings[quant_type]['quant_scheme'] = new_scheme + if new_dtype is not None: + quant_default_settings[quant_type]['quant_dtype'] = new_dtype + return diff --git a/nni/compression/pytorch/quantization/utils.py b/nni/compression/pytorch/quantization/utils.py new file mode 100644 index 0000000000..a5131d176f --- /dev/null +++ b/nni/compression/pytorch/quantization/utils.py @@ -0,0 +1,83 @@ +import torch + +from nni.common.version import TORCH_VERSION + +from .literal import QuantDtype, QuantScheme, QuantType + + +def calculate_qmin_qmax(bits, dtype): + if dtype == QuantDtype.INT: + qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1 + elif dtype == QuantDtype.UINT: + qmin, qmax = 0, 2 ** bits - 1 + else: + raise TypeError("Wrong quantization dtype, please make sure it is one of 'int' and 'uint'.") + return qmin, qmax + + +def get_bits_length(config, quant_type): + if isinstance(config["quant_bits"], int): + return config["quant_bits"] + else: + return config["quant_bits"].get(quant_type) + + +def get_target_dim(quant_type, quant_scheme): + # for weight: c_out x c_in x (h) * (w) + # for feature maps: batch * channel * (t) * h * w + # other type is not supported for now + default_idx = 0 if quant_type == QuantType.WEIGHT else 1 + if is_per_channel(quant_scheme): + target_dim = default_idx + else: + target_dim = None + return target_dim + + +def get_min_max_value(x, quant_type, quant_scheme): + + target_dim = get_target_dim(quant_type, quant_scheme) + if target_dim is None: + return torch.min(x), torch.max(x) + + indices = list(range(len(x.shape))) + assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor" + del indices[target_dim] + + if TORCH_VERSION > (1, 6): + min_val = torch.amin(x, indices, keepdims=True) + max_val = torch.amax(x, indices, keepdims=True) + else: + min_val = max_val = x + for ind in indices: + min_val = torch.min(min_val, dim=ind, keepdim=True)[0] + max_val = torch.max(max_val, dim=ind, keepdim=True)[0] + return min_val, max_val + + +def get_mean_value(x, target_dim=None): + if target_dim is None: + return torch.mean(x) + + indices = list(range(len(x.shape))) + assert target_dim < len(indices), "target_dim needs to be less than the number of dim of the tensor" + del indices[target_dim] + + mean_val = torch.mean(x, dim=indices, keepdim=True) + return mean_val + + +def is_per_channel(quant_scheme): + if quant_scheme in [QuantScheme.PER_CHANNEL_AFFINE, QuantScheme.PER_CHANNEL_SYMMETRIC]: + return True + else: + return False + + +def get_quant_shape(shape, quant_type, quant_scheme): + default_idx = 0 if quant_type == QuantType.WEIGHT else 1 + if is_per_channel(quant_scheme): + quant_shape = [1 if idx != default_idx else s for idx, s in enumerate(shape)] + else: + quant_shape = [] + return quant_shape diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index 61611596e5..daad2af930 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -9,6 +9,7 @@ import schema import nni.algorithms.compression.pytorch.pruning as torch_pruner import nni.algorithms.compression.pytorch.quantization as torch_quantizer +from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape, get_min_max_value import math @@ -50,7 +51,8 @@ def test_torch_quantizer_modules_detection(self): model.relu = torch.nn.ReLU() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) + dummy = torch.randn(1, 1, 28, 28) + quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy) quantizer.compress() modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress_name = [t[0].name for t in modules_to_compress] @@ -332,6 +334,130 @@ def test_torch_quantizer_weight_type(self): self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter)) self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter)) + def test_quantization_dtype_scheme(self): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 2, 3, 1) + self.bn1 = torch.nn.BatchNorm2d(2) + + def forward(self, x): + x = self.bn1(self.conv1(x)) + return x + dtypes = ['int', 'uint'] + qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric'] + for dtype in dtypes: + for qscheme in qschemes: + config_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': 8, + 'op_types': ['Conv2d'], + 'quant_dtype': dtype, + 'quant_scheme': qscheme + }] + model = TestModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + # only QAT_quantizer is supported for now + dummy = torch.randn(1, 1, 4, 4) + quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy) + + # test layer setting + for layer, config in quantizer.modules_to_compress: + module = layer.module + name = layer.name + layer_setting = module.layer_quant_setting + qmin, qmax = calculate_qmin_qmax(8, dtype) + all_quant_types = ['input', 'weight'] + for quant_type in all_quant_types: + # check for settings + tensor_setting = getattr(layer_setting, quant_type) + self.assertTrue(tensor_setting is not None) + self.assertTrue(tensor_setting.quant_scheme == qscheme) + self.assertTrue(tensor_setting.quant_dtype == dtype) + self.assertTrue(tensor_setting.qmin == qmin) + self.assertTrue(tensor_setting.qmax == qmax) + + input_shape, output_shape = quantizer.all_shapes[name] + + shape = input_shape if quant_type == 'input' else module.weight.shape + quant_shape = get_quant_shape(shape, quant_type, qscheme) + scale_name = quant_type + '_scale' + zero_point_name = quant_type + '_zero_point' + scale = getattr(module, scale_name) + zero_point = getattr(module, zero_point_name) + self.assertTrue(list(scale.shape) == quant_shape) + self.assertTrue(list(zero_point.shape) == quant_shape) + + weight = torch.arange(start=1, end=19).view(2, 1, 3, 3) + if qscheme == 'per_channel_symmetric': + if dtype == 'int': + target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1]) + target_zero_point = torch.ones([2, 1, 1, 1]) * 0 + else: + target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1]) + target_zero_point = torch.ones([2, 1, 1, 1]) * 127 + elif qscheme == 'per_tensor_symmetric': + if dtype == 'int': + target_scale = torch.tensor(18. / 127) + target_zero_point = torch.zeros([]) + else: + target_scale = torch.tensor(18. / 127.5) + target_zero_point = torch.ones([]) * 127 + elif qscheme == 'per_channel_affine': + min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1]) + if dtype == 'int': + target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1]) + target_zero_point = -127 - torch.round(min_val / target_scale) + else: + target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1]) + target_zero_point = 0 - torch.round(min_val / target_scale) + else: + if dtype == 'int': + target_scale = torch.tensor(18. / 254) + target_zero_point = -127 - torch.round(0 / target_scale) + else: + target_scale = torch.tensor(18. / 255) + target_zero_point = 0 - torch.round(0 / target_scale) + wrapper = getattr(model, name) + wrapper.module.weight = weight + quantizer.quantize_weight(wrapper) + self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale)) + self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point)) + + inp = torch.arange(start=0, end=16).view(1, 1, 4, 4) + if qscheme == 'per_channel_symmetric': + if dtype == 'int': + target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1]) + target_zero_point = torch.ones([1, 1, 1, 1]) * 0 + else: + target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1]) + target_zero_point = torch.ones([1, 1, 1, 1]) * 127 + elif qscheme == 'per_tensor_symmetric': + if dtype == 'int': + target_scale = torch.tensor(15. / 127) + target_zero_point = torch.zeros([]) + else: + target_scale = torch.tensor(15. / 127.5) + target_zero_point = torch.ones([]) * 127 + elif qscheme == 'per_channel_affine': + min_val = torch.tensor([0.]).view([1, 1, 1, 1]) + if dtype == 'int': + target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1]) + target_zero_point = -127 - torch.round(min_val / target_scale) + else: + target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1]) + target_zero_point = 0 - torch.round(min_val / target_scale) + else: + if dtype == 'int': + target_scale = torch.tensor(15. / 254) + target_zero_point = -127 - torch.round(0 / target_scale) + else: + target_scale = torch.tensor(15. / 255) + target_zero_point = 0 - torch.round(0 / target_scale) + quantizer.quantize_input(inp, wrapper) + self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale)) + self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point)) + def test_torch_QAT_quantizer(self): model = TorchModel() config_list = [{ @@ -347,7 +473,8 @@ def test_torch_QAT_quantizer(self): model.relu = torch.nn.ReLU() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) + dummy = torch.randn(1, 1, 28, 28) + quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy) quantizer.compress() # test quantize @@ -357,20 +484,20 @@ def test_torch_QAT_quantizer(self): weight = torch.tensor([[1, 2], [3, 5]]).float() model.conv2.module.weight.data = weight quantizer.quantize_weight(model.conv2, input_tensor=input) - assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) - assert model.conv2.module.zero_point == 0 + assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps) + assert model.conv2.module.weight_zero_point == 0 quantizer.quantize_input(input, model.conv2) - self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255]))) - self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) + self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255]))) + self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.))) # range including 0 weight = torch.tensor([[-1, 2], [3, 5]]).float() model.conv2.module.weight = weight quantizer.quantize_weight(model.conv2, input_tensor=input) - assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) - assert model.conv2.module.zero_point in (42, 43) + assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps) + assert model.conv2.module.weight_zero_point in (42, 43) quantizer.quantize_input(input, model.conv2) - self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255]))) - self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.]))) + self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255]))) + self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor(0.))) # test value of weight and bias after quantization weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]]) weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]]) @@ -385,15 +512,15 @@ def test_torch_QAT_quantizer(self): # test ema eps = 1e-7 x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) - out = model.relu(x) - assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps) - assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps) + model.relu(x) + self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.))) + self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2))) quantizer.step_with_optimizer() x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) - out = model.relu(x) - assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps) - assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps) + model.relu(x) + self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor(0.002))) + self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor(0.2060))) def test_torch_quantizer_export(self): config_list_qat = [{ @@ -424,12 +551,15 @@ def test_torch_quantizer_export(self): }] config_set = [config_list_qat, config_list_dorefa, config_list_bnn] quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer] - + dummy = torch.randn(1, 1, 28, 28) for config, quantize_algorithm in zip(config_set, quantize_algorithm_set): model = TorchModel() model.relu = torch.nn.ReLU() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - quantizer = quantize_algorithm(model, config, optimizer) + if quantize_algorithm == torch_quantizer.QAT_Quantizer: + quantizer = quantize_algorithm(model, config, optimizer, dummy) + else: + quantizer = quantize_algorithm(model, config, optimizer) quantizer.compress() x = torch.rand((1, 1, 28, 28), requires_grad=True) @@ -461,7 +591,11 @@ def test_quantizer_load_calibration_config(self): model = TorchModel().eval() model.relu = torch.nn.ReLU() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - quantizer = quantize_algorithm(model, configure_list, optimizer) + if quantize_algorithm == torch_quantizer.QAT_Quantizer: + dummy = torch.randn(1, 1, 28, 28) + quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy) + else: + quantizer = quantize_algorithm(model, configure_list, optimizer) quantizer.compress() if calibration_config is not None: quantizer.load_calibration_config(calibration_config)