Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support dp multi-gpu training for QAT quantizer #4127

Merged
merged 2 commits into from
Sep 8, 2021

Conversation

chenbohua3
Copy link
Contributor

No description provided.

@chenbohua3
Copy link
Contributor Author

chenbohua3 commented Aug 30, 2021

I will rebase the master when #4084 is merged

Have rebased.

module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0]))
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can also del layer here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -479,43 +480,55 @@ def quantize_weight(self, wrapper, **kwargs):
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps:
if quant_start_step > int(self.bound_model.steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we add specific datatype convert here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a good idea to directly compare a python int with a torch.Tensor, which may lead to unpredictable error.

module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between copy_ and assignment directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As described here:

In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded.

If we assign directly, the update of scale & zero_point & tracked information will be lost

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx! got it

tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@QuanluZhang QuanluZhang merged commit 396ae65 into microsoft:master Sep 8, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants