From be902d841e44ea748766595e4bc2584dbe0b16ad Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 18 Mar 2021 17:55:44 +0100 Subject: [PATCH] Add HF_HUB_OFFLINE env var (#22) * offline mode simulator * add HF_HUB_OFFLINE env var * add tests * doc tweak * Re-align from transformers and @aaugustin cc @lhoestq Co-authored-by: Julien Chaumond --- src/huggingface_hub/constants.py | 11 ++++ src/huggingface_hub/file_download.py | 93 ++++++++++++++++++++++++---- tests/test_file_download.py | 26 ++++++-- tests/test_offline_utils.py | 36 +++++++++++ tests/testing_utils.py | 70 +++++++++++++++++++++ 5 files changed, 221 insertions(+), 15 deletions(-) create mode 100644 tests/test_offline_utils.py diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index fb64c4988a..b1af0766df 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -1,6 +1,11 @@ import os +# Possible values for env variables + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + # Constants for file downloads PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" @@ -30,3 +35,9 @@ default_cache_path = os.path.join(hf_cache_home, "hub") HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) + +HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "AUTO").upper() +if HF_HUB_OFFLINE in ENV_VARS_TRUE_VALUES: + HF_HUB_OFFLINE = True +else: + HF_HUB_OFFLINE = False diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 31d1602e84..026ef3bb5d 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -6,6 +6,7 @@ import os import sys import tempfile +import time from contextlib import contextmanager from functools import partial from hashlib import sha256 @@ -16,6 +17,7 @@ import requests from filelock import FileLock +from huggingface_hub import constants from . import __version__ from .constants import ( @@ -171,12 +173,73 @@ def http_user_agent( return ua +class OfflineModeIsEnabled(ConnectionError): + pass + + +def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None): + """Raise a OfflineModeIsEnabled error (subclass of ConnectionError) if HF_HUB_OFFLINE is True.""" + if constants.HF_HUB_OFFLINE: + raise OfflineModeIsEnabled( + "Offline mode is enabled." + if msg is None + else "Offline mode is enabled. " + str(msg) + ) + + +def _request_with_retry( + method: str, + url: str, + max_retries: int = 0, + base_wait_time: float = 0.5, + max_wait_time: float = 2, + timeout: float = 10.0, + **params, +) -> requests.Response: + """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff. + + Note that if the environment variable HF_HUB_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised. + + Args: + method (str): HTTP method, such as 'GET' or 'HEAD' + url (str): The URL of the ressource to fetch + max_retries (int): Maximum number of retries, defaults to 0 (no retries) + base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between + retries then grows exponentially, capped by max_wait_time. + max_wait_time (float): Maximum amount of time between two retries, in seconds + **params: Params to pass to `requests.request` + """ + _raise_if_offline_mode_is_enabled(f"Tried to reach {url}") + tries, success = 0, False + while not success: + tries += 1 + try: + response = requests.request( + method=method.upper(), url=url, timeout=timeout, **params + ) + success = True + except requests.exceptions.ConnectTimeout as err: + if tries > max_retries: + raise err + else: + logger.info( + f"{method} request to {url} timed out, retrying... [{tries/max_retries}]" + ) + sleep_time = max( + max_wait_time, base_wait_time * 2 ** (tries - 1) + ) # Exponential backoff + time.sleep(sleep_time) + return response + + def http_get( url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None, + timeout=10.0, + max_retries=0, ): """ Donwload remote file. Do not gobble up errors. @@ -184,7 +247,15 @@ def http_get( headers = copy.deepcopy(headers) if resume_size > 0: headers["Range"] = "bytes=%d-" % (resume_size,) - r = requests.get(url, stream=True, proxies=proxies, headers=headers) + r = _request_with_retry( + method="GET", + url=url, + stream=True, + proxies=proxies, + headers=headers, + timeout=timeout, + max_retries=max_retries, + ) r.raise_for_status() content_length = r.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None @@ -254,8 +325,9 @@ def cached_download( etag = None if not local_files_only: try: - r = requests.head( - url, + r = _request_with_retry( + method="HEAD", + url=url, headers=headers, allow_redirects=False, proxies=proxies, @@ -276,15 +348,14 @@ def cached_download( # between the HEAD and the GET (unlikely, but hey). if 300 <= r.status_code <= 399: url_to_download = r.headers["Location"] + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, - ) as exc: - # Actually raise for those subclasses of ConnectionError: - if isinstance(exc, requests.exceptions.SSLError) or isinstance( - exc, requests.exceptions.ProxyError - ): - raise exc + OfflineModeIsEnabled, + ): # Otherwise, our Internet connection is down. # etag is None pass @@ -297,7 +368,7 @@ def cached_download( # etag is None == we don't have a connection or we passed local_files_only. # try to get the last downloaded one if etag is None: - if os.path.exists(cache_path): + if os.path.exists(cache_path) and not force_download: return cache_path else: matching_files = [ @@ -307,7 +378,7 @@ def cached_download( ) if not file.endswith(".json") and not file.endswith(".lock") ] - if len(matching_files) > 0: + if len(matching_files) > 0 and not force_download: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 9ad0347979..f53b41b289 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -22,7 +22,12 @@ ) from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url -from .testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SAMPLE_DATASET_IDENTIFIER +from .testing_utils import ( + DUMMY_UNKWOWN_IDENTIFIER, + SAMPLE_DATASET_IDENTIFIER, + OfflineSimulationMode, + offline, +) MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER @@ -51,13 +56,26 @@ class CachedDownloadTests(unittest.TestCase): def test_bogus_url(self): - # This lets us simulate no connection - # as the error raised is the same - # `ConnectionError` url = "https://bogus" with self.assertRaisesRegex(ValueError, "Connection error"): _ = cached_download(url) + def test_no_connection(self): + invalid_url = hf_hub_url( + MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID + ) + valid_url = hf_hub_url( + MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT + ) + self.assertIsNotNone(cached_download(valid_url, force_download=True)) + for offline_mode in OfflineSimulationMode: + with offline(mode=offline_mode): + with self.assertRaisesRegex(ValueError, "Connection error"): + _ = cached_download(invalid_url) + with self.assertRaisesRegex(ValueError, "Connection error"): + _ = cached_download(valid_url, force_download=True) + self.assertIsNotNone(cached_download(valid_url)) + def test_file_not_found(self): # Valid revision (None) but missing file. url = hf_hub_url(MODEL_ID, filename="missing.bin") diff --git a/tests/test_offline_utils.py b/tests/test_offline_utils.py new file mode 100644 index 0000000000..076bb29a7d --- /dev/null +++ b/tests/test_offline_utils.py @@ -0,0 +1,36 @@ +from io import BytesIO + +import pytest + +import requests +from huggingface_hub.file_download import http_get + +from .testing_utils import ( + OfflineSimulationMode, + RequestWouldHangIndefinitelyError, + offline, +) + + +def test_offline_with_timeout(): + with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): + with pytest.raises(RequestWouldHangIndefinitelyError): + requests.request("GET", "https://huggingface.co") + with pytest.raises(requests.exceptions.ConnectTimeout): + requests.request("GET", "https://huggingface.co", timeout=1.0) + with pytest.raises(requests.exceptions.ConnectTimeout): + http_get("https://huggingface.co", BytesIO()) + + +def test_offline_with_connection_error(): + with offline(OfflineSimulationMode.CONNECTION_FAILS): + with pytest.raises(requests.exceptions.ConnectionError): + requests.request("GET", "https://huggingface.co") + with pytest.raises(requests.exceptions.ConnectionError): + http_get("https://huggingface.co", BytesIO()) + + +def test_offline_with_datasets_offline_mode_enabled(): + with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1): + with pytest.raises(ConnectionError): + http_get("https://huggingface.co", BytesIO()) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 1c98239a0c..21668ef2e8 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,6 +1,9 @@ import os import unittest +from contextlib import contextmanager from distutils.util import strtobool +from enum import Enum +from unittest.mock import patch SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -55,3 +58,70 @@ def require_git_lfs(test_case): return unittest.skip("test of git lfs workflow")(test_case) else: return test_case + + +class RequestWouldHangIndefinitelyError(Exception): + pass + + +class OfflineSimulationMode(Enum): + CONNECTION_FAILS = 0 + CONNECTION_TIMES_OUT = 1 + HF_HUB_OFFLINE_SET_TO_1 = 2 + + +@contextmanager +def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16): + """ + Simulate offline mode. + + There are three offline simulatiom modes: + + CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call. + Connection errors are created by mocking socket.socket + CONNECTION_TIMES_OUT: the connection hangs until it times out. + The default timeout value is low (1e-16) to speed up the tests. + Timeout errors are created by mocking requests.request + HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1. + This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error. + """ + import socket + + from requests import request as online_request + + def timeout_request(method, url, **kwargs): + # Change the url to an invalid url so that the connection hangs + invalid_url = "https://10.255.255.1" + if kwargs.get("timeout") is None: + raise RequestWouldHangIndefinitelyError( + f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout." + ) + kwargs["timeout"] = timeout + try: + return online_request(method, invalid_url, **kwargs) + except Exception as e: + # The following changes in the error are just here to make the offline timeout error prettier + e.request.url = url + max_retry_error = e.args[0] + max_retry_error.args = ( + max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"), + ) + e.args = (max_retry_error,) + raise + + def offline_socket(*args, **kwargs): + raise socket.error("Offline mode is enabled.") + + if mode is OfflineSimulationMode.CONNECTION_FAILS: + # inspired from https://stackoverflow.com/a/18601897 + with patch("socket.socket", offline_socket): + yield + elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: + # inspired from https://stackoverflow.com/a/904609 + with patch("requests.request", timeout_request): + yield + elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1: + with patch("huggingface_hub.constants.HF_HUB_OFFLINE", True): + yield + else: + raise ValueError("Please use a value from the OfflineSimulationMode enum.")