Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load GitHub datasets from Hub #4059

Merged
merged 6 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 46 additions & 144 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,93 +424,6 @@ def get_module(self) -> MetricModule:
raise NotImplementedError


class GithubDatasetModuleFactory(_DatasetModuleFactory):
"""
Get the module of a dataset from GitHub (legacy).
The dataset script is downloaded from GitHub.
This class will eventually be removed and a HubDatasetModuleFactory will be used instead.
"""

def __init__(
self,
name: str,
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.revision = revision
self.download_config = download_config.copy() if download_config else DownloadConfig()
if self.download_config.max_retries < 3:
self.download_config.max_retries = 3
self.download_mode = download_mode
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") == 0
increase_load_count(name, resource_type="dataset")

def download_loading_script(self, revision: Optional[str]) -> str:
file_path = hf_github_url(path=self.name, name=self.name + ".py", revision=revision)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
return cached_path(file_path, download_config=download_config)

def download_dataset_infos_file(self, revision: Optional[str]) -> str:
dataset_infos = hf_github_url(path=self.name, name=config.DATASETDICT_INFOS_FILENAME, revision=revision)
# Download the dataset infos file if available
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading metadata"
try:
return cached_path(
dataset_infos,
download_config=download_config,
)
except (FileNotFoundError, ConnectionError):
return None

def get_module(self) -> DatasetModule:
# get script and other files
revision = self.revision
try:
local_path = self.download_loading_script(revision)
except FileNotFoundError:
if revision is not None or os.getenv("HF_SCRIPTS_VERSION", None) is not None:
raise
else:
revision = "main"
local_path = self.download_loading_script(revision)
logger.warning(
f"Couldn't find a directory or a dataset named '{self.name}' in this version. "
f"It was picked from the main branch on github instead."
)
dataset_infos_path = self.download_dataset_infos_file(revision)
imports = get_imports(local_path)
local_imports = _download_additional_modules(
name=self.name,
base_path=hf_github_url(path=self.name, name="", revision=revision),
imports=imports,
download_config=self.download_config,
)
additional_files = [(config.DATASETDICT_INFOS_FILENAME, dataset_infos_path)] if dataset_infos_path else []
# copy the script and the files in an importable directory
dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
module_path, hash = _create_importable_file(
local_path=local_path,
local_imports=local_imports,
additional_files=additional_files,
dynamic_modules_path=dynamic_modules_path,
module_namespace="datasets",
name=self.name,
download_mode=self.download_mode,
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {"hash": hash, "base_path": hf_hub_url(self.name, "", revision=self.revision)}
return DatasetModule(module_path, hash, builder_kwargs)


class GithubMetricModuleFactory(_MetricModuleFactory):
"""Get the module of a metric. The metric script is downloaded from GitHub.

Expand Down Expand Up @@ -917,11 +830,10 @@ def __init__(
self.download_config = download_config or DownloadConfig()
self.download_mode = download_mode
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") == 1
increase_load_count(name, resource_type="dataset")

def download_loading_script(self) -> str:
file_path = hf_hub_url(repo_id=self.name, path=self.name.split("/")[1] + ".py", revision=self.revision)
file_path = hf_hub_url(repo_id=self.name, path=self.name.split("/")[-1] + ".py", revision=self.revision)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
Expand Down Expand Up @@ -1197,67 +1109,57 @@ def dataset_module_factory(
elif is_relative_path(path) and path.count("/") <= 1:
try:
_raise_if_offline_mode_is_enabled()
if path.count("/") == 0: # even though the dataset is on the Hub, we get it from GitHub for now
# TODO(QL): use a Hub dataset module factory instead of GitHub
return GithubDatasetModuleFactory(
hf_api = HfApi(config.HF_ENDPOINT)
try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
repo_id=path,
revision=revision,
token=token if token else "no-token",
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
if isinstance(
e,
(
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
),
):
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})")
elif "404" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg)
elif "401" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
msg = msg + f" at revision '{revision}'" if revision else msg
raise FileNotFoundError(
msg
+ ". If the repo is private, make sure you are authenticated with `use_auth_token=True` after logging in with `huggingface-cli login`."
)
else:
raise e
if filename in [sibling.rfilename for sibling in dataset_info.siblings]:
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
elif path.count("/") == 1: # community dataset on the Hub
hf_api = HfApi(config.HF_ENDPOINT)
try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
repo_id=path,
revision=revision,
token=token if token else "no-token",
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
if isinstance(
e,
(
OfflineModeIsEnabled,
requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
),
):
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({type(e).__name__})")
elif "404" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
raise FileNotFoundError(msg + f" at revision '{revision}'" if revision else msg)
elif "401" in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
msg = msg + f" at revision '{revision}'" if revision else msg
raise FileNotFoundError(
msg
+ ". If the repo is private, make sure you are authenticated with `use_auth_token=True` after logging in with `huggingface-cli login`."
)
else:
raise e
if filename in [sibling.rfilename for sibling in dataset_info.siblings]:
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
).get_module()
else:
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
else:
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
except Exception as e1: # noqa: all the attempts failed, before raising the error we should check if the module is already cached.
try:
return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
Expand Down
23 changes: 7 additions & 16 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import os
import re
import shutil
import tempfile
import time
Expand All @@ -13,7 +12,7 @@
import requests

import datasets
from datasets import SCRIPTS_VERSION, config, load_dataset, load_from_disk
from datasets import config, load_dataset, load_from_disk
from datasets.arrow_dataset import Dataset
from datasets.builder import DatasetBuilder
from datasets.data_files import DataFilesDict
Expand All @@ -24,7 +23,6 @@
from datasets.load import (
CachedDatasetModuleFactory,
CachedMetricModuleFactory,
GithubDatasetModuleFactory,
GithubMetricModuleFactory,
HubDatasetModuleFactoryWithoutScript,
HubDatasetModuleFactoryWithScript,
Expand All @@ -35,7 +33,6 @@
infer_module_for_data_files,
infer_module_for_data_files_in_archives,
)
from datasets.utils.file_utils import is_remote_url

from .utils import (
OfflineSimulationMode,
Expand Down Expand Up @@ -255,9 +252,9 @@ def setUp(self):
hf_modules_cache=self.hf_modules_cache,
)

def test_GithubDatasetModuleFactory(self):
def test_HubDatasetModuleFactoryWithScript_with_github_dataset(self):
# "wmt_t2t" has additional imports (internal)
factory = GithubDatasetModuleFactory(
factory = HubDatasetModuleFactoryWithScript(
"wmt_t2t", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
)
module_factory_result = factory.get_module()
Expand Down Expand Up @@ -479,7 +476,6 @@ def test_CachedMetricModuleFactory(self):
[
CachedDatasetModuleFactory,
CachedMetricModuleFactory,
GithubDatasetModuleFactory,
GithubMetricModuleFactory,
HubDatasetModuleFactoryWithoutScript,
HubDatasetModuleFactoryWithScript,
Expand Down Expand Up @@ -577,17 +573,16 @@ def test_offline_dataset_module_factory(self):
self.assertIn("Using the latest cached version of the module", self._caplog.text)

def test_load_dataset_from_github(self):
scripts_version = os.getenv("HF_SCRIPTS_VERSION", SCRIPTS_VERSION)
with self.assertRaises(FileNotFoundError) as context:
datasets.load_dataset("_dummy")
self.assertIn(
"https://raw.githubusercontent.com/huggingface/datasets/main/datasets/_dummy/_dummy.py",
"Dataset '_dummy' doesn't exist on the Hub",
str(context.exception),
)
with self.assertRaises(FileNotFoundError) as context:
datasets.load_dataset("_dummy", revision="0.0.0")
self.assertIn(
"https://raw.githubusercontent.com/huggingface/datasets/0.0.0/datasets/_dummy/_dummy.py",
"Dataset '_dummy' doesn't exist on the Hub at revision '0.0.0'",
str(context.exception),
)
for offline_simulation_mode in list(OfflineSimulationMode):
Expand All @@ -596,7 +591,7 @@ def test_load_dataset_from_github(self):
datasets.load_dataset("_dummy")
if offline_simulation_mode != OfflineSimulationMode.HF_DATASETS_OFFLINE_SET_TO_1:
self.assertIn(
f"https://raw.githubusercontent.com/huggingface/datasets/{scripts_version}/datasets/_dummy/_dummy.py",
"Couldn't reach '_dummy' on the Hub",
str(context.exception),
)

Expand Down Expand Up @@ -708,11 +703,7 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory
assert "Using the latest cached version of the module" in caplog.text
with pytest.raises(FileNotFoundError) as exc_info:
datasets.load_dataset(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST)
m_combined_path = re.search(
rf"http\S*{re.escape(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST + '/' + SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST + '.py')}\b",
str(exc_info.value),
)
assert m_combined_path is not None and is_remote_url(m_combined_path.group())
assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value)
assert os.path.abspath(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST) in str(exc_info.value)


Expand Down