-
Notifications
You must be signed in to change notification settings - Fork 1.8k
support dp multi-gpu training for QAT quantizer #4127
Conversation
Have rebased. |
50e858b
to
102a0a4
Compare
module = layer.module | ||
module.register_buffer("zero_point", torch.tensor([0.0])) | ||
module.register_buffer("scale", torch.tensor([1.0])) | ||
module.register_buffer('ema_decay', torch.tensor([0.99])) | ||
if "weight" in config.get("quant_types", []): | ||
layer.module.register_buffer('weight_bits', torch.zeros(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be we can also del layer
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -479,43 +480,55 @@ def quantize_weight(self, wrapper, **kwargs): | |||
quant_start_step = config.get('quant_start_step', 0) | |||
assert weight_bits >= 1, "quant bits length should be at least 1" | |||
|
|||
# we dont update weight in evaluation stage | |||
if quant_start_step > self.bound_model.steps: | |||
if quant_start_step > int(self.bound_model.steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we add specific datatype convert here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a good idea to directly compare a python int
with a torch.Tensor
, which may lead to unpredictable error.
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax) | ||
module.scale.copy_(scale) | ||
module.zero_point.copy_(zero_point) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between copy_
and assignment directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As described here:
In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded.
If we assign directly, the update of scale & zero_point & tracked information will be lost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx! got it
tracked_min_output = update_ema(module.tracked_min_output, current_min, | ||
module.ema_decay) | ||
tracked_max_output = update_ema(module.tracked_max_output, current_max, | ||
module.ema_decay) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
align
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
No description provided.