Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit arguments in from_pretrained #24306

Merged
merged 19 commits into from
Jun 21, 2023
47 changes: 45 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,43 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
def _set_token_in_kwargs(self, kwargs, token=None):
"""Temporary method to deal with `token` and `use_auth_token`.

This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.

Need to clean up `use_auth_token` in a follow PR.
"""
# Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
if token is None:
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token
Comment on lines +470 to +494
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have quite some model classes (e.g. clip-like family) whose config classes have their own from_pretrained.

This private _set_token_in_kwargs method is to make life easier when dealing with token and use_auth_token.

We will see what the best way is in a follow up PR when we want to make those customized from_pretrained with explicit arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger Now it's the good time if you are motivated after 🦷.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh... Annoying! This fix works in the meantime.


@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
ydshieh marked this conversation as resolved.
Show resolved Hide resolved
revision: str = "main",
**kwargs,
) -> "PretrainedConfig":
r"""
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.

Expand All @@ -493,7 +529,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*):
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
Expand Down Expand Up @@ -544,6 +580,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
assert config.output_attentions == True
assert unused_kwargs == {"foo": False}
```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision

cls._set_token_in_kwargs(kwargs, token)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
Expand Down
34 changes: 31 additions & 3 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import json
import os
import warnings
from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -255,8 +256,15 @@ def _set_processor_class(self, processor_class: str):

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> PreTrainedFeatureExtractor:
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r"""
Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
derived class of [`SequenceFeatureExtractor`].
Expand Down Expand Up @@ -285,7 +293,7 @@ def from_pretrained(
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*):
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
Expand Down Expand Up @@ -335,6 +343,26 @@ def from_pretrained(
assert feature_extractor.return_attention_mask is False
assert unused_kwargs == {"foo": False}
```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision

use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token

feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)

return cls.from_dict(feature_extractor_dict, **kwargs)
Expand Down
24 changes: 18 additions & 6 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
import os
import warnings
from typing import Any, Dict, Optional, Union

from .. import __version__
Expand Down Expand Up @@ -382,6 +383,11 @@ def from_pretrained(
cls,
pretrained_model_name: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
) -> "GenerationConfig":
r"""
Expand Down Expand Up @@ -410,7 +416,7 @@ def from_pretrained(
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*):
token (`str` or `bool`, *optional*):
ydshieh marked this conversation as resolved.
Show resolved Hide resolved
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
Expand Down Expand Up @@ -470,18 +476,24 @@ def from_pretrained(
```"""
config_file_name = config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
Expand Down Expand Up @@ -509,7 +521,7 @@ def from_pretrained(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
use_auth_token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
Expand Down
34 changes: 32 additions & 2 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import json
import os
import warnings
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -81,7 +82,16 @@ def _set_processor_class(self, processor_class: str):
self._processor_class = processor_class

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r"""
Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.

Expand Down Expand Up @@ -109,7 +119,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or `bool`, *optional*):
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
Expand Down Expand Up @@ -162,6 +172,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
assert image_processor.do_normalize is False
assert unused_kwargs == {"foo": False}
```"""
kwargs["cache_dir"] = cache_dir
kwargs["force_download"] = force_download
kwargs["local_files_only"] = local_files_only
kwargs["revision"] = revision

use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

if token is not None:
# change to `token` in a follow-up PR
kwargs["use_auth_token"] = token

image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)

return cls.from_dict(image_processor_dict, **kwargs)
Expand Down
38 changes: 25 additions & 13 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import json
import os
import re
import warnings
from functools import partial
from pickle import UnpicklingError
from typing import Any, Dict, Set, Tuple, Union
from typing import Any, Dict, Optional, Set, Tuple, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -485,6 +486,13 @@ def from_pretrained(
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
r"""
Expand Down Expand Up @@ -558,7 +566,7 @@ def from_pretrained(
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or `bool`, *optional*):
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
Expand Down Expand Up @@ -603,23 +611,27 @@ def from_pretrained(
>>> config = BertConfig.from_json_file("./pt_model/config.json")
>>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
```"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token

if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
Expand All @@ -645,7 +657,7 @@ def from_pretrained(
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
Expand Down Expand Up @@ -715,7 +727,7 @@ def from_pretrained(
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"use_auth_token": token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
Expand Down Expand Up @@ -746,7 +758,7 @@ def from_pretrained(
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
"use_auth_token": token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
Expand Down Expand Up @@ -797,7 +809,7 @@ def from_pretrained(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
use_auth_token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
Expand Down Expand Up @@ -986,7 +998,7 @@ def from_pretrained(
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
Expand Down
Loading