diff --git a/setup.py b/setup.py index 6c1835d3b15..a39f012b23f 100644 --- a/setup.py +++ b/setup.py @@ -171,6 +171,7 @@ "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", "lz4", + "moto[server]", "pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0 "py7zr", "rarfile>=4.0", diff --git a/src/datasets/download/download_config.py b/src/datasets/download/download_config.py index c1fdf9be6a2..1f4d46ab009 100644 --- a/src/datasets/download/download_config.py +++ b/src/datasets/download/download_config.py @@ -93,8 +93,6 @@ def __post_init__(self, use_auth_token): FutureWarning, ) self.token = use_auth_token - if "hf" not in self.storage_options: - self.storage_options["hf"] = {"token": self.token, "endpoint": config.HF_ENDPOINT} def copy(self) -> "DownloadConfig": return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()}) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index d51a344ee29..38c4bbf8f44 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -198,6 +198,9 @@ def cached_path( if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) + url_or_filename, storage_options = _prepare_path_and_storage_options( + url_or_filename, download_config=download_config + ) output_path = get_from_cache( url_or_filename, cache_dir=cache_dir, @@ -210,7 +213,7 @@ def cached_path( max_retries=download_config.max_retries, token=download_config.token, ignore_url_params=download_config.ignore_url_params, - storage_options=download_config.storage_options, + storage_options=storage_options, download_desc=download_config.download_desc, disable_tqdm=download_config.disable_tqdm, ) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index c0358187724..84e418eba3a 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -1,7 +1,7 @@ import os import re from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import zstandard as zstd @@ -77,6 +77,27 @@ def tmpfs_file(tmpfs): return FILE_PATH +@pytest.mark.parametrize("protocol", ["hf", "s3"]) +def test_cached_path_protocols(protocol, monkeypatch, tmp_path): + # GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf' + mock_fsspec_head = MagicMock(return_value={}) + mock_fsspec_get = MagicMock(return_value=None) + monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head) + monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get) + cache_dir = tmp_path / "cache" + storage_options = {} if protocol == "hf" else {"s3": {"anon": True}} + download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options) + urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"} + url = urls[protocol] + _ = cached_path(url, download_config=download_config) + assert True + for mock in [mock_fsspec_head, mock_fsspec_get]: + assert mock.called + assert mock.call_count == 1 + assert mock.call_args.args[0] == url + assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol] + + @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} diff --git a/tests/test_load.py b/tests/test_load.py index 3bc272b4c60..38d6852fcf5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -54,6 +54,7 @@ assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, offline, + require_moto, require_not_windows, require_pil, require_sndfile, @@ -1686,6 +1687,51 @@ def test_load_from_disk_with_default_in_memory( _ = load_from_disk(dataset_path) +@pytest.fixture +def moto_server(monkeypatch): + from moto.server import ThreadedMotoServer + + monkeypatch.setattr( + "os.environ", + { + "AWS_ENDPOINT_URL": "http://localhost:5000", + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "FOO", + "AWS_SECRET_ACCESS_KEY": "BAR", + }, + ) + server = ThreadedMotoServer() + server.start() + try: + yield + finally: + server.stop() + + +@require_moto +def test_load_file_from_s3(moto_server): + # we need server mode here because of an aiobotocore incompatibility with moto.mock_aws + # (https://github.com/getmoto/moto/issues/6836) + import boto3 + + # Create a mock S3 bucket + bucket_name = "test-bucket" + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket=bucket_name) + + # Upload a file to the mock bucket + key = "test-file.csv" + csv_data = "Island\nIsabela\nBaltra" + + s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data) + + # Load the file from the mock bucket + ds = datasets.load_dataset("csv", data_files={"train": "s3://test-bucket/test-file.csv"}) + + # Check if the loaded content matches the original content + assert list(ds["train"]) == [{"Island": "Isabela"}, {"Island": "Baltra"}] + + @pytest.mark.integration def test_remote_data_files(): repo_id = "hf-internal-testing/raw_jsonl" diff --git a/tests/utils.py b/tests/utils.py index 73bbf77bf7a..18fe27d8af1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -75,7 +75,7 @@ def parse_flag_from_env(key, default=False): require_faiss = pytest.mark.skipif(find_spec("faiss") is None or sys.platform == "win32", reason="test requires faiss") - +require_moto = pytest.mark.skipif(find_spec("moto") is None, reason="test requires moto") require_numpy1_on_windows = pytest.mark.skipif( version.parse(importlib.metadata.version("numpy")) >= version.parse("2.0.0") and sys.platform == "win32", reason="test requires numpy < 2.0 on windows",