Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load_dataset for data_files with protocols other than HF #6862

Merged
Merged
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
"moto[server]",
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __post_init__(self, use_auth_token):
FutureWarning,
)
self.token = use_auth_token
if "hf" not in self.storage_options:
if self.token is not None and "hf" not in self.storage_options:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried that removing this altogether might break functionality for someone. This still might break CI but let's see before spending more time on it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might have to revert this, since we need the endpoint to be in the storage options even if token is None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to account for that here. This way _prepare_path_and_storage_options ensures that the endpoint is in the storage options without needing to populate it for all protocols in __post_init__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can completely remove the default "hf" storage option from DownloadConfig, now that it is properly handled by the call to _prepare_path_and_storage_options made in cached_path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you agree, @lhoestq?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sounds good :)

self.storage_options["hf"] = {"token": self.token, "endpoint": config.HF_ENDPOINT}

def copy(self) -> "DownloadConfig":
Expand Down
44 changes: 44 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import shutil
import tempfile
import time
from contextlib import contextmanager
from hashlib import sha256
from multiprocessing import Pool
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import boto3
import dill
import pyarrow as pa
import pytest
import requests
from moto.server import ThreadedMotoServer

import datasets
from datasets import config, load_dataset, load_from_disk
Expand Down Expand Up @@ -1648,6 +1651,47 @@ def test_load_from_disk_with_default_in_memory(
_ = load_from_disk(dataset_path)


@contextmanager
def moto_server():
with patch.dict(
os.environ,
{
"AWS_ENDPOINT_URL": "http://localhost:5000",
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_ACCESS_KEY_ID": "FOO",
"AWS_SECRET_ACCESS_KEY": "BAR",
},
):
server = ThreadedMotoServer()
server.start()
try:
yield
finally:
server.stop()


def test_load_file_from_s3():
# we need server mode here because of an aiobotocore incompatibility with moto.mock_aws
# (https://github.com/getmoto/moto/issues/6836)
with moto_server():
# Create a mock S3 bucket
bucket_name = "test-bucket"
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket_name)

# Upload a file to the mock bucket
key = "test-file.csv"
csv_data = "Island\nIsabela\nBaltra"

s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data)

# Load the file from the mock bucket
ds = datasets.load_dataset("csv", data_files={"train": "s3://test-bucket/test-file.csv"})

# Check if the loaded content matches the original content
assert list(ds["train"]) == [{"Island": "Isabela"}, {"Island": "Baltra"}]


@pytest.mark.integration
def test_remote_data_files():
repo_id = "hf-internal-testing/raw_jsonl"
Expand Down