diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index a8062aff3ed..0d64c5fd74c 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2765,14 +2765,20 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False): q_model._model = ipex.quantization.convert(model._model, inplace=inplace) try: if isinstance(self.example_inputs, dict): - q_model._model = torch.jit.trace(q_model._model, example_kwarg_inputs=self.example_inputs) + q_model._model = torch.jit.trace( + q_model._model, + example_kwarg_inputs=self.example_inputs, + ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs) q_model._model = torch.jit.freeze(q_model._model.eval()) except: if isinstance(self.example_inputs, dict): q_model._model = torch.jit.trace( - q_model._model, example_kwarg_inputs=self.example_inputs, strict=False + q_model._model, + example_kwarg_inputs=self.example_inputs, + strict=False, + check_trace=False, ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False) @@ -2789,7 +2795,7 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False): except: if isinstance(self.example_inputs, dict): q_model._model = torch.jit.trace( - q_model._model, example_kwarg_inputs=self.example_inputs, strict=False + q_model._model, example_kwarg_inputs=self.example_inputs, strict=False, check_trace=False ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False) diff --git a/neural_compressor/adaptor/torch_utils/mixed_precision.py b/neural_compressor/adaptor/torch_utils/mixed_precision.py index 9539c1a5517..839ee7d88ba 100644 --- a/neural_compressor/adaptor/torch_utils/mixed_precision.py +++ b/neural_compressor/adaptor/torch_utils/mixed_precision.py @@ -34,7 +34,9 @@ def ipex_mixed_precision(model, example_inputs=None): try: mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs) except: - mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs, strict=False) + mp_model = torch.jit.trace( + mp_model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False + ) else: try: mp_model = torch.jit.trace(mp_model, example_inputs) diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 65804468ec4..7b59b3ce3e5 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -1186,7 +1186,9 @@ def trace(self, model, dummy_input): dummy_input = move_input_to_device(dummy_input, "cpu") if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): try: - traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(dummy_input), strict=False) + traced_model = torch.jit.trace( + model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False + ) traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) except Exception as e: logger.warning(e)