Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract data on the fly in packaged builders #6784

Merged
merged 13 commits into from
Apr 16, 2024
3 changes: 1 addition & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

from . import config
from .download import DownloadConfig
from .download.streaming_download_manager import _prepare_path_and_storage_options, xbasename, xjoin
from .naming import _split_re
from .splits import Split
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import is_local_path, is_relative_path
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
from .utils.py_utils import glob_pattern_to_regex, string_to_dict


Expand Down
3 changes: 3 additions & 0 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DownloadConfig:
was already extracted, re-extract the archive and override the folder where it was extracted.
delete_extracted (`bool`, defaults to `False`):
Whether to delete (or keep) the extracted files.
extract_on_the_fly (`bool`, defaults to `False`):
If `True`, extract compressed files while they are being read.
use_etag (`bool`, defaults to `True`):
Whether to use the ETag HTTP response header to validate the cached files.
num_proc (`int`, *optional*):
Expand Down Expand Up @@ -72,6 +74,7 @@ class DownloadConfig:
extract_compressed_file: bool = False
force_extract: bool = False
delete_extracted: bool = False
extract_on_the_fly: bool = False
use_etag: bool = True
num_proc: Optional[int] = None
max_retries: int = 1
Expand Down
194 changes: 6 additions & 188 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
import multiprocessing
import os
import posixpath
import tarfile
import warnings
import zipfile
from datetime import datetime
from functools import partial
from itertools import chain
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

import fsspec
from fsspec.core import url_to_fs
Expand All @@ -36,6 +33,8 @@
from ..utils import tqdm as hf_tqdm
from ..utils.deprecation_utils import DeprecatedEnum, deprecated
from ..utils.file_utils import (
ArchiveIterable,
FilesIterable,
cached_path,
get_from_cache,
hash_url_to_filename,
Expand All @@ -46,47 +45,13 @@
from ..utils.info_utils import get_size_checksum_dict
from ..utils.logging import get_logger, tqdm
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from ..utils.track import TrackedIterable, tracked_str
from ..utils.track import tracked_str
from .download_config import DownloadConfig


logger = get_logger(__name__)


BASE_KNOWN_EXTENSIONS = [
"txt",
"csv",
"json",
"jsonl",
"tsv",
"conll",
"conllu",
"orig",
"parquet",
"pkl",
"pickle",
"rel",
"xml",
]
MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL = {
bytes.fromhex("504B0304"): "zip",
bytes.fromhex("504B0506"): "zip", # empty archive
bytes.fromhex("504B0708"): "zip", # spanned archive
bytes.fromhex("425A68"): "bz2",
bytes.fromhex("1F8B"): "gzip",
bytes.fromhex("FD377A585A00"): "xz",
bytes.fromhex("04224D18"): "lz4",
bytes.fromhex("28B52FFD"): "zstd",
}
MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL = {
b"Rar!": "rar",
}
MAGIC_NUMBER_MAX_LENGTH = max(
len(magic_number)
for magic_number in chain(MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL, MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL)
)


class DownloadMode(enum.Enum):
"""`Enum` for how to treat pre-existing downloads and data.

Expand Down Expand Up @@ -118,153 +83,6 @@ def help_message(self):
return "Use 'DownloadMode' instead."


def _get_path_extension(path: str) -> str:
# Get extension: train.json.gz -> gz
extension = path.split(".")[-1]
# Remove query params ("dl=1", "raw=true"): gz?dl=1 -> gz
# Remove shards infos (".txt_1", ".txt-00000-of-00100"): txt_1 -> txt
for symb in "?-_":
extension = extension.split(symb)[0]
return extension


def _get_extraction_protocol_with_magic_number(f) -> Optional[str]:
"""read the magic number from a file-like object and return the compression protocol"""
# Check if the file object is seekable even before reading the magic number (to avoid https://bugs.python.org/issue26440)
try:
f.seek(0)
except (AttributeError, io.UnsupportedOperation):
return None
magic_number = f.read(MAGIC_NUMBER_MAX_LENGTH)
f.seek(0)
for i in range(MAGIC_NUMBER_MAX_LENGTH):
compression = MAGIC_NUMBER_TO_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None:
return compression
compression = MAGIC_NUMBER_TO_UNSUPPORTED_COMPRESSION_PROTOCOL.get(magic_number[: MAGIC_NUMBER_MAX_LENGTH - i])
if compression is not None:
raise NotImplementedError(f"Compression protocol '{compression}' not implemented.")


def _get_extraction_protocol(path: str) -> Optional[str]:
path = str(path)
extension = _get_path_extension(path)
# TODO(mariosasko): The below check will be useful once we can preserve the original extension in the new cache layout (use the `filename` parameter of `hf_hub_download`)
if (
extension in BASE_KNOWN_EXTENSIONS
or extension in ["tgz", "tar"]
or path.endswith((".tar.gz", ".tar.bz2", ".tar.xz"))
):
return None
with open(path, "rb") as f:
return _get_extraction_protocol_with_magic_number(f)


class _IterableFromGenerator(TrackedIterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator: Callable, *args, **kwargs):
super().__init__()
self.generator = generator
self.args = args
self.kwargs = kwargs

def __iter__(self):
for x in self.generator(*self.args, **self.kwargs):
self.last_item = x
yield x
self.last_item = None


class ArchiveIterable(_IterableFromGenerator):
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""

@staticmethod
def _iter_tar(f):
stream = tarfile.open(fileobj=f, mode="r|*")
for tarinfo in stream:
file_path = tarinfo.name
if not tarinfo.isreg():
continue
if file_path is None:
continue
if os.path.basename(file_path).startswith((".", "__")):
# skipping hidden files
continue
file_obj = stream.extractfile(tarinfo)
yield file_path, file_obj
stream.members = []
del stream

@staticmethod
def _iter_zip(f):
zipf = zipfile.ZipFile(f)
for member in zipf.infolist():
file_path = member.filename
if member.is_dir():
continue
if file_path is None:
continue
if os.path.basename(file_path).startswith((".", "__")):
# skipping hidden files
continue
file_obj = zipf.open(member)
yield file_path, file_obj

@classmethod
def _iter_from_fileobj(cls, f) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol_with_magic_number(f)
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def _iter_from_path(cls, urlpath: str) -> Generator[Tuple, None, None]:
compression = _get_extraction_protocol(urlpath)
with open(urlpath, "rb") as f:
if compression == "zip":
yield from cls._iter_zip(f)
else:
yield from cls._iter_tar(f)

@classmethod
def from_buf(cls, fileobj) -> "ArchiveIterable":
return cls(cls._iter_from_fileobj, fileobj)

@classmethod
def from_path(cls, urlpath_or_buf) -> "ArchiveIterable":
return cls(cls._iter_from_path, urlpath_or_buf)


class FilesIterable(_IterableFromGenerator):
"""An iterable of paths from a list of directories or files"""

@classmethod
def _iter_from_paths(cls, urlpaths: Union[str, List[str]]) -> Generator[str, None, None]:
if not isinstance(urlpaths, list):
urlpaths = [urlpaths]
for urlpath in urlpaths:
if os.path.isfile(urlpath):
yield urlpath
else:
for dirpath, dirnames, filenames in os.walk(urlpath):
# in-place modification to prune the search
dirnames[:] = sorted([dirname for dirname in dirnames if not dirname.startswith((".", "__"))])
if os.path.basename(dirpath).startswith((".", "__")):
# skipping hidden directories
continue
for filename in sorted(filenames):
if filename.startswith((".", "__")):
# skipping hidden files
continue
yield os.path.join(dirpath, filename)

@classmethod
def from_paths(cls, urlpaths) -> "FilesIterable":
return cls(cls._iter_from_paths, urlpaths)


class DownloadManager:
is_streaming = False

Expand Down Expand Up @@ -530,7 +348,7 @@ def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]):
if hasattr(path_or_buf, "read"):
return ArchiveIterable.from_buf(path_or_buf)
else:
return ArchiveIterable.from_path(path_or_buf)
return ArchiveIterable.from_urlpath(path_or_buf)

def iter_files(self, paths: Union[str, List[str]]):
"""Iterate over file paths.
Expand All @@ -549,7 +367,7 @@ def iter_files(self, paths: Union[str, List[str]]):
>>> files = dl_manager.iter_files(files)
```
"""
return FilesIterable.from_paths(paths)
return FilesIterable.from_urlpaths(paths)

def extract(self, path_or_paths, num_proc="deprecated"):
"""Extract given path(s).
Expand Down
Loading
Loading