Skip to content

Commit

Permalink
Modify quantizer type based on encoding provided when strict=False (#…
Browse files Browse the repository at this point in the history
…3141)

Signed-off-by: Sai Chaitanya Gajula <quic_gsaichai@quicinc.com>
  • Loading branch information
quic-gsaichai authored Jul 3, 2024
1 parent c4857b7 commit 283ee26
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 7 deletions.
28 changes: 22 additions & 6 deletions TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def import_param_encodings(self,
...
}
"""
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches, too-many-statements
for param_name, quantizer in self.param_quantizers.items():
if quantizer._is_encoding_frozen: # pylint: disable=protected-access
continue
Expand All @@ -523,11 +523,27 @@ def import_param_encodings(self,
if quantizer.enabled:
# pylint: disable=protected-access
if isinstance(quantizer, StaticGridPerChannelQuantizer) and len(quantizer._cppOp) != len(encoding):
raise ValueError(f"Invalid PerChannel encodings for {param_name}, the quantizer is a "
f"PerChannelQuantizer. To avoid this, disable per_channel_quantization")
if isinstance(quantizer, StaticGridPerTensorQuantizer) and len(encoding) != 1:
raise ValueError(f"Invalid PerTensor encodings for {param_name}, the quantizer is a "
f"PerTensorQuantizer. To avoid this, enable per_channel_quantization")
assert len(encoding) == 1, (f'Number of Per Channel encodings provided ({len(encoding)}) is '
f'not same as number of channels ({len(quantizer._cppOp)})')
if strict:
raise ValueError(f"Invalid PerChannel encodings for {param_name}, the quantizer is a "
f"PerChannelQuantizer. To avoid this, disable per_channel_quantization")
# Modifying PerChannel quantizer to PerTensor
_logger.warning('Replacing PerChannel Quantizer with PerTensor based on encoding provided')
quantizer = utils.get_per_tensor_quantizer_from_per_channel(quantizer)
self.param_quantizers[param_name] = quantizer
elif isinstance(quantizer, StaticGridPerTensorQuantizer) and len(encoding) != 1:
if strict:
raise ValueError(f"Invalid PerTensor encodings for {param_name}, the quantizer is a "
f"PerTensorQuantizer. To avoid this, enable per_channel_quantization")
# Modifying PerTensor quantizer to PerChannel
_logger.warning('Replacing PerTensor Quantizer with PerChannel based on encoding provided..')
quantizer = utils.get_per_channel_quantizer_from_per_tensor(quantizer, self.get_original_module())
assert len(quantizer._cppOp) == len(encoding), (f'Number of per channel encodings ({len(encoding)})'
f' should much with number of output '
f'channels ({len(quantizer._cppOp)})')
self.param_quantizers[param_name] = quantizer

if encoding[0]['dtype'] == 'int':
# Validate and set symmetric flags before computing partial encodings
validate_is_symmetric_flag(quantizer, encoding[0], strict)
Expand Down
40 changes: 39 additions & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from aimet_common.utils import profile as _profile
import aimet_common.libpymo as libpymo
from aimet_torch import elementwise_ops
from aimet_torch.tensor_quantizer import TensorQuantizer
from aimet_torch.tensor_quantizer import TensorQuantizer, StaticGridPerChannelQuantizer, StaticGridPerTensorQuantizer

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)

Expand Down Expand Up @@ -1215,6 +1215,44 @@ def _validate_is_symmetric_flag(quantizer: TensorQuantizer, encoding_dict: Dict,
raise AttributeError("Provided encoding doesn't have 'is_symmetric' flag")


def get_per_channel_quantizer_from_per_tensor(quantizer: TensorQuantizer, original_module: torch.nn.Module):
""" Get PerChannel Quantizer with same settings as given PerTensor Quantizer """
channel_axis = 0
if isinstance(original_module, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
if len(original_module.weight.shape) > 1:
channel_axis = 1

num_channels = original_module.weight.shape[channel_axis]
use_strict_symmetric = quantizer.use_strict_symmetric
use_unsigned_symmetric = quantizer.use_unsigned_symmetric
quantizer = StaticGridPerChannelQuantizer(quantizer.bitwidth, quantizer.round_mode,
quantizer.quant_scheme,
quantizer.use_symmetric_encodings,
num_channels=num_channels,
enabled_by_default=quantizer.enabled,
ch_axis=channel_axis,
data_type=quantizer.data_type)
quantizer.use_strict_symmetric = use_strict_symmetric
quantizer.use_unsigned_symmetric = use_unsigned_symmetric
return quantizer


def get_per_tensor_quantizer_from_per_channel(quantizer: TensorQuantizer):
""" Get PerTensor Quantizer with same settings as given PerChannel Quantizer """
use_strict_symmetric = quantizer.use_strict_symmetric
use_unsigned_symmetric = quantizer.use_unsigned_symmetric
quantizer = StaticGridPerTensorQuantizer(quantizer.bitwidth, quantizer.round_mode,
quantizer.quant_scheme,
quantizer.use_symmetric_encodings,
enabled_by_default=quantizer.enabled,
data_type=quantizer.data_type)
quantizer.use_strict_symmetric = use_strict_symmetric
quantizer.use_unsigned_symmetric = use_unsigned_symmetric
return quantizer


def validate_is_symmetric_flag(quantizer: TensorQuantizer, encoding_dict: Dict, strict: bool = True):
"""
Validate 'is_symmetric' flag from encoding_dict with quantizer.use_symmetric_encodings and set the later accordingly
Expand Down
69 changes: 69 additions & 0 deletions TrainingExtensions/torch/test/python/test_quantsim_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,3 +2224,72 @@ def test_load_and_freeze_with_partial_encodings(self, sample_enc):
assert sim.model.conv1.param_quantizers['weight'].use_symmetric_encodings
else:
assert not sim.model.conv1.param_quantizers['weight'].use_symmetric_encodings

def test_load_encodings_to_allow_modifying_quantizer_type(self):
""" Test load encodings API to allow modifying quantizer type based on encoding """
model = test_models.TinyModelWithNoMathInvariantOps()
dummy_input = torch.randn([1, 3, 24, 24])

sample_act_enc = {"min": -4, "max": 4, "bitwidth": 8, "dtype": "int", "is_symmetric": "False"}
sample_param_enc = {"min": -4, "max": 4, "bitwidth": 8, "dtype": "int", "is_symmetric": "True"}

encodings = {"activation_encodings": {"conv1": {"input": {"0": sample_act_enc}},
"mul1": {"output": {"0": sample_act_enc}}},
"param_encodings": {}}

pcq_config = {
"defaults":{
"ops":{
"is_output_quantized": "True"
},
"params":{
"is_quantized": "True",
"is_symmetric": "True"
},
"strict_symmetric": "False",
"per_channel_quantization": "True"
},
"params": {},
"op_type": {},
"model_input":{
"is_input_quantized": "True"
},
"supergroups": [],
"model_output": {}
}

with tempfile.TemporaryDirectory() as tmp_dir:
pcq_config_file = os.path.join(tmp_dir, 'pcq_quantsim_config.json')
with open(pcq_config_file, 'w') as f:
json.dump(pcq_config, f)

for config_file in [None, pcq_config_file]:
if config_file is None:
# PTQ to PCQ case, initial quantizer is PTQ, but the encodings are of PCQ
encodings['param_encodings']['conv1.weight'] = [sample_param_enc for i in range(16)]
else:
# PCQ to PTQ case, initial quantizer is PCQ, but the encodings are of PTQ
encodings['param_encodings']['conv1.weight'] = [sample_param_enc]

with tempfile.TemporaryDirectory() as tmp_dir:
with open(os.path.join(tmp_dir, 'replace_quantizer_with_enc.json'), 'w') as f:
json.dump(encodings, f)

sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf, config_file=config_file)

# Checking Quantizer type before loading encodings to Quantsim
if config_file is None:
assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerTensorQuantizer)
else:
assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerChannelQuantizer)

sim.load_and_freeze_encodings(os.path.join(tmp_dir, 'replace_quantizer_with_enc.json'),
ignore_when_quantizer_disabled=True)

sim.compute_encodings(lambda m, _: m(dummy_input), None)

# Checking whether the quantizer is modifed to required type after laoding encodings
if config_file is None:
assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerChannelQuantizer)
else:
assert isinstance(sim.model.conv1.param_quantizers['weight'], StaticGridPerTensorQuantizer)

0 comments on commit 283ee26

Please sign in to comment.