-
Notifications
You must be signed in to change notification settings - Fork 1.8k
support dp multi-gpu training for QAT quantizer #4127
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -373,21 +373,22 @@ def __init__(self, model, config_list, optimizer, dummy_input=None): | |
self.quant_grad = QATGrad.apply | ||
modules_to_compress = self.get_modules_to_compress() | ||
device = next(model.parameters()).device | ||
self.bound_model.register_buffer("steps", torch.Tensor([1])) | ||
self.bound_model.register_buffer("steps", torch.tensor(1)) | ||
for layer, config in modules_to_compress: | ||
layer.module.register_buffer("zero_point", torch.Tensor([0.0])) | ||
layer.module.register_buffer("scale", torch.Tensor([1.0])) | ||
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) | ||
module = layer.module | ||
module.register_buffer("zero_point", torch.tensor([0.0])) | ||
module.register_buffer("scale", torch.tensor([1.0])) | ||
module.register_buffer('ema_decay', torch.tensor([0.99])) | ||
if "weight" in config.get("quant_types", []): | ||
layer.module.register_buffer('weight_bits', torch.zeros(1)) | ||
module.register_buffer('weight_bits', torch.zeros(1)) | ||
if "input" in config.get("quant_types", []): | ||
layer.module.register_buffer('tracked_min_input', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_input', torch.zeros(1)) | ||
layer.module.register_buffer('input_bits', torch.zeros(1)) | ||
module.register_buffer('input_bits', torch.zeros(1)) | ||
module.register_buffer('tracked_min_input', torch.zeros(1)) | ||
module.register_buffer('tracked_max_input', torch.zeros(1)) | ||
if "output" in config.get("quant_types", []): | ||
layer.module.register_buffer('output_bits', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_min_output', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_output', torch.zeros(1)) | ||
module.register_buffer('output_bits', torch.zeros(1)) | ||
module.register_buffer('tracked_min_output', torch.zeros(1)) | ||
module.register_buffer('tracked_max_output', torch.zeros(1)) | ||
self.bound_model.to(device) | ||
|
||
def _del_simulated_attr(self, module): | ||
|
@@ -479,43 +480,55 @@ def quantize_weight(self, wrapper, **kwargs): | |
quant_start_step = config.get('quant_start_step', 0) | ||
assert weight_bits >= 1, "quant bits length should be at least 1" | ||
|
||
# we dont update weight in evaluation stage | ||
if quant_start_step > self.bound_model.steps: | ||
if quant_start_step > int(self.bound_model.steps): | ||
return weight | ||
|
||
if not wrapper.training: | ||
return weight | ||
|
||
# quantize weight | ||
rmin, rmax = torch.min(weight), torch.max(weight) | ||
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
module.scale.copy_(scale) | ||
module.zero_point.copy_(zero_point) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As described here: In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded. If we assign directly, the update of scale & zero_point & tracked information will be lost There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx! got it |
||
weight = self._quantize(weight_bits, module, weight) | ||
weight = self._dequantize(module, weight) | ||
module.weight_bits = torch.Tensor([weight_bits]) | ||
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update | ||
# will be lost after each forward process. However, this update takes effect on each | ||
# replicated module during each forward process, which will make the quantized weight | ||
# be used correctly. | ||
wrapper.module.weight = weight | ||
return weight | ||
|
||
def quantize_input(self, inputs, wrapper, **kwargs): | ||
config = wrapper.config | ||
module = wrapper.module | ||
input_bits = get_bits_length(config, 'input') | ||
module.input_bits = torch.Tensor([input_bits]) | ||
|
||
module.input_bit = torch.tensor([input_bits]) | ||
quant_start_step = config.get('quant_start_step', 0) | ||
assert input_bits >= 1, "quant bits length should be at least 1" | ||
|
||
if quant_start_step > self.bound_model.steps: | ||
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs) | ||
if quant_start_step > int(self.bound_model.steps): | ||
current_min, current_max = torch.min(inputs), torch.max(inputs) | ||
module.tracked_min_input.copy_(current_min) | ||
module.tracked_max_input.copy_(current_max) | ||
return inputs | ||
|
||
# we dont update output quantization parameters in evaluation stage | ||
if wrapper.training: | ||
current_min, current_max = torch.min(inputs), torch.max(inputs) | ||
module.tracked_min_input = update_ema(module.tracked_min_input, current_min, | ||
module.ema_decay) | ||
module.tracked_max_input = update_ema(module.tracked_max_input, current_max, | ||
module.ema_decay) | ||
module.scale, module.zero_point = update_quantization_param( | ||
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay) | ||
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay) | ||
module.tracked_min_input.copy_(current_min) | ||
module.tracked_max_input.copy_(current_max) | ||
|
||
scale, zero_point = update_quantization_param( | ||
input_bits, module.tracked_min_input, module.tracked_max_input) | ||
module.scale.copy_(scale) | ||
module.zero_point.copy_(zero_point) | ||
|
||
inp = self._quantize(input_bits, module, inputs) | ||
inp = self._dequantize(module, inp) | ||
return inp | ||
|
@@ -528,19 +541,26 @@ def quantize_output(self, output, wrapper, **kwargs): | |
quant_start_step = config.get('quant_start_step', 0) | ||
assert output_bits >= 1, "quant bits length should be at least 1" | ||
|
||
if quant_start_step > self.bound_model.steps: | ||
module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output) | ||
if quant_start_step > int(self.bound_model.steps): | ||
current_min, current_max = torch.min(output), torch.max(output) | ||
module.tracked_min_output.copy_(current_min) | ||
module.tracked_max_output.copy_(current_max) | ||
return output | ||
|
||
# we dont update output quantization parameters in evaluation stage | ||
if wrapper.training: | ||
current_min, current_max = torch.min(output), torch.max(output) | ||
module.tracked_min_output = update_ema(module.tracked_min_output, current_min, | ||
module.ema_decay) | ||
module.tracked_max_output = update_ema(module.tracked_max_output, current_max, | ||
module.ema_decay) | ||
module.scale, module.zero_point = update_quantization_param( | ||
tracked_min_output = update_ema(module.tracked_min_output, current_min, | ||
module.ema_decay) | ||
tracked_max_output = update_ema(module.tracked_max_output, current_max, | ||
module.ema_decay) | ||
module.tracked_min_output.copy_(tracked_min_output) | ||
module.tracked_max_output.copy_(tracked_max_output) | ||
|
||
scale, zero_point = update_quantization_param( | ||
output_bits, module.tracked_min_output, module.tracked_max_output) | ||
module.scale.copy_(scale) | ||
module.zero_point.copy_(zero_point) | ||
|
||
out = self._quantize(output_bits, module, output) | ||
out = self._dequantize(module, out) | ||
|
@@ -645,7 +665,7 @@ def step_with_optimizer(self): | |
""" | ||
override `compressor` `step` method, quantization only happens after certain number of steps | ||
""" | ||
self.bound_model.steps += 1 | ||
self.bound_model.steps.add_(1) | ||
|
||
|
||
class DoReFaQuantizer(Quantizer): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we add specific datatype convert here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a good idea to directly compare a
python int
with atorch.Tensor
, which may lead to unpredictable error.