Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Degraded but fully working cache-system when symlinks are not supported #1067

Merged
merged 16 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import re
import shutil
import sys
import tempfile
import warnings
Expand Down Expand Up @@ -171,6 +172,41 @@ def get_jinja_version():
return _jinja_version


# Check if the platform supports symlinks
_are_symlinks_supported = True
with tempfile.TemporaryDirectory() as tmpdir:
src_path = Path(tmpdir) / "dummy_file_src"
src_path.touch()
dst_path = Path(tmpdir) / "dummy_file_dst"
try:
os.symlink(src_path, dst_path)
except OSError:
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
# Likely running on Windows
_are_symlinks_supported = False

if not os.environ.get("DISABLE_SYMLINKS_WARNING"):
message = (
"`huggingface_hub` cache-system uses symlinks by default to efficiently"
" store duplicated files but your machine doesn't support them. Caching"
" files will still work but in a degraded version that might require"
" more space on your disk. This warning can be disabled by setting the"
" `DISABLE_SYMLINKS_WARNING` environment variable. For more details,"
" see https://huggingface.co/docs/huggingface_hub/package_reference/utilities."
)
if os.name == "nt":
message += (
"\nTo support symlinks on Windows, you either need to activate"
" Developer Mode or to run Python as an administrator. In order to"
" see activate developer mode, see this article:"
" https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
)
warnings.warn(message)


def are_symlinks_supported():
return _are_symlinks_supported


# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
Expand Down Expand Up @@ -859,7 +895,7 @@ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
return etag.strip('"')


def _create_relative_symlink(src: str, dst: str) -> None:
def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
"""Create a symbolic link named dst pointing to src as a relative path to dst.

The relative part is mostly because it seems more elegant to the author.
Expand All @@ -869,25 +905,29 @@ def _create_relative_symlink(src: str, dst: str) -> None:
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
│ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
│ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd

If symlinks cannot be created on this platform (most likely to be Windows), the
workaround is to avoid symlinks by having the actual file in `dst`. If it is a new
file (`new_blob=True`), we move it to `dst`. If it is not a new file
(`new_blob=False`), we don't know if the blob file is already referenced elsewhere.
To avoid breaking existing cache, the file is duplicated on the disk.

In case symlinks are not supported, a warning message is displayed to the user once
when loading `huggingface_hub`. The warning message can be disable with the
`DISABLE_SYMLINKS_WARNING` environment variable.
"""
relative_src = os.path.relpath(src, start=os.path.dirname(dst))
try:
os.remove(dst)
except OSError:
pass
try:

if are_symlinks_supported():
os.symlink(relative_src, dst)
except OSError:
# Likely running on Windows
if os.name == "nt":
raise OSError(
"Windows requires Developer Mode to be activated, or to run Python as "
"an administrator, in order to create symlinks.\nIn order to "
"activate Developer Mode, see this article: "
"https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development"
)
else:
raise
elif new_blob:
os.replace(src, dst)
else:
shutil.copyfile(src, dst)


def _cache_commit_hash_for_specific_revision(
Expand Down Expand Up @@ -1266,7 +1306,7 @@ def hf_hub_download(
if os.path.exists(blob_path) and not force_download:
# we have the blob already, but not the pointer
logger.info("creating pointer to %s from %s", blob_path, pointer_path)
_create_relative_symlink(blob_path, pointer_path)
_create_relative_symlink(blob_path, pointer_path, new_blob=False)
return pointer_path

# Prevent parallel downloads of the same file with a lock.
Expand Down Expand Up @@ -1322,7 +1362,7 @@ def _resumable_file_manager() -> "io.BufferedWriter":
os.replace(temp_file.name, blob_path)

logger.info("creating pointer to %s from %s", blob_path, pointer_path)
_create_relative_symlink(blob_path, pointer_path)
_create_relative_symlink(blob_path, pointer_path, new_blob=True)

try:
os.remove(lock_path)
Expand Down
11 changes: 0 additions & 11 deletions src/huggingface_hub/utils/_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:

snapshots_path = repo_path / "snapshots"
refs_path = repo_path / "refs"
blobs_path = repo_path / "blobs"

if not snapshots_path.exists() or not snapshots_path.is_dir():
raise CorruptedCacheException(
Expand Down Expand Up @@ -679,22 +678,12 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo:
if file_path.is_dir():
continue

if not file_path.is_symlink():
raise CorruptedCacheException(
f"Revision folder corrupted. Found a non-symlink file: {file_path}"
)

blob_path = Path(file_path).resolve()
if not blob_path.exists():
raise CorruptedCacheException(
f"Blob missing (broken symlink): {blob_path}"
)

if blobs_path not in blob_path.parents:
raise CorruptedCacheException(
f"Blob symlink points outside of blob directory: {blob_path}"
)

if blob_path not in blob_stats:
blob_stats[blob_path] = blob_path.stat()

Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import pytest

from _pytest.fixtures import SubRequest


@pytest.fixture
def fx_cache_dir(request: SubRequest) -> Generator[None, None, None]:
"""Add a `cache_dir` attribute pointing to a temporary directory in tests.

Example:
```py
@pytest.mark.usefixtures("fx_cache_dir")
class TestWithCache(unittest.TestCase):
cache_dir: Path

def test_cache_dir(self) -> None:
self.assertTrue(self.cache_dir.is_dir())
```
"""
with TemporaryDirectory() as cache_dir:
request.cls.cache_dir = Path(cache_dir).resolve()
yield
89 changes: 89 additions & 0 deletions tests/test_cache_no_symlinks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import unittest
from pathlib import Path
from unittest.mock import Mock, patch

import pytest

from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.utils import logging

from .testing_constants import TOKEN
from .testing_utils import DUMMY_MODEL_ID, with_production_testing


logger = logging.get_logger(__name__)
MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout"


def get_file_contents(path):
with open(path) as f:
content = f.read()

return content


@with_production_testing
@pytest.mark.usefixtures("fx_cache_dir")
@patch("huggingface_hub.file_download.are_symlinks_supported")
class TestCacheLayoutIfSymlinksNotSupported(unittest.TestCase):
cache_dir: Path

def test_download_no_symlink_new_file(
self, mock_are_symlinks_supported: Mock
) -> None:
mock_are_symlinks_supported.return_value = False
filepath = Path(
hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
cache_dir=self.cache_dir,
local_files_only=False,
use_auth_token=TOKEN,
)
)
# Not a symlink !
self.assertFalse(filepath.is_symlink())
self.assertTrue(filepath.is_file())

# Blobs directory is empty
self.assertEqual(len(list((Path(filepath).parents[2] / "blobs").glob("*"))), 0)

def test_download_no_symlink_existing_file(
self, mock_are_symlinks_supported: Mock
) -> None:
mock_are_symlinks_supported.return_value = True
filepath = Path(
hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
cache_dir=self.cache_dir,
local_files_only=False,
use_auth_token=TOKEN,
)
)
self.assertTrue(filepath.is_symlink())
blob_path = filepath.resolve()
self.assertTrue(blob_path.is_file())

# Delete file in snapshot
filepath.unlink()

# Re-download but symlinks are not supported anymore (example: not an admin)
mock_are_symlinks_supported.return_value = False
new_filepath = Path(
hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
cache_dir=self.cache_dir,
local_files_only=False,
use_auth_token=TOKEN,
)
)
# File exist but is not a symlink
self.assertFalse(new_filepath.is_symlink())
self.assertTrue(new_filepath.is_file())

# Blob file still exists as well (has not been deleted)
# => duplicate file on disk
self.assertTrue(blob_path.is_file())
11 changes: 0 additions & 11 deletions tests/test_utils_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import time
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator
from unittest.mock import Mock

import pytest

from _pytest.fixtures import SubRequest
from huggingface_hub._snapshot_download import snapshot_download
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, scan_cache_dir
Expand All @@ -32,14 +29,6 @@
REPO_A_MAIN_README_BLOB_HASH = "4baf04727c45b660add228b2934001991bd34b29"


@pytest.fixture
def fx_cache_dir(request: SubRequest) -> Generator[None, None, None]:
"""Add a `cache_dir` attribute pointing to a temporary directory."""
with TemporaryDirectory() as cache_dir:
request.cls.cache_dir = Path(cache_dir).resolve()
yield


@pytest.mark.usefixtures("fx_cache_dir")
class TestMissingCacheUtils(unittest.TestCase):
cache_dir: Path
Expand Down