Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix cached_path for hub downloads (#5141)
Browse files Browse the repository at this point in the history
* fix cached_path for hub downloads

* fix test name

* fix type hint

* Update allennlp/common/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
epwalsh and LysandreJik authored Apr 22, 2021
1 parent f877fdc commit 7fc5a91
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Fixed `cached_path()` for "hf://" files.


## [v2.3.1](https://github.com/allenai/allennlp/releases/tag/v2.3.1) - 2021-04-20

Expand Down
154 changes: 103 additions & 51 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,17 @@
import boto3
import botocore
import torch
from botocore.exceptions import ClientError, EndpointConnectionError
from filelock import FileLock as _FileLock
import numpy as np
from overrides import overrides
import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from requests.packages.urllib3.util.retry import Retry
import lmdb
from torch import Tensor
from huggingface_hub import hf_hub_url, cached_download, snapshot_download
from allennlp.version import VERSION
import huggingface_hub as hf_hub

from allennlp.version import VERSION
from allennlp.common.tqdm import Tqdm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -208,15 +206,34 @@ def cached_path(
force_extract: bool = False,
) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Given something that might be a URL or local path, determine which.
If it's a remote resource, download the file and cache it, and
then return the path to the cached file. If it's already a local path,
make sure the file exists and return the path.
For URLs, "http://", "https://", "s3://", and "hf://" are all supported.
The latter corresponds to the HuggingFace Hub.
For example, to download the PyTorch weights for the model `epwalsh/bert-xsmall-dummy`
on HuggingFace, you could do:
```python
cached_path("hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin")
```
For paths or URLs that point to a tarfile or zipfile, you can also add a path
to a specific file to the `url_or_filename` preceeded by a "!", and the archive will
be automatically extracted (provided you set `extract_archive` to `True`),
returning the local path to the specific file. For example:
```python
cached_path("model.tar.gz!weights.th", extract_archive=True)
```
# Parameters
url_or_filename : `Union[str, Path]`
A URL or local file to parse and possibly download.
A URL or path to parse and possibly download.
cache_dir : `Union[str, Path]`, optional (default = `None`)
The directory to cache downloads.
Expand All @@ -235,47 +252,11 @@ def cached_path(
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

extraction_path: Optional[str] = None

if not isinstance(url_or_filename, str):
url_or_filename = str(url_or_filename)

if url_or_filename.startswith("hf://"):
# Remove the hf:// prefix
identifier = url_or_filename[5:]

filename: Optional[str]
if len(identifier.split("/")) > 2:
filename = "/".join(identifier.split("/")[2:])
model_identifier = "/".join(identifier.split("/")[:2])
else:
filename = None
model_identifier = identifier

revision: Optional[str]
if "@" in model_identifier:
repo_id = model_identifier.split("@")[0]
revision = model_identifier.split("@")[1]
else:
repo_id = model_identifier
revision = None

if filename is not None:
url = hf_hub_url(repo_id=repo_id, filename=filename, revision=revision)
url_or_filename = str(
cached_download(
url=url,
library_name="allennlp",
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)
else:
extraction_path = snapshot_download(
repo_id, revision=revision, cache_dir=CACHE_DIRECTORY
)

file_path: str
extraction_path: Optional[str] = None

# If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
exclamation_index = url_or_filename.find("!")
Expand All @@ -300,7 +281,7 @@ def cached_path(

parsed = urlparse(url_or_filename)

if parsed.scheme in ("http", "https", "s3") and extraction_path is None:
if parsed.scheme in ("http", "https", "s3", "hf"):
# URL, so get it from the cache (downloading if necessary)
file_path = get_from_cache(url_or_filename, cache_dir)

Expand All @@ -309,7 +290,7 @@ def cached_path(
# For example ~/.allennlp/cache/234234.21341 -> ~/.allennlp/cache/234234.21341-extracted
extraction_path = file_path + "-extracted"

elif extraction_path is None:
else:
url_or_filename = os.path.expanduser(url_or_filename)

if os.path.exists(url_or_filename):
Expand Down Expand Up @@ -418,7 +399,7 @@ def _s3_request(func: Callable):
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
except botocore.exceptions.ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url))
else:
Expand Down Expand Up @@ -476,7 +457,7 @@ def _http_etag(url: str) -> Optional[str]:
with _session_with_backoff() as session:
response = session.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError(
raise OSError(
"HEAD request failed for url {} with status code {}".format(url, response.status_code)
)
return response.headers.get("ETag")
Expand Down Expand Up @@ -858,6 +839,53 @@ def from_path(cls, path: Union[str, Path]) -> "_Meta":
return cls(**data)


def _hf_hub_download(
url, model_identifier: str, filename: Optional[str], cache_dir: Union[str, Path]
) -> str:
revision: Optional[str]
if "@" in model_identifier:
repo_id = model_identifier.split("@")[0]
revision = model_identifier.split("@")[1]
else:
repo_id = model_identifier
revision = None

if filename is not None:
hub_url = hf_hub.hf_hub_url(repo_id=repo_id, filename=filename, revision=revision)
cache_path = str(
hf_hub.cached_download(
url=hub_url,
library_name="allennlp",
library_version=VERSION,
cache_dir=cache_dir,
)
)
# HF writes it's own meta '.json' file which uses the same format we used to use and still
# support, but is missing some fields that we like to have.
# So we overwrite it when it we can.
with FileLock(cache_path + ".lock", read_only_ok=True):
meta = _Meta.from_path(cache_path + ".json")
# The file HF writes will have 'resource' set to the 'http' URL corresponding to the 'hf://' URL,
# but we want 'resource' to be the original 'hf://' URL.
if meta.resource != url:
meta.resource = url
meta.to_file()
else:
cache_path = str(hf_hub.snapshot_download(repo_id, revision=revision, cache_dir=cache_dir))
# Need to write the meta file for snapshot downloads if it doesn't exist.
with FileLock(cache_path + ".lock", read_only_ok=True):
if not os.path.exists(cache_path + ".json"):
meta = _Meta(
resource=url,
cached_path=cache_path,
creation_time=time.time(),
extraction_dir=True,
size=_get_resource_size(cache_path),
)
meta.to_file()
return cache_path


# TODO(joelgrus): do we want to do checksums or anything like that?
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
"""
Expand All @@ -867,13 +895,37 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
if cache_dir is None:
cache_dir = CACHE_DIRECTORY

if url.startswith("hf://"):
# Remove the 'hf://' prefix
identifier = url[5:]

if identifier.count("/") > 1:
filename = "/".join(identifier.split("/")[2:])
model_identifier = "/".join(identifier.split("/")[:2])
return _hf_hub_download(url, model_identifier, filename, cache_dir)
elif identifier.count("/") == 1:
# 'hf://' URLs like 'hf://xxxx/yyyy' are potentially ambiguous,
# because this could refer to either:
# 1. the file 'yyyy' in the 'xxxx' repository, or
# 2. the repo 'yyyy' under the user/org name 'xxxx'.
# We default to (1), but if we get a 404 error then we try (2).
try:
model_identifier, filename = identifier.split("/")
return _hf_hub_download(url, model_identifier, filename, cache_dir)
except requests.exceptions.HTTPError as exc:
if exc.response.status_code == 404:
return _hf_hub_download(url, identifier, None, cache_dir)
raise
else:
return _hf_hub_download(url, identifier, None, cache_dir)

# Get eTag to add to filename, if it exists.
try:
if url.startswith("s3://"):
etag = _s3_etag(url)
else:
etag = _http_etag(url)
except (ConnectionError, EndpointConnectionError):
except (requests.exceptions.ConnectionError, botocore.exceptions.EndpointConnectionError):
# We might be offline, in which case we don't want to throw an error
# just yet. Instead, we'll try to use the latest cached version of the
# target resource, if it exists. We'll only throw an exception if we
Expand Down
17 changes: 17 additions & 0 deletions tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,20 @@ def test_cached_download(self):
def test_snapshot_download(self):
predictor = Predictor.from_path("hf://lysandre/test-simple-tagger-tiny")
assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"

def test_cached_download_no_user_or_org(self):
path = cached_path("hf://t5-small/config.json", cache_dir=self.TEST_DIR)
assert os.path.isfile(path)
assert pathlib.Path(os.path.dirname(path)) == self.TEST_DIR
assert os.path.isfile(path + ".json")
meta = _Meta.from_path(path + ".json")
assert meta.etag is not None
assert meta.resource == "hf://t5-small/config.json"

def test_snapshot_download_no_user_or_org(self):
path = cached_path("hf://t5-small", cache_dir=self.TEST_DIR)
assert os.path.isdir(path)
assert pathlib.Path(os.path.dirname(path)) == self.TEST_DIR
assert os.path.isfile(path + ".json")
meta = _Meta.from_path(path + ".json")
assert meta.resource == "hf://t5-small"

0 comments on commit 7fc5a91

Please sign in to comment.