Skip to content

Commit

Permalink
fix bert and vit patch (#1022)
Browse files Browse the repository at this point in the history
* fix bert and vit patch
* fix vit and bert save


Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng authored Nov 25, 2024
1 parent 388265f commit ad9b795
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
1 change: 0 additions & 1 deletion optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,6 @@ def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.linear_gelu = LinearGelu(module.dense)
del self.__dict__["_modules"]["dense"]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_gelu(hidden_states)
Expand Down
7 changes: 6 additions & 1 deletion optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from ...exporters.ipex.cache_utils import IPEXPagedCache
from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
Expand All @@ -73,7 +74,7 @@
def _is_patched_with_ipex(model, task, use_cache: bool = True):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
return False
if not use_cache:
if not use_cache and task in _IPEX_EXPORTED_GENERATION_TASKS:
return False
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES

Expand Down Expand Up @@ -299,6 +300,10 @@ def model_dtype(self):
)
return self._dtype

@property
def add_patch(self) -> bool:
return self._add_patch

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
Expand Down
8 changes: 8 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,15 @@ class IPEXModelTest(unittest.TestCase):
"squeezebert",
"xlm",
)
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("bert",)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
self.assertTrue(ipex_model.add_patch)
device = ipex_model.device
self.assertIsInstance(ipex_model.config, PretrainedConfig)
transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device)
Expand Down Expand Up @@ -317,6 +320,8 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
if IS_XPU:
dtype = torch.float16
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache, torch_dtype=dtype)
if use_cache:
self.assertTrue(model.add_patch)
device = model.device
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
self.assertEqual(model.use_cache, use_cache)
Expand Down Expand Up @@ -433,12 +438,15 @@ class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase):
"resnet",
"vit",
)
IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("vit",)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
set_seed(SEED)
ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True)
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
self.assertTrue(ipex_model.add_patch)
device = ipex_model.device
self.assertIsInstance(ipex_model.config, PretrainedConfig)
transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id).to(device)
Expand Down

0 comments on commit ad9b795

Please sign in to comment.