Skip to content

Commit

Permalink
Use new huggingface_hub tools for download models (huggingface#18438)
Browse files Browse the repository at this point in the history
* Draft new cached_file

* Initial draft for config and model

* Small fixes

* Fix first batch of tests

* Look in cache when internet is down

* Fix last tests

* Bad black, not fixing all quality errors

* Make diff less

* Implement change for TF and Flax models

* Add tokenizer and feature extractor

* For compatibility with main

* Add utils to move the cache and auto-do it at first use.

* Quality

* Deal with empty commit shas

* Deal with empty etag

* Address review comments
  • Loading branch information
sgugger authored and oneraghavan committed Sep 26, 2022
1 parent 042c795 commit 27909c8
Show file tree
Hide file tree
Showing 13 changed files with 662 additions and 545 deletions.
118 changes: 34 additions & 84 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,9 @@

from packaging import version

from requests import HTTPError

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -591,77 +575,43 @@ def _get_config_dict(
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path

is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)

if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
else:
config_file = hf_bucket_url(
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name_or_path,
filename=configuration_file,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
subfolder=subfolder,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f" containing a {configuration_file} file"
)

try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
)

try:
# Load config dict
Expand All @@ -671,10 +621,10 @@ def _get_config_dict(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)

if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
if is_local:
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")

return config_dict, kwargs

Expand Down
103 changes: 35 additions & 68 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,15 @@

import numpy as np

from requests import HTTPError

from .dynamic_module_utils import custom_object_save
from .utils import (
FEATURE_EXTRACTOR_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
cached_path,
cached_file,
copy_func,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
logging,
Expand Down Expand Up @@ -388,64 +380,40 @@ def get_feature_extractor_dict(
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
else:
feature_extractor_file = hf_bucket_url(
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
)

try:
# Load from URL or cache if already cached
resolved_feature_extractor_file = cached_path(
feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)

except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
" the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {FEATURE_EXTRACTOR_NAME} file"
)
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_feature_extractor_file = cached_file(
pretrained_model_name_or_path,
feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
)

try:
# Load feature_extractor dict
Expand All @@ -458,12 +426,11 @@ def get_feature_extractor_dict(
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)

if resolved_feature_extractor_file == feature_extractor_file:
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
if is_local:
logger.info(f"loading configuration file {resolved_feature_extractor_file}")
else:
logger.info(
f"loading feature extractor configuration file {feature_extractor_file} from cache at"
f" {resolved_feature_extractor_file}"
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)

return feature_extractor_dict, kwargs
Expand Down
Loading

0 comments on commit 27909c8

Please sign in to comment.