From bbd23388c504f172e3df9c02b4f9fed065f57a4e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 30 Sep 2024 13:23:40 +0200 Subject: [PATCH] Set custom huggingface-hub cache (#3075) * Test huggingface-hub cache * Set custom cache for huggingface-hub in addition to datasets * Fix quality due to Optional[Path] * Fix test_pre_compute_post_compute --- .../_job_runner_with_datasets_cache.py | 7 +++++- .../test__job_runner_with_datasets_cache.py | 22 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/services/worker/src/worker/job_runners/_job_runner_with_datasets_cache.py b/services/worker/src/worker/job_runners/_job_runner_with_datasets_cache.py index d8a6f3696d..d6f3e0e84f 100644 --- a/services/worker/src/worker/job_runners/_job_runner_with_datasets_cache.py +++ b/services/worker/src/worker/job_runners/_job_runner_with_datasets_cache.py @@ -6,6 +6,7 @@ from typing import Optional import datasets.config +import huggingface_hub.constants from libcommon.dtos import JobInfo from worker.config import AppConfig @@ -28,7 +29,9 @@ def __init__( ) def set_datasets_cache(self, cache_subdirectory: Optional[Path]) -> None: - datasets.config.HF_DATASETS_CACHE = cache_subdirectory + if cache_subdirectory is None: + return + datasets.config.HF_DATASETS_CACHE = cache_subdirectory / "datasets" logging.debug(f"datasets data cache set to: {datasets.config.HF_DATASETS_CACHE}") datasets.config.DOWNLOADED_DATASETS_PATH = ( datasets.config.HF_DATASETS_CACHE / datasets.config.DOWNLOADED_DATASETS_DIR @@ -36,6 +39,8 @@ def set_datasets_cache(self, cache_subdirectory: Optional[Path]) -> None: datasets.config.EXTRACTED_DATASETS_PATH = ( datasets.config.HF_DATASETS_CACHE / datasets.config.EXTRACTED_DATASETS_DIR ) + huggingface_hub.constants.HF_HUB_CACHE = cache_subdirectory / "hub" + logging.debug(f"huggingface_hub cache set to: {huggingface_hub.constants.HF_HUB_CACHE}") def pre_compute(self) -> None: super().pre_compute() diff --git a/services/worker/tests/job_runners/test__job_runner_with_datasets_cache.py b/services/worker/tests/job_runners/test__job_runner_with_datasets_cache.py index 5ac69cf83e..f660e32d8e 100644 --- a/services/worker/tests/job_runners/test__job_runner_with_datasets_cache.py +++ b/services/worker/tests/job_runners/test__job_runner_with_datasets_cache.py @@ -6,6 +6,7 @@ from typing import Optional import datasets.config +import huggingface_hub.constants import pytest from libcommon.dtos import Priority from libcommon.resources import CacheMongoResource, QueueMongoResource @@ -74,7 +75,8 @@ def test_set_datasets_cache(app_config: AppConfig, get_job_runner: GetJobRunner) base_path = job_runner.base_cache_directory dummy_path = base_path / "dummy" job_runner.set_datasets_cache(dummy_path) - assert str(datasets.config.HF_DATASETS_CACHE).startswith(str(dummy_path)) + assert datasets.config.HF_DATASETS_CACHE.is_relative_to(dummy_path) + assert huggingface_hub.constants.HF_HUB_CACHE.is_relative_to(dummy_path) def test_pre_compute_post_compute(app_config: AppConfig, get_job_runner: GetJobRunner) -> None: @@ -85,16 +87,22 @@ def test_pre_compute_post_compute(app_config: AppConfig, get_job_runner: GetJobR job_runner.pre_compute() datasets_cache_subdirectory = job_runner.cache_subdirectory assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=True) - assert str(datasets.config.HF_DATASETS_CACHE).startswith(str(datasets_base_path)) + assert datasets.config.HF_DATASETS_CACHE.is_relative_to(datasets_base_path) assert "dummy-job-runner-user-dataset" in str(datasets.config.HF_DATASETS_CACHE) job_runner.post_compute() assert_datasets_cache_path(path=datasets_base_path, exists=True) - assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=False, equals=False) + assert_datasets_cache_path(path=datasets_cache_subdirectory, exists=False) -def assert_datasets_cache_path(path: Optional[Path], exists: bool, equals: bool = True) -> None: +def assert_datasets_cache_path(path: Optional[Path], exists: bool) -> None: assert path is not None assert path.exists() is exists - assert (datasets.config.HF_DATASETS_CACHE == path) is equals - assert (datasets.config.DOWNLOADED_DATASETS_PATH == path / datasets.config.DOWNLOADED_DATASETS_DIR) is equals - assert (datasets.config.EXTRACTED_DATASETS_PATH == path / datasets.config.EXTRACTED_DATASETS_DIR) is equals + if exists: + datasets_cache_path = path / "datasets" + hub_cache_path = path / "hub" + assert datasets.config.HF_DATASETS_CACHE == datasets_cache_path + assert ( + datasets.config.DOWNLOADED_DATASETS_PATH == datasets_cache_path / datasets.config.DOWNLOADED_DATASETS_DIR + ) + assert datasets.config.EXTRACTED_DATASETS_PATH == datasets_cache_path / datasets.config.EXTRACTED_DATASETS_DIR + assert huggingface_hub.constants.HF_HUB_CACHE == hub_cache_path