Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support downloading specific splits in load_dataset #6832

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def make_file_instructions(
dataset_name=name,
split=info.name,
filetype_suffix=filetype_suffix,
shard_lengths=name2shard_lengths[info.name],
num_shards=len(name2shard_lengths[info.name] or ()),
)
for info in split_infos
}
Expand Down
163 changes: 134 additions & 29 deletions src/datasets/builder.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __post_init__(self):
else:
self.version = Version.from_dict(self.version)
if self.splits is not None and not isinstance(self.splits, SplitDict):
self.splits = SplitDict.from_split_dict(self.splits)
self.splits = SplitDict.from_split_dict(self.splits, self.dataset_name)
if self.supervised_keys is not None and not isinstance(self.supervised_keys, SupervisedKeysData):
if isinstance(self.supervised_keys, (tuple, list)):
self.supervised_keys = SupervisedKeysData(*self.supervised_keys)
Expand Down
1 change: 1 addition & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,7 @@ def load_dataset(

# Download and prepare data
builder_instance.download_and_prepare(
split=split,
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
Expand Down
22 changes: 11 additions & 11 deletions src/datasets/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Utilities for file names."""

import itertools
import os
import posixpath
import re


Expand Down Expand Up @@ -46,33 +46,33 @@ def snakecase_to_camelcase(name):


def filename_prefix_for_name(name):
if os.path.basename(name) != name:
if posixpath.basename(name) != name:
raise ValueError(f"Should be a dataset name, not a path: {name}")
return camelcase_to_snakecase(name)


def filename_prefix_for_split(name, split):
if os.path.basename(name) != name:
if posixpath.basename(name) != name:
raise ValueError(f"Should be a dataset name, not a path: {name}")
if not re.match(_split_re, split):
raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.")
return f"{filename_prefix_for_name(name)}-{split}"


def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None):
def filepattern_for_dataset_split(path, dataset_name, split, filetype_suffix=None):
prefix = filename_prefix_for_split(dataset_name, split)
filepath = posixpath.join(path, prefix)
filepath = f"{filepath}*"
if filetype_suffix:
prefix += f".{filetype_suffix}"
filepath = os.path.join(data_dir, prefix)
return f"{filepath}*"
filepath += f".{filetype_suffix}"
return filepath


def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None):
def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1):
prefix = filename_prefix_for_split(dataset_name, split)
prefix = os.path.join(path, prefix)
prefix = posixpath.join(path, prefix)

if shard_lengths:
num_shards = len(shard_lengths)
if num_shards > 1:
filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)]
if filetype_suffix:
filenames = [filename + f".{filetype_suffix}" for filename in filenames]
Expand Down
12 changes: 9 additions & 3 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import pyarrow as pa

Expand All @@ -24,12 +24,18 @@ class Arrow(datasets.ArrowBasedBuilder):
def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download_and_extract(data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
Expand Down
9 changes: 7 additions & 2 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,15 @@ def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs
if output_dir is not None and output_dir != self.cache_dir:
shutil.copytree(self.cache_dir, output_dir)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.info.splits]

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
# used to stream from cache
if isinstance(self.info.splits, datasets.SplitDict):
split_infos: List[datasets.SplitInfo] = list(self.info.splits.values())
if splits:
split_infos = [split_info for split_info in split_infos if split_info.name in splits]
else:
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
return [
Expand All @@ -184,7 +189,7 @@ def _split_generators(self, dl_manager):
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
num_shards=len(split_info.shard_lengths or ()),
)
},
)
Expand Down
10 changes: 8 additions & 2 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,18 @@ class Csv(datasets.ArrowBasedBuilder):
def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download_and_extract(data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
Expand Down Expand Up @@ -106,6 +109,8 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
)

data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
Expand Down
12 changes: 9 additions & 3 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
import json
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import pyarrow as pa
import pyarrow.json as paj
Expand Down Expand Up @@ -44,12 +44,18 @@ def _info(self):
raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported")
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download_and_extract(data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
Expand Down
10 changes: 8 additions & 2 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ def _info(self):
)
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download_and_extract(data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
Expand Down
12 changes: 9 additions & 3 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from dataclasses import InitVar, dataclass
from io import StringIO
from typing import Optional
from typing import List, Optional

import pyarrow as pa

Expand Down Expand Up @@ -42,7 +42,10 @@ class Text(datasets.ArrowBasedBuilder):
def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].

If str or List[str], then the dataset returns only the 'train' split.
Expand All @@ -51,7 +54,10 @@ def _split_generators(self, dl_manager):
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download_and_extract(data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
Expand Down
12 changes: 9 additions & 3 deletions src/datasets/packaged_modules/webdataset/webdataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import json
from itertools import islice
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -39,12 +39,18 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator):
def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()

def _split_generators(self, dl_manager):
def _available_splits(self) -> Optional[List[str]]:
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None

def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
"""We handle string, list and dicts in datafiles"""
# Download the data files
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
data_files = dl_manager.download(self.config.data_files)
data_files = self.config.data_files
if splits and isinstance(data_files, dict):
data_files = {split: data_files[split] for split in splits}
data_files = dl_manager.download(data_files)
if isinstance(data_files, (str, list, tuple)):
tar_paths = data_files
if isinstance(tar_paths, str):
Expand Down
16 changes: 6 additions & 10 deletions src/datasets/utils/info_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dic
if expected_checksums is None:
logger.info("Unable to verify checksums.")
return
if len(set(expected_checksums) - set(recorded_checksums)) > 0:
raise ExpectedMoreDownloadedFiles(str(set(expected_checksums) - set(recorded_checksums)))
if len(set(recorded_checksums) - set(expected_checksums)) > 0:
raise UnexpectedDownloadedFile(str(set(recorded_checksums) - set(expected_checksums)))
bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]]
bad_urls = [
url
for url in (set(recorded_checksums) & set(expected_checksums))
if expected_checksums[url] != recorded_checksums[url]
]
for_verification_name = " for " + verification_name if verification_name is not None else ""
if len(bad_urls) > 0:
raise NonMatchingChecksumError(
Expand Down Expand Up @@ -88,13 +88,9 @@ def verify_splits(expected_splits: Optional[dict], recorded_splits: dict):
if expected_splits is None:
logger.info("Unable to verify splits sizes.")
return
if len(set(expected_splits) - set(recorded_splits)) > 0:
raise ExpectedMoreSplits(str(set(expected_splits) - set(recorded_splits)))
if len(set(recorded_splits) - set(expected_splits)) > 0:
raise UnexpectedSplits(str(set(recorded_splits) - set(expected_splits)))
bad_splits = [
{"expected": expected_splits[name], "recorded": recorded_splits[name]}
for name in expected_splits
for name in (set(recorded_splits) & set(expected_splits))
if expected_splits[name].num_examples != recorded_splits[name].num_examples
]
if len(bad_splits) > 0:
Expand Down
21 changes: 12 additions & 9 deletions tests/test_arrow_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import posixpath
import tempfile
from pathlib import Path
from unittest import TestCase
Expand Down Expand Up @@ -103,8 +104,8 @@ def test_read_files(self):
reader = ReaderTest(tmp_dir, info)

files = [
{"filename": os.path.join(tmp_dir, "train")},
{"filename": os.path.join(tmp_dir, "test"), "skip": 10, "take": 10},
{"filename": posixpath.join(tmp_dir, "train")},
{"filename": posixpath.join(tmp_dir, "test"), "skip": 10, "take": 10},
]
dset = Dataset(**reader.read_files(files, original_instructions="train+test[10:20]"))
self.assertEqual(dset.num_rows, 110)
Expand Down Expand Up @@ -169,18 +170,18 @@ def test_make_file_instructions_basic():
assert isinstance(file_instructions, FileInstructions)
assert file_instructions.num_examples == 33
assert file_instructions.file_instructions == [
{"filename": os.path.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33}
{"filename": posixpath.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33}
]

split_infos = [SplitInfo(name="train", num_examples=100, shard_lengths=[10] * 10)]
file_instructions = make_file_instructions(name, split_infos, instruction, filetype_suffix, prefix_path)
assert isinstance(file_instructions, FileInstructions)
assert file_instructions.num_examples == 33
assert file_instructions.file_instructions == [
{"filename": os.path.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3},
{"filename": posixpath.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3},
]


Expand Down Expand Up @@ -217,7 +218,7 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran
if not isinstance(shard_lengths, list):
assert file_instructions.file_instructions == [
{
"filename": os.path.join(prefix_path, f"{name}-{split_name}.arrow"),
"filename": posixpath.join(prefix_path, f"{name}-{split_name}.arrow"),
"skip": read_range[0],
"take": read_range[1] - read_range[0],
}
Expand All @@ -226,7 +227,9 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran
file_instructions_list = []
shard_offset = 0
for i, shard_length in enumerate(shard_lengths):
filename = os.path.join(prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow")
filename = posixpath.join(
prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow"
)
if shard_offset <= read_range[0] < shard_offset + shard_length:
file_instructions_list.append(
{
Expand Down
1 change: 0 additions & 1 deletion tests/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def test_download_manager_delete_extracted_files(xz_file):
assert extracted_path == dl_manager.extracted_paths[xz_file]
extracted_path = Path(extracted_path)
parts = extracted_path.parts
# import pdb; pdb.set_trace()
assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None)
assert parts[-2] == extracted_subdir
assert extracted_path.exists()
Expand Down
Loading
Loading