Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repo_id to DatasetInfo #6268

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def task_templates(self):
def version(self):
return self._info.version

@property
def repo_id(self) -> str:
return self._info.repo_id


class TensorflowDatasetMixin:
_TF_DATASET_REFS = set()
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ def __init__(
# update info with user specified infos
if features is not None:
self.info.features = features
if repo_id is not None:
self.info.repo_id = repo_id

# Prepare data dirs:
# cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing)
Expand Down Expand Up @@ -417,7 +419,7 @@ def __init__(
if len(os.listdir(self._cache_dir)) > 0:
if os.path.exists(path_join(self._cache_dir, config.DATASET_INFO_FILENAME)):
logger.info("Overwrite dataset info from restored data version if exists.")
self.info = DatasetInfo.from_directory(self._cache_dir)
self.info.update(DatasetInfo.from_directory(self._cache_dir))
else: # dir exists but no data, remove the empty dir as data aren't available anymore
logger.warning(
f"Old caching folder {self._cache_dir} for dataset {self.dataset_name} exists but no data were found. Removing it. "
Expand Down Expand Up @@ -882,7 +884,7 @@ def download_and_prepare(
logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
self.info = self._load_info()
self.info.update(self._load_info())
self.download_post_processing_resources(dl_manager)
return

Expand Down
8 changes: 8 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class DatasetInfo:
dataset_name: Optional[str] = None # for packaged builders, to be different from builder_name
config_name: Optional[str] = None
version: Optional[Union[str, Version]] = None
repo_id: Optional[str] = None
# Set later by `download_and_prepare`
splits: Optional[dict] = None
download_checksums: Optional[dict] = None
Expand Down Expand Up @@ -282,6 +283,12 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]):
supervised_keys = None
task_templates = None

repo_ids = {dset_info.repo_id for dset_info in dataset_infos}
if len(repo_ids) == 1:
repo_id = next(iter(repo_ids))
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
else:
repo_id = None

# Find common task templates across all dataset infos
all_task_templates = [info.task_templates for info in dataset_infos if info.task_templates is not None]
if len(all_task_templates) > 1:
Expand All @@ -299,6 +306,7 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]):
features=features,
supervised_keys=supervised_keys,
task_templates=task_templates,
repo_id=repo_id,
)

@classmethod
Expand Down
Loading