Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new huggingface_hub tools for download models #18438

Merged
merged 16 commits into from
Aug 5, 2022
111 changes: 27 additions & 84 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,9 @@

from packaging import version

from requests import HTTPError

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -584,77 +568,36 @@ def _get_config_dict(
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
resolved_config_file = pretrained_model_name_or_path
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)

if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
else:
config_file = hf_bucket_url(
try:
# Load from URL or cache if already cached
resolved_config_file = cached_file(
pretrained_model_name_or_path,
filename=configuration_file,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
subfolder=subfolder,
)
except EnvironmentError:
raise
except Exception:
raise EnvironmentError(
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f" containing a {configuration_file} file"
)

try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
)

try:
# Load config dict
Expand All @@ -664,10 +607,10 @@ def _get_config_dict(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)

if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")

return config_dict, kwargs

Expand Down
146 changes: 45 additions & 101 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss

from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled

Expand All @@ -51,24 +50,18 @@
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
cached_file,
copy_func,
has_file,
hf_bucket_url,
is_accelerate_available,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1904,9 +1897,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
Expand All @@ -1923,64 +1914,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
filename = FLAX_WEIGHTS_NAME
else:
filename = WEIGHTS_NAME
filename

archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)

try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == WEIGHTS_NAME:
try:
if archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
is_sharded = True
except EntryNotFoundError:
if archive_file is not None:
is_sharded = True
if archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
Expand All @@ -2006,45 +1965,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
else:
except EnvironmentError:
raise
except Exception:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}."
)

if resolved_archive_file == archive_file:
if os.path.isdir(pretrained_model_name_or_path):
logger.info(f"loading weights file {archive_file}")
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {archive_file}")
else:
resolved_archive_file = None
archive_file = None

# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
# archive_file becomes a list of files that point to the different checkpoint shards in this case.
archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand All @@ -2061,7 +2005,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
state_dict = load_state_dict(archive_file)

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
Expand All @@ -2077,7 +2021,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
one_state_dict = load_state_dict(archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
Expand Down Expand Up @@ -2135,16 +2079,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

if from_tf:
if resolved_archive_file.endswith(".index"):
if archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
model = cls.load_tf_weights(model, config, archive_file[:-6]) # Remove the '.index'
else:
# Load from our TensorFlow 2.0 checkpoints
try:
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model

model, loading_info = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
model, archive_file, allow_missing_keys=True, output_loading_info=True
)
except ImportError:
logger.error(
Expand All @@ -2157,7 +2101,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
try:
from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model

model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
model = load_flax_checkpoint_in_pytorch_model(model, archive_file)
except ImportError:
logger.error(
"Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
Expand All @@ -2175,7 +2119,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_file,
cached_path,
default_cache_path,
define_sagemaker_information,
Expand Down
Loading