Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datachain.lib.tar.process_tar() generator #440

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/datachain/lib/tar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import hashlib
import tarfile
from collections.abc import Iterator

from datachain.lib.file import File, TarVFile


def build_tar_member(parent: File, info: tarfile.TarInfo) -> File:
new_parent = parent.get_full_name()
etag_string = "-".join([parent.etag, info.name, str(info.mtime)])
etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest()
return File(
source=parent.source,
path=f"{new_parent}/{info.name}",
version=parent.version,
size=info.size,
etag=etag,
location=[
{
"vtype": TarVFile.get_vtype(),
"parent": parent.model_dump_custom(),
"size": info.size,
"offset": info.offset_data,
}
],
)


def process_tar(file: File) -> Iterator[File]:
with file.open() as fd:
with tarfile.open(fileobj=fd) as tar:
for entry in tar.getmembers():
yield build_tar_member(file, entry)
62 changes: 3 additions & 59 deletions src/datachain/lib/webdataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import hashlib
import json
import tarfile
import warnings
Expand All @@ -17,7 +16,8 @@
from pydantic import Field

from datachain.lib.data_model import DataModel
from datachain.lib.file import File, TarVFile
from datachain.lib.file import File
from datachain.lib.tar import build_tar_member
from datachain.lib.utils import DataChainError

# The `json` method of the Pydantic `BaseModel` class has been deprecated
Expand Down Expand Up @@ -176,34 +176,11 @@ def produce(self):
self._tar_stream, self._core_extensions, self.state.stem
)

file = self.build_file_record()
file = build_tar_member(self._tar_stream, self.state.core_file)
wds = self._wds_class(**self.state.data | {"file": file})
self.state = BuilderState()
return wds

def build_file_record(self):
new_parent = self._tar_stream.get_full_name()
core_file = self.state.core_file
etag_string = "-".join(
[self._tar_stream.etag, core_file.name, str(core_file.mtime)]
)
etag = hashlib.md5(etag_string.encode(), usedforsecurity=False).hexdigest()
return File(
source=self._tar_stream.source,
path=f"{new_parent}/{core_file.name}",
version=self._tar_stream.version,
size=core_file.size,
etag=etag,
location=[
{
"vtype": TarVFile.get_vtype(),
"parent": self._tar_stream.model_dump_custom(),
"size": core_file.size,
"offset": core_file.offset_data,
}
],
)

def _get_type(self, ext):
field = self._wds_class.model_fields.get(ext, None)
if field is None:
Expand All @@ -217,39 +194,6 @@ def _get_type(self, ext):
return anno


class TarStream(File):
@staticmethod
def to_text(data):
return data.decode("utf-8")

_DATA_CONVERTERS: ClassVar[dict[type, Any]] = {
str: lambda data: TarStream.to_text(data),
int: lambda data: int(TarStream.to_text(data)),
float: lambda data: float(TarStream.to_text(data)),
bytes: lambda data: data,
dict: lambda data: json.loads(TarStream.to_text(data)),
}

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tar = None

def open(self):
self._tar = tarfile.open(fileobj=super().open()) # noqa: SIM115
return self

def getmembers(self) -> list[tarfile.TarInfo]:
return self._tar.getmembers()

def read_member(self, member: tarfile.TarInfo, type):
fd = self._tar.extractfile(member)
data = fd.read()
converter = self._DATA_CONVERTERS.get(type, None)
if not converter:
raise ValueError("")
return converter(data)


def get_tar_groups(stream, tar, core_extensions, spec, encoding="utf-8"):
builder = Builder(stream, core_extensions, spec, tar, encoding)

Expand Down
18 changes: 17 additions & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
is_listing_dataset,
parse_listing_uri,
)
from datachain.lib.tar import process_tar
from datachain.lib.udf import Mapper
from datachain.lib.utils import DataChainError
from tests.utils import images_equal
from tests.utils import TARRED_TREE, images_equal


def _get_listing_datasets(session):
Expand Down Expand Up @@ -624,3 +625,18 @@ def name_len_interrupt(_name):
captured = capfd.readouterr()
assert "KeyboardInterrupt" in captured.err
assert "semaphore" not in captured.err


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
def test_process_and_open_tar(cloud_test_catalog):
ctc = cloud_test_catalog
dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.gen(file=process_tar)
.filter(C("file.path").glob("*/cats/*"))
)
assert dc.count() == 2
assert {(file.read(), file.name) for file in dc.collect("file")} == {
(b"meow", "cat1"),
(b"mrow", "cat2"),
}
19 changes: 0 additions & 19 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,25 +2582,6 @@ def test_checksum_udf(cloud_test_catalog, dogs_dataset):
assert len(result) == 4


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
def test_tar_loader(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
catalog.index([ctc.src_uri])
catalog.create_dataset_from_sources("animals", [ctc.src_uri])
q = DatasetQuery(name="animals", version=1, catalog=catalog).generate(index_tar)
q.save("extracted")

q = DatasetQuery(name="extracted", catalog=catalog).filter(C.path.glob("*/cats/*"))
assert len(q.db_results()) == 2

ds = q.extract(Object(to_str), "path")
assert {(value, posixpath.basename(path)) for value, path in ds} == {
("meow", "cat1"),
("mrow", "cat2"),
}


@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_simple_dataset_query(cloud_test_catalog):
ctc = cloud_test_catalog
Expand Down
Loading