Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix download local dir edge case (remove lru_cache) #2629

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/huggingface_hub/_local_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import os
import time
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -179,7 +178,6 @@ def save(self, paths: LocalUploadFilePaths) -> None:
self.timestamp = new_timestamp


@lru_cache(maxsize=128) # ensure singleton
def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFilePaths:
"""Compute paths to the files related to a download process.

Expand Down Expand Up @@ -220,7 +218,6 @@ def get_local_download_paths(local_dir: Path, filename: str) -> LocalDownloadFil
return LocalDownloadFilePaths(file_path=file_path, lock_path=lock_path, metadata_path=metadata_path)


@lru_cache(maxsize=128) # ensure singleton
def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePaths:
"""Compute paths to the files related to an upload process.

Expand Down Expand Up @@ -404,7 +401,6 @@ def write_download_metadata(local_dir: Path, filename: str, commit_hash: str, et
f.write(f"{commit_hash}\n{etag}\n{time.time()}\n")


@lru_cache()
def _huggingface_dir(local_dir: Path) -> Path:
"""Return the path to the `.cache/huggingface` directory in a local directory."""
# Wrap in lru_cache to avoid overwriting the .gitignore file if called multiple times
Expand Down
34 changes: 20 additions & 14 deletions tests/test_local_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ def test_local_download_paths(tmp_path: Path):
assert paths.incomplete_path("etag123").parent.is_dir()


def test_local_download_paths_are_cached(tmp_path: Path):
"""Test local download paths are cached."""
# No need for an exact singleton here.
# We just want to avoid recreating the dataclass on consecutive calls (happens often
# in the process).
def test_local_download_paths_are_recreated_each_time(tmp_path: Path):
paths1 = get_local_download_paths(tmp_path, "path/in/repo.txt")
assert paths1.file_path.parent.is_dir()
assert paths1.metadata_path.parent.is_dir()

paths1.file_path.parent.rmdir()
paths1.metadata_path.parent.rmdir()

paths2 = get_local_download_paths(tmp_path, "path/in/repo.txt")
assert paths1 is paths2
assert paths2.file_path.parent.is_dir()
assert paths2.metadata_path.parent.is_dir()


@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test.")
Expand Down Expand Up @@ -198,14 +201,17 @@ def test_local_upload_paths(tmp_path: Path):
assert paths.lock_path.parent.is_dir()


def test_local_upload_paths_are_cached(tmp_path: Path):
"""Test local upload paths are cached."""
# No need for an exact singleton here.
# We just want to avoid recreating the dataclass on consecutive calls (happens often
# in the process).
paths1 = get_local_download_paths(tmp_path, "path/in/repo.txt")
paths2 = get_local_download_paths(tmp_path, "path/in/repo.txt")
assert paths1 is paths2
def test_local_upload_paths_are_recreated_each_time(tmp_path: Path):
paths1 = get_local_upload_paths(tmp_path, "path/in/repo.txt")
assert paths1.file_path.parent.is_dir()
assert paths1.metadata_path.parent.is_dir()

paths1.file_path.parent.rmdir()
paths1.metadata_path.parent.rmdir()

paths2 = get_local_upload_paths(tmp_path, "path/in/repo.txt")
assert paths2.file_path.parent.is_dir()
assert paths2.metadata_path.parent.is_dir()


@pytest.mark.skipif(os.name != "nt", reason="Windows-specific test.")
Expand Down
Loading