-
Notifications
You must be signed in to change notification settings - Fork 1.8k
support dtype&scheme customization for QAT quantizer #4137
support dtype&scheme customization for QAT quantizer #4137
Conversation
Have rebased |
22c40e8
to
e908235
Compare
@@ -155,7 +155,7 @@ | |||
grad_output : Tensor | |||
量化操作输出的梯度 | |||
quant_type : QuantType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all modification of Chinese document since Chinese document is generated in another pipeline which generate it automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rmax = torch.max(rmax, torch.zeros_like(rmax)) | ||
zero_point = torch.zeros_like(rmin) | ||
|
||
# todo: there is no need to calculate qmin and qmax again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Maybe we can put them into wrapper.
input_shape, output_shape = self.all_shapes[name] | ||
layer_quant_setting = LayerQuantSetting(config) | ||
layer_quant_setting.ema_decay = 0.99 | ||
quant_start_step = config.get('quant_start_step', 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since layer_quant_setting = LayerQuantSetting(config)
has used config, should we get quant_start_step
directly during LayerQuantSetting init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now quant_start_step
is a QAT_Quantizer
specific parameter. I think it is better to set it after LayerQuantSetting
initialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@J-shang how do u think about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer get it from LayerQuantSetting._extra_layer_setting
if it is possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, I think it can be put into LayerQuantSetting
because this class can be universal class for all quantizers (may be part of design for quantization v2) and quant_start_step
is an universal parameter in config in quantization original design, all universal parameters should be put into this class except parameter changes during training.
For this pr, I think current implementation is enough.
if not wrapper.training: | ||
scale, zero_point = module.weight_scale, module.weight_zero_point | ||
weight = self._quantize(weight, scale, zero_point, qmin, qmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only need to quantize weight at the first inference epoch. Suggest avoiding unnecessary computation by some ways such as adding specific tag here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am afraid that we must quantize the weight in each iteration. Because we can not in-place update the weight. So if we use dp, the update will lost after each iteration.
module.tracked_min_input.copy_(tracked_min_input) | ||
module.tracked_max_input.copy_(tracked_max_input) | ||
|
||
if quant_start_step > int(self.bound_model.steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why delete original logic which can help initiate tracked_min_input
and tracked_max_input
here? According to the test result before, keeping them can perform better or converge faster than removing them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it will be kept whatever quant_start_step > int(self.bound_model.steps)
or not.
if quant_start_step > int(self.bound_model.steps): | ||
return inputs | ||
|
||
tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems line 595-598 repeats what line 587-590 does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have removed them
@@ -604,10 +607,13 @@ class Quantizer(Compressor): | |||
def __init__(self, model, config_list, optimizer=None, dummy_input=None): | |||
if isinstance(model, torch.nn.DataParallel): | |||
model = model.module | |||
model_copied = copy.deepcopy(model) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we use quantizer-wrapped
model to determine which layer's shapes should be recorded. And in the process of recording shape, some hooks would be registered in the model. So I think it is better to use a copied model to do shape recording.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get it.
@@ -793,25 +799,54 @@ def find_conv_bn_patterns(self, model, dummy_input): | |||
if successor.op_type == 'BatchNorm2d': | |||
self.conv_bn_patterns[node_group.name] = successor.name | |||
|
|||
def step_with_optimizer(self): | |||
pass | |||
def record_shape(self, model, dummy_input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be an universal function which is not specific for quantizer. Maybe we should put it into other places like utils.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There exists compressor-specific util function (like get_modules_to_compress
) and currently it is only used by quantization. Putting this function as a attribute function of quantizer may be a good choice.. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get it.
bits = QuantGrad.get_bits_length(wrapper.config, quant_type) | ||
qmin, qmax = 0, (1 << bits) - 1 | ||
|
||
scale_name, zero_point_name = quant_type.type_to_scale_zero_point_name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is strange that we get scale_name and zero_point_name in this way while we register them in module using like 'weight_scale' and weight_zero_point
directly. So what is the meaning of type_to_scale_zero_point_name()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different quant_type
s correspond to different scale&zero_point names, e.g. weight_scale
, input_scale
and output_scale
. In order to getting them, we must map quantization types to the scale&zero point names. type_to_scale_zero_point_name
is just for the code simplicity, do you have better ideas about this?
self._fields[k] = v | ||
|
||
def __setattr__(self, name: str, val: Any) -> None: | ||
if name.startswith("_"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this just for _fields
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
please fix the conflict |
4628fce
to
05ab608
Compare
have resolved the conflict |
input_shape, output_shape = self.all_shapes[name] | ||
layer_quant_setting = LayerQuantSetting(config) | ||
layer_quant_setting.ema_decay = 0.99 | ||
quant_start_step = config.get('quant_start_step', 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@J-shang how do u think about it?
self.bound_model.to(device) | ||
|
||
def _del_simulated_attr(self, module): | ||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output', | ||
'tracked_min_input', 'tracked_max_input', 'scale', 'zero_point', 'weight_bits', | ||
'output_bits', 'BN_FOLD_TAG', 'input_bits'] | ||
'tracked_min_input', 'tracked_max_input', 'weight_bits', 'output_bits', 'BN_FOLD_TAG', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can delete 'input_bits', 'weight_bits', 'output_bits' here since they have been removed in module parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
weight_bits = get_bits_length(config, 'weight') | ||
layer.module.register_buffer('weight_bits', torch.Tensor([int(weight_bits)])) | ||
quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme) | ||
module.register_buffer('weight_scale', torch.zeros(quant_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why keep 'scale' and 'zero_point' in module parameters while removing bits
. To some degree, they should in same level and should be kept in same places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale
and zero_point
are tensors that will be used in the auto-grad graph while bits
will not. So I put these "non-auto-grad" attributes in the layer/tensor settings
@@ -793,25 +799,54 @@ def find_conv_bn_patterns(self, model, dummy_input): | |||
if successor.op_type == 'BatchNorm2d': | |||
self.conv_bn_patterns[node_group.name] = successor.name | |||
|
|||
def step_with_optimizer(self): | |||
pass | |||
def record_shape(self, model, dummy_input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get it.
@@ -604,10 +607,13 @@ class Quantizer(Compressor): | |||
def __init__(self, model, config_list, optimizer=None, dummy_input=None): | |||
if isinstance(model, torch.nn.DataParallel): | |||
model = model.module | |||
model_copied = copy.deepcopy(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get it.
@@ -347,7 +348,8 @@ def test_torch_QAT_quantizer(self): | |||
model.relu = torch.nn.ReLU() | |||
|
|||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | |||
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer) | |||
dummy = torch.randn(1, 1, 28, 28) | |||
quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need using a config_list like [{..., 'quant_scheme': ..., 'quant_dtype': ...}]
here to test dtype and scheme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a stand-alone ut to test scheme and dtype.
|
||
|
||
# Just show each attribute's name, no practical effect | ||
class QuantConfigLiteral(str, _QuantLiteralEnum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems not be used, why keep it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To tell the developer the name of each attribute:)
input_shape, output_shape = self.all_shapes[name] | ||
layer_quant_setting = LayerQuantSetting(config) | ||
layer_quant_setting.ema_decay = 0.99 | ||
quant_start_step = config.get('quant_start_step', 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, I think it can be put into LayerQuantSetting
because this class can be universal class for all quantizers (may be part of design for quantization v2) and quant_start_step
is an universal parameter in config in quantization original design, all universal parameters should be put into this class except parameter changes during training.
For this pr, I think current implementation is enough.
qmin, qmax = -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1 | ||
elif dtype == QuantDtype.UINT: | ||
qmin, qmax = 0, 2 ** bits - 1 | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest raising TypeError here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
return scale, nudged_zero_point | ||
zero_point = torch.clamp(zero_point, qmin, qmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if scheme is affine, should we clamp zero_point between qmin and qmax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, follow the codes in PyTorch repo
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update | ||
# will be lost after each forward process. However, this update takes effect on each | ||
# replicated module during each forward process, which will make the quantized weight | ||
# be used correctly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add this comment? what modification is in-place modification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like wrapper.module.weight.copy_(weight)
, this is not allowed in the PyTorch
|
||
# layer-wise settings | ||
quant_start_step = layer_quant_setting.quant_start_step | ||
ema_decay = layer_quant_setting.ema_decay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is ema_decay only for input and output tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
nni/common/version.py
Outdated
@@ -0,0 +1,3 @@ | |||
import torch | |||
|
|||
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it require torch installed, even only use hpo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some guard logics
test/ut/sdk/test_compressor_torch.py
Outdated
scale = getattr(module, scale_name) | ||
zero_point = getattr(module, zero_point_name) | ||
self.assertTrue(list(scale.shape) == quant_shape) | ||
self.assertTrue(list(zero_point.shape) == quant_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test is weak, better to add more test. for example, test whether a quantized tensor with different type of dtype/scheme has the expected value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some value checks for scales & zero_points
# TODO: may relax this limitation? | ||
assert name in self.all_shapes, "Could not found shapes for layer {}".format(name) | ||
input_shape, output_shape = self.all_shapes[name] | ||
layer_quant_setting = LayerQuantSetting(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should properly support other ops, such as linear. or we can report warning message for unsupported ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QAT_quantizer now only quantizes Conv2d/Linear/ReLU/ReLU6. And when quantize input/output of a Linear layer in per-channel style, a rank=2 check will be executed.
@chenbohua3 this pr looks good, please update doc (e.g., qat_quantizer) accordingly |
No description provided.