diff --git a/d2go/quantization/modeling.py b/d2go/quantization/modeling.py index c186db26..37272eef 100644 --- a/d2go/quantization/modeling.py +++ b/d2go/quantization/modeling.py @@ -5,7 +5,7 @@ import copy import logging import math -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import detectron2.utils.comm as comm import torch @@ -471,6 +471,7 @@ def setup_qat_model( enable_fake_quant: bool = False, enable_observer: bool = False, enable_learnable_observer: bool = False, + example_input: Optional[Any] = None, ): assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in [ "default", @@ -490,7 +491,7 @@ def setup_qat_model( model_fp32_state_dict = model_fp32.state_dict() # prepare model for qat - model = prepare_fake_quant_model(cfg, model_fp32, True) + model = prepare_fake_quant_model(cfg, model_fp32, True, example_input=example_input) # make sure the proper qconfig are used in the model learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model) @@ -554,9 +555,23 @@ def _setup_non_qat_to_qat_state_dict_map( ) ) - assert len(new_state_dict_non_observer_keys_not_ignored) == len( + if not len(new_state_dict_non_observer_keys_not_ignored) == len( original_state_dict_shapes - ), f"keys in state dict of original and new qat model {len(new_state_dict_non_observer_keys_not_ignored)} vs {len(original_state_dict_shapes)}" + ): + a = set(new_state_dict_non_observer_keys_not_ignored) + b = set(original_state_dict_shapes.keys()) + a_diff_b = a.difference(b) + b_diff_a = b.difference(a) + logger.info("unique keys in qat model state dict") + for key in a_diff_b: + logger.info(f"{key}") + logger.info("unique keys in original model state dict") + for key in b_diff_a: + logger.info(f"{key}") + + raise RuntimeError( + f"an inconsistent number of keys in state dict of new qat and original model: {len(a)} vs {len(b)}" + ) if is_eager_mode: for n_k, o_k in zip(