Skip to content

Commit

Permalink
Add HF_HUB_OFFLINE env var (#22)
Browse files Browse the repository at this point in the history
* offline mode simulator

* add HF_HUB_OFFLINE env var

* add tests

* doc tweak

* Re-align from transformers and @aaugustin

cc @lhoestq

Co-authored-by: Julien Chaumond <julien@huggingface.co>
  • Loading branch information
lhoestq and julien-c authored Mar 18, 2021
1 parent 5b1c1e6 commit be902d8
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 15 deletions.
11 changes: 11 additions & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os


# Possible values for env variables

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

# Constants for file downloads

PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -30,3 +35,9 @@
default_cache_path = os.path.join(hf_cache_home, "hub")

HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path)

HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "AUTO").upper()
if HF_HUB_OFFLINE in ENV_VARS_TRUE_VALUES:
HF_HUB_OFFLINE = True
else:
HF_HUB_OFFLINE = False
93 changes: 82 additions & 11 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import sys
import tempfile
import time
from contextlib import contextmanager
from functools import partial
from hashlib import sha256
Expand All @@ -16,6 +17,7 @@

import requests
from filelock import FileLock
from huggingface_hub import constants

from . import __version__
from .constants import (
Expand Down Expand Up @@ -171,20 +173,89 @@ def http_user_agent(
return ua


class OfflineModeIsEnabled(ConnectionError):
pass


def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None):
"""Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_HUB_OFFLINE is True."""
if constants.HF_HUB_OFFLINE:
raise OfflineModeIsEnabled(
"Offline mode is enabled."
if msg is None
else "Offline mode is enabled. " + str(msg)
)


def _request_with_retry(
method: str,
url: str,
max_retries: int = 0,
base_wait_time: float = 0.5,
max_wait_time: float = 2,
timeout: float = 10.0,
**params,
) -> requests.Response:
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
Note that if the environment variable HF_HUB_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
Args:
method (str): HTTP method, such as 'GET' or 'HEAD'
url (str): The URL of the ressource to fetch
max_retries (int): Maximum number of retries, defaults to 0 (no retries)
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
retries then grows exponentially, capped by max_wait_time.
max_wait_time (float): Maximum amount of time between two retries, in seconds
**params: Params to pass to `requests.request`
"""
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
tries, success = 0, False
while not success:
tries += 1
try:
response = requests.request(
method=method.upper(), url=url, timeout=timeout, **params
)
success = True
except requests.exceptions.ConnectTimeout as err:
if tries > max_retries:
raise err
else:
logger.info(
f"{method} request to {url} timed out, retrying... [{tries/max_retries}]"
)
sleep_time = max(
max_wait_time, base_wait_time * 2 ** (tries - 1)
) # Exponential backoff
time.sleep(sleep_time)
return response


def http_get(
url: str,
temp_file: BinaryIO,
proxies=None,
resume_size=0,
headers: Optional[Dict[str, str]] = None,
timeout=10.0,
max_retries=0,
):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r = _request_with_retry(
method="GET",
url=url,
stream=True,
proxies=proxies,
headers=headers,
timeout=timeout,
max_retries=max_retries,
)
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
Expand Down Expand Up @@ -254,8 +325,9 @@ def cached_download(
etag = None
if not local_files_only:
try:
r = requests.head(
url,
r = _request_with_retry(
method="HEAD",
url=url,
headers=headers,
allow_redirects=False,
proxies=proxies,
Expand All @@ -276,15 +348,14 @@ def cached_download(
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
) as exc:
# Actually raise for those subclasses of ConnectionError:
if isinstance(exc, requests.exceptions.SSLError) or isinstance(
exc, requests.exceptions.ProxyError
):
raise exc
OfflineModeIsEnabled,
):
# Otherwise, our Internet connection is down.
# etag is None
pass
Expand All @@ -297,7 +368,7 @@ def cached_download(
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
if os.path.exists(cache_path) and not force_download:
return cache_path
else:
matching_files = [
Expand All @@ -307,7 +378,7 @@ def cached_download(
)
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
if len(matching_files) > 0 and not force_download:
return os.path.join(cache_dir, matching_files[-1])
else:
# If files cannot be found and local_files_only=True,
Expand Down
26 changes: 22 additions & 4 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
)
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url

from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SAMPLE_DATASET_IDENTIFIER
from .testing_utils import (
DUMMY_UNKWOWN_IDENTIFIER,
SAMPLE_DATASET_IDENTIFIER,
OfflineSimulationMode,
offline,
)


MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
Expand Down Expand Up @@ -51,13 +56,26 @@

class CachedDownloadTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = cached_download(url)

def test_no_connection(self):
invalid_url = hf_hub_url(
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID
)
valid_url = hf_hub_url(
MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT
)
self.assertIsNotNone(cached_download(valid_url, force_download=True))
for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = cached_download(invalid_url)
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = cached_download(valid_url, force_download=True)
self.assertIsNotNone(cached_download(valid_url))

def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_hub_url(MODEL_ID, filename="missing.bin")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_offline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from io import BytesIO

import pytest

import requests
from huggingface_hub.file_download import http_get

from .testing_utils import (
OfflineSimulationMode,
RequestWouldHangIndefinitelyError,
offline,
)


def test_offline_with_timeout():
with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT):
with pytest.raises(RequestWouldHangIndefinitelyError):
requests.request("GET", "https://huggingface.co")
with pytest.raises(requests.exceptions.ConnectTimeout):
requests.request("GET", "https://huggingface.co", timeout=1.0)
with pytest.raises(requests.exceptions.ConnectTimeout):
http_get("https://huggingface.co", BytesIO())


def test_offline_with_connection_error():
with offline(OfflineSimulationMode.CONNECTION_FAILS):
with pytest.raises(requests.exceptions.ConnectionError):
requests.request("GET", "https://huggingface.co")
with pytest.raises(requests.exceptions.ConnectionError):
http_get("https://huggingface.co", BytesIO())


def test_offline_with_datasets_offline_mode_enabled():
with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1):
with pytest.raises(ConnectionError):
http_get("https://huggingface.co", BytesIO())
70 changes: 70 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import unittest
from contextlib import contextmanager
from distutils.util import strtobool
from enum import Enum
from unittest.mock import patch


SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
Expand Down Expand Up @@ -55,3 +58,70 @@ def require_git_lfs(test_case):
return unittest.skip("test of git lfs workflow")(test_case)
else:
return test_case


class RequestWouldHangIndefinitelyError(Exception):
pass


class OfflineSimulationMode(Enum):
CONNECTION_FAILS = 0
CONNECTION_TIMES_OUT = 1
HF_HUB_OFFLINE_SET_TO_1 = 2


@contextmanager
def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16):
"""
Simulate offline mode.
There are three offline simulatiom modes:
CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call.
Connection errors are created by mocking socket.socket
CONNECTION_TIMES_OUT: the connection hangs until it times out.
The default timeout value is low (1e-16) to speed up the tests.
Timeout errors are created by mocking requests.request
HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1.
This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error.
"""
import socket

from requests import request as online_request

def timeout_request(method, url, **kwargs):
# Change the url to an invalid url so that the connection hangs
invalid_url = "https://10.255.255.1"
if kwargs.get("timeout") is None:
raise RequestWouldHangIndefinitelyError(
f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout."
)
kwargs["timeout"] = timeout
try:
return online_request(method, invalid_url, **kwargs)
except Exception as e:
# The following changes in the error are just here to make the offline timeout error prettier
e.request.url = url
max_retry_error = e.args[0]
max_retry_error.args = (
max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),
)
e.args = (max_retry_error,)
raise

def offline_socket(*args, **kwargs):
raise socket.error("Offline mode is enabled.")

if mode is OfflineSimulationMode.CONNECTION_FAILS:
# inspired from https://stackoverflow.com/a/18601897
with patch("socket.socket", offline_socket):
yield
elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT:
# inspired from https://stackoverflow.com/a/904609
with patch("requests.request", timeout_request):
yield
elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1:
with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True):
yield
else:
raise ValueError("Please use a value from the OfflineSimulationMode enum.")

0 comments on commit be902d8

Please sign in to comment.