Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set load_from_disk path type as PathLike #7081

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ def save_to_disk(
If you want to store paths or urls, please use the Value("string") type.

Args:
dataset_path (`str`):
dataset_path (`path-like`):
Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`)
of the dataset directory where the dataset will be saved to.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down Expand Up @@ -1660,7 +1660,7 @@ def _build_local_temp_path(uri_or_path: str) -> Path:

@staticmethod
def load_from_disk(
dataset_path: str,
dataset_path: PathLike,
fs="deprecated",
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
Expand All @@ -1670,7 +1670,7 @@ def load_from_disk(
filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.

Args:
dataset_path (`str`):
dataset_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`)
of the dataset directory where the dataset will be loaded from.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down
9 changes: 4 additions & 5 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,10 +1231,9 @@ def save_to_disk(
If you want to store paths or urls, please use the Value("string") type.

Args:
dataset_dict_path (`str`):
Path (e.g. `dataset/train`) or remote URI
(e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be
saved to.
dataset_dict_path (`path-like`):
Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`)
of the dataset dict directory where the dataset dict will be saved to.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Instance of the remote filesystem where the dataset will be saved to.

Expand Down Expand Up @@ -1314,7 +1313,7 @@ def load_from_disk(
Load a dataset that was previously saved using [`save_to_disk`] from a filesystem using `fsspec.spec.AbstractFileSystem`.

Args:
dataset_dict_path (`str`):
dataset_dict_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`)
of the dataset dict directory where the dataset dict will be loaded from.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down
12 changes: 8 additions & 4 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from .utils.logging import get_logger
from .utils.metadata import MetadataConfigs
from .utils.py_utils import get_imports, lock_importable_file
from .utils.typing import PathLike
from .utils.version import Version


Expand Down Expand Up @@ -2168,16 +2169,19 @@ def load_dataset(


def load_from_disk(
dataset_path: str, fs="deprecated", keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None
dataset_path: PathLike,
fs="deprecated",
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
) -> Union[Dataset, DatasetDict]:
"""
Loads a dataset that was previously saved using [`~Dataset.save_to_disk`] from a dataset directory, or
from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.

Args:
dataset_path (`str`):
Path (e.g. `"dataset/train"`) or remote URI (e.g.
`"s3://my-bucket/dataset/train"`) of the [`Dataset`] or [`DatasetDict`] directory where the dataset will be
dataset_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3://my-bucket/dataset/train"`)
of the [`Dataset`] or [`DatasetDict`] directory where the dataset/dataset-dict will be
loaded from.
fs (`~filesystems.S3FileSystem` or `fsspec.spec.AbstractFileSystem`, *optional*):
Instance of the remote filesystem used to download the files from.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4119,7 +4119,7 @@ def test_dummy_dataset_serialize_fs(dataset, mockfs):
dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
assert mockfs.isdir(dataset_path)
assert mockfs.glob(dataset_path + "/*")
reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
reloaded = Dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
assert len(reloaded) == len(dataset)
assert reloaded.features == dataset.features
assert reloaded.to_dict() == dataset.to_dict()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def test_dummy_datasetdict_serialize_fs(mockfs):
dataset_dict.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
assert mockfs.isdir(dataset_path)
assert mockfs.glob(dataset_path + "/*")
reloaded = dataset_dict.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
reloaded = DatasetDict.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
assert list(reloaded) == list(dataset_dict)
for k in dataset_dict:
assert reloaded[k].features == dataset_dict[k].features
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ def test_load_from_disk_with_default_in_memory(
expected_in_memory = False

dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True, trust_remote_code=True)
dataset_path = os.path.join(tmp_path, "saved_dataset")
dataset_path = tmp_path / "saved_dataset"
dset.save_to_disk(dataset_path)

with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
Expand Down
Loading