Skip to content

Commit

Permalink
Set custom huggingface-hub cache (#3075)
Browse files Browse the repository at this point in the history
* Test huggingface-hub cache

* Set custom cache for huggingface-hub in addition to datasets

* Fix quality due to Optional[Path]

* Fix test_pre_compute_post_compute
  • Loading branch information
albertvillanova authored Sep 30, 2024
1 parent be8f4b4 commit bbd2338
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import datasets.config
import huggingface_hub.constants
from libcommon.dtos import JobInfo

from worker.config import AppConfig
Expand All @@ -28,14 +29,18 @@ def __init__(
)

def set_datasets_cache(self, cache_subdirectory: Optional[Path]) -> None:
datasets.config.HF_DATASETS_CACHE = cache_subdirectory
if cache_subdirectory is None:
return
datasets.config.HF_DATASETS_CACHE = cache_subdirectory / "datasets"
logging.debug(f"datasets data cache set to: {datasets.config.HF_DATASETS_CACHE}")
datasets.config.DOWNLOADED_DATASETS_PATH = (
datasets.config.HF_DATASETS_CACHE / datasets.config.DOWNLOADED_DATASETS_DIR
)
datasets.config.EXTRACTED_DATASETS_PATH = (
datasets.config.HF_DATASETS_CACHE / datasets.config.EXTRACTED_DATASETS_DIR
)
huggingface_hub.constants.HF_HUB_CACHE = cache_subdirectory / "hub"
logging.debug(f"huggingface_hub cache set to: {huggingface_hub.constants.HF_HUB_CACHE}")

def pre_compute(self) -> None:
super().pre_compute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional

import datasets.config
import huggingface_hub.constants
import pytest
from libcommon.dtos import Priority
from libcommon.resources import CacheMongoResource, QueueMongoResource
Expand Down Expand Up @@ -74,7 +75,8 @@ def test_set_datasets_cache(app_config: AppConfig, get_job_runner: GetJobRunner)
base_path = job_runner.base_cache_directory
dummy_path = base_path / "dummy"
job_runner.set_datasets_cache(dummy_path)
assert str(datasets.config.HF_DATASETS_CACHE).startswith(str(dummy_path))
assert datasets.config.HF_DATASETS_CACHE.is_relative_to(dummy_path)
assert huggingface_hub.constants.HF_HUB_CACHE.is_relative_to(dummy_path)


def test_pre_compute_post_compute(app_config: AppConfig, get_job_runner: GetJobRunner) -> None:
Expand All @@ -85,16 +87,22 @@ def test_pre_compute_post_compute(app_config: AppConfig, get_job_runner: GetJobR
job_runner.pre_compute()
datasets_cache_subdirectory = job_runner.cache_subdirectory
assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=True)
assert str(datasets.config.HF_DATASETS_CACHE).startswith(str(datasets_base_path))
assert datasets.config.HF_DATASETS_CACHE.is_relative_to(datasets_base_path)
assert "dummy-job-runner-user-dataset" in str(datasets.config.HF_DATASETS_CACHE)
job_runner.post_compute()
assert_datasets_cache_path(path=datasets_base_path, exists=True)
assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=False, equals=False)
assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=False)


def assert_datasets_cache_path(path: Optional[Path], exists: bool, equals: bool = True) -> None:
def assert_datasets_cache_path(path: Optional[Path], exists: bool) -> None:
assert path is not None
assert path.exists() is exists
assert (datasets.config.HF_DATASETS_CACHE == path) is equals
assert (datasets.config.DOWNLOADED_DATASETS_PATH == path / datasets.config.DOWNLOADED_DATASETS_DIR) is equals
assert (datasets.config.EXTRACTED_DATASETS_PATH == path / datasets.config.EXTRACTED_DATASETS_DIR) is equals
if exists:
datasets_cache_path = path / "datasets"
hub_cache_path = path / "hub"
assert datasets.config.HF_DATASETS_CACHE == datasets_cache_path
assert (
datasets.config.DOWNLOADED_DATASETS_PATH == datasets_cache_path / datasets.config.DOWNLOADED_DATASETS_DIR
)
assert datasets.config.EXTRACTED_DATASETS_PATH == datasets_cache_path / datasets.config.EXTRACTED_DATASETS_DIR
assert huggingface_hub.constants.HF_HUB_CACHE == hub_cache_path

0 comments on commit bbd2338

Please sign in to comment.