Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
support dp multi-gpu training for QAT quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Aug 31, 2021
1 parent d204d8b commit 102a0a4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
78 changes: 49 additions & 29 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,21 +373,22 @@ def __init__(self, model, config_list, optimizer, dummy_input=None):
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.Tensor([1]))
self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0]))
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
if "input" in config.get("quant_types", []):
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('tracked_min_input', torch.zeros(1))
module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1))
module.register_buffer('output_bits', torch.zeros(1))
module.register_buffer('tracked_min_output', torch.zeros(1))
module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device)

def _del_simulated_attr(self, module):
Expand Down Expand Up @@ -479,43 +480,55 @@ def quantize_weight(self, wrapper, **kwargs):
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps:
if quant_start_step > int(self.bound_model.steps):
return weight

if not wrapper.training:
return weight

# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bits = torch.Tensor([weight_bits])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight

def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input_bits = get_bits_length(config, 'input')
module.input_bits = torch.Tensor([input_bits])

module.input_bit = torch.tensor([input_bits])
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
return inputs

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)

scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)

inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp)
return inp
Expand All @@ -528,19 +541,26 @@ def quantize_output(self, output, wrapper, **kwargs):
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
return output

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
module.tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)

scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)

out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
Expand Down Expand Up @@ -645,7 +665,7 @@ def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1
self.bound_model.steps.add_(1)


class DoReFaQuantizer(Quantizer):
Expand Down
15 changes: 13 additions & 2 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ class Quantizer(Compressor):
"""

def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
Expand Down Expand Up @@ -892,12 +894,21 @@ def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax)
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
return output

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables
tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None

Expand Down

0 comments on commit 102a0a4

Please sign in to comment.