@@ -122,13 +122,49 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
122
122
else : # pragma: no cover
123
123
model_class ._keys_to_ignore_on_load_missing .extend (missing_keys_to_ignore_on_load )
124
124
125
+ if not os .path .isdir (model_name_or_path ) and not os .path .isfile (model_name_or_path ): # pragma: no cover
126
+ from transformers .utils import cached_file
127
+
128
+ try :
129
+ # Load from URL or cache if already cached
130
+ resolved_weights_file = cached_file (
131
+ model_name_or_path ,
132
+ filename = WEIGHTS_NAME ,
133
+ cache_dir = cache_dir ,
134
+ force_download = force_download ,
135
+ resume_download = resume_download ,
136
+ use_auth_token = use_auth_token ,
137
+ )
138
+ except EnvironmentError as err : # pragma: no cover
139
+ logger .error (err )
140
+ msg = (
141
+ f"Can't load weights for '{ model_name_or_path } '. Make sure that:\n \n "
142
+ f"- '{ model_name_or_path } ' is a correct model identifier "
143
+ f"listed on 'https://huggingface.co/models'\n (make sure "
144
+ f"'{ model_name_or_path } ' is not a path to a local directory with "
145
+ f"something else, in that case)\n \n - or '{ model_name_or_path } ' is "
146
+ f"the correct path to a directory containing a file "
147
+ f"named one of { WEIGHTS_NAME } \n \n "
148
+ )
149
+ if revision is not None :
150
+ msg += (
151
+ f"- or '{ revision } ' is a valid git identifier "
152
+ f"(branch name, a tag name, or a commit id) that "
153
+ f"exists for this model name as listed on its model "
154
+ f"page on 'https://huggingface.co/models'\n \n "
155
+ )
156
+ raise EnvironmentError (msg )
157
+ else :
158
+ resolved_weights_file = os .path .join (model_name_or_path , WEIGHTS_NAME )
159
+ state_dict = torch .load (resolved_weights_file , {})
125
160
model = model_class .from_pretrained (
126
161
model_name_or_path ,
127
162
cache_dir = cache_dir ,
128
163
force_download = force_download ,
129
164
resume_download = resume_download ,
130
165
use_auth_token = use_auth_token ,
131
166
revision = revision ,
167
+ state_dict = state_dict ,
132
168
** kwargs ,
133
169
)
134
170
0 commit comments