Skip to content

Commit 6133f4e

Browse files
authored
adapt transformers 4.37 loading (#1606)
Signed-off-by: changwangss <chang1.wang@intel.com> Signed-off-by: chensuyue <suyue.chen@intel.com>
1 parent 3882e9c commit 6133f4e

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fi
8484
# install special test env requirements
8585
# common deps
8686
pip install cmake
87-
pip install transformers==4.36.2
87+
pip install transformers
8888

8989
if [[ $(echo "${test_case}" | grep -c "others") != 0 ]];then
9090
pip install tf_slim xgboost accelerate==0.21.0 peft

.azure-pipelines/scripts/ut/run_itrex.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ bash /intel-extension-for-transformers/.github/workflows/script/install_binary.s
1212

1313
# prepare test env
1414
# tmp install transformers for incompatible issue
15-
pip install transformers==4.36.2
15+
pip install transformers
1616
pip install -r /intel-extension-for-transformers/tests/requirements.txt
1717
LOG_DIR=/neural-compressor/log_dir
1818
mkdir -p ${LOG_DIR}

neural_compressor/utils/load_huggingface.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,49 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
122122
else: # pragma: no cover
123123
model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load)
124124

125+
if not os.path.isdir(model_name_or_path) and not os.path.isfile(model_name_or_path): # pragma: no cover
126+
from transformers.utils import cached_file
127+
128+
try:
129+
# Load from URL or cache if already cached
130+
resolved_weights_file = cached_file(
131+
model_name_or_path,
132+
filename=WEIGHTS_NAME,
133+
cache_dir=cache_dir,
134+
force_download=force_download,
135+
resume_download=resume_download,
136+
use_auth_token=use_auth_token,
137+
)
138+
except EnvironmentError as err: # pragma: no cover
139+
logger.error(err)
140+
msg = (
141+
f"Can't load weights for '{model_name_or_path}'. Make sure that:\n\n"
142+
f"- '{model_name_or_path}' is a correct model identifier "
143+
f"listed on 'https://huggingface.co/models'\n (make sure "
144+
f"'{model_name_or_path}' is not a path to a local directory with "
145+
f"something else, in that case)\n\n- or '{model_name_or_path}' is "
146+
f"the correct path to a directory containing a file "
147+
f"named one of {WEIGHTS_NAME}\n\n"
148+
)
149+
if revision is not None:
150+
msg += (
151+
f"- or '{revision}' is a valid git identifier "
152+
f"(branch name, a tag name, or a commit id) that "
153+
f"exists for this model name as listed on its model "
154+
f"page on 'https://huggingface.co/models'\n\n"
155+
)
156+
raise EnvironmentError(msg)
157+
else:
158+
resolved_weights_file = os.path.join(model_name_or_path, WEIGHTS_NAME)
159+
state_dict = torch.load(resolved_weights_file, {})
125160
model = model_class.from_pretrained(
126161
model_name_or_path,
127162
cache_dir=cache_dir,
128163
force_download=force_download,
129164
resume_download=resume_download,
130165
use_auth_token=use_auth_token,
131166
revision=revision,
167+
state_dict=state_dict,
132168
**kwargs,
133169
)
134170

0 commit comments

Comments
 (0)