diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index a892335ee8..f9039fa790 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -806,7 +806,6 @@ def __init__(self, module, config): super().__init__() _setattr_from_module(self, module) self.linear_gelu = LinearGelu(module.dense) - del self.__dict__["_modules"]["dense"] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_gelu(hidden_states) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index cb60541b1d..7928492b32 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -54,6 +54,7 @@ from ...exporters.ipex.cache_utils import IPEXPagedCache from ...exporters.ipex.model_config import ipex_onnx_config from ...exporters.ipex.model_patcher import ( + _IPEX_EXPORTED_GENERATION_TASKS, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model, ) @@ -73,7 +74,7 @@ def _is_patched_with_ipex(model, task, use_cache: bool = True): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): return False - if not use_cache: + if not use_cache and task in _IPEX_EXPORTED_GENERATION_TASKS: return False return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES @@ -299,6 +300,10 @@ def model_dtype(self): ) return self._dtype + @property + def add_patch(self) -> bool: + return self._add_patch + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 3f366d3faf..459d1c9b15 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -74,12 +74,15 @@ class IPEXModelTest(unittest.TestCase): "squeezebert", "xlm", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("bert",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device) @@ -317,6 +320,8 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): if IS_XPU: dtype = torch.float16 model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache, torch_dtype=dtype) + if use_cache: + self.assertTrue(model.add_patch) device = model.device transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device) self.assertEqual(model.use_cache, use_cache) @@ -433,12 +438,15 @@ class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase): "resnet", "vit", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("vit",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) device = ipex_model.device self.assertIsInstance(ipex_model.config, PretrainedConfig) transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device)