Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support dp multi-gpu training for QAT quantizer #4127

Merged
merged 2 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,21 +373,22 @@ def __init__(self, model, config_list, optimizer, dummy_input=None):
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.Tensor([1]))
self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0]))
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
module.register_buffer('weight_bits', torch.zeros(1))
if "input" in config.get("quant_types", []):
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('tracked_min_input', torch.zeros(1))
module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1))
module.register_buffer('output_bits', torch.zeros(1))
module.register_buffer('tracked_min_output', torch.zeros(1))
module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device)

def _del_simulated_attr(self, module):
Expand Down Expand Up @@ -479,43 +480,55 @@ def quantize_weight(self, wrapper, **kwargs):
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps:
if quant_start_step > int(self.bound_model.steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we add specific datatype convert here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a good idea to directly compare a python int with a torch.Tensor, which may lead to unpredictable error.

return weight

if not wrapper.training:
return weight

# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between copy_ and assignment directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As described here:

In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded.

If we assign directly, the update of scale & zero_point & tracked information will be lost

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx! got it

weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bits = torch.Tensor([weight_bits])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight

def quantize_input(self, inputs, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input_bits = get_bits_length(config, 'input')
module.input_bits = torch.Tensor([input_bits])

module.input_bit = torch.tensor([input_bits])
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
return inputs

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)

scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)

inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp)
return inp
Expand All @@ -528,19 +541,26 @@ def quantize_output(self, output, wrapper, **kwargs):
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
return output

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
module.tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)

scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)

out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
Expand Down Expand Up @@ -645,7 +665,7 @@ def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1
self.bound_model.steps.add_(1)


class DoReFaQuantizer(Quantizer):
Expand Down
15 changes: 13 additions & 2 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ class Quantizer(Compressor):
"""

def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
Expand Down Expand Up @@ -892,12 +894,21 @@ def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax)
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
return output

@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables
tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None

Expand Down