Skip to content

Commit

Permalink
IndexedFile -> ArrowRow (#445)
Browse files Browse the repository at this point in the history
* indexedfile -> arrowvfile

* ArrowVFile -> ArrowRow
  • Loading branch information
Dave Berenbaum authored Sep 20, 2024
1 parent 44c6a6f commit a925653
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 25 deletions.
4 changes: 2 additions & 2 deletions docs/references/file.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ automatically when creating a `DataChain` from files, like in
classes include various metadata fields about the underlying file as well as methods to
read from the files and otherwise work with the file contents.

::: datachain.lib.file.ArrowRow

::: datachain.lib.file.ExportPlacement

::: datachain.lib.file.File
Expand All @@ -15,8 +17,6 @@ read from the files and otherwise work with the file contents.

::: datachain.lib.file.ImageFile

::: datachain.lib.file.IndexedFile

::: datachain.lib.file.TarVFile

::: datachain.lib.file.TextFile
4 changes: 2 additions & 2 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
ArrowRow,
File,
FileError,
ImageFile,
IndexedFile,
TarVFile,
TextFile,
)
Expand All @@ -16,6 +16,7 @@
__all__ = [
"AbstractUDF",
"Aggregator",
"ArrowRow",
"C",
"Column",
"DataChain",
Expand All @@ -26,7 +27,6 @@
"FileError",
"Generator",
"ImageFile",
"IndexedFile",
"Mapper",
"ModelStore",
"Session",
Expand Down
11 changes: 8 additions & 3 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import TYPE_CHECKING, Optional

import pyarrow as pa
from pyarrow.dataset import dataset
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm

from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import File, IndexedFile
from datachain.lib.file import ArrowRow, File
from datachain.lib.model_store import ModelStore
from datachain.lib.udf import Generator

Expand Down Expand Up @@ -84,7 +84,12 @@ def process(self, file: File):
vals_dict[field] = val
vals = [self.output_schema(**vals_dict)]
if self.source:
yield [IndexedFile(file=file, index=index), *vals]
kwargs: dict = self.kwargs
# Can't serialize CsvFileFormat; may lose formatting options.
if isinstance(kwargs.get("format"), CsvFileFormat):
kwargs["format"] = "csv"
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
yield [arrow_file, *vals]
else:
yield vals
index += 1
Expand Down
4 changes: 2 additions & 2 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, get_file_type
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.file import File, IndexedFile, get_file_type
from datachain.lib.listing import (
is_listing_dataset,
is_listing_expired,
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def parse_tabular(
for name, info in output.model_fields.items()
}
if source:
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
return self.gen(
ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
)
Expand Down
28 changes: 23 additions & 5 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from PIL import Image
from pyarrow.dataset import dataset
from pydantic import Field, field_validator

if TYPE_CHECKING:
Expand Down Expand Up @@ -439,14 +440,31 @@ def save(self, destination: str):
self.read().save(destination)


class IndexedFile(DataModel):
"""Metadata indexed from tabular files.
Includes `file` and `index` signals.
"""
class ArrowRow(DataModel):
"""`DataModel` for reading row from Arrow-supported file."""

file: File
index: int
kwargs: dict

@contextmanager
def open(self):
"""Stream row contents from indexed file."""
if self.file._caching_enabled:
self.file.ensure_cached()
path = self.file.get_local_path()
ds = dataset(path, **self.kwargs)

else:
path = self.file.get_path()
ds = dataset(path, filesystem=self.file.get_fs(), **self.kwargs)

return ds.take([self.index]).to_reader()

def read(self):
"""Returns row contents as dict."""
with self.open() as record_batch:
return record_batch.to_pylist()[0]


def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
schema_to_output,
)
from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import File, IndexedFile
from datachain.lib.file import ArrowRow, File
from datachain.lib.hf import HFClassLabel


Expand All @@ -33,10 +33,11 @@ def test_arrow_generator(tmp_path, catalog, cache):
objs = list(func.process(stream))

assert len(objs) == len(ids)
for index, (o, id, text) in enumerate(zip(objs, ids, texts)):
assert isinstance(o[0], IndexedFile)
assert isinstance(o[0].file, File)
assert o[0].index == index
for o, id, text in zip(objs, ids, texts):
assert isinstance(o[0], ArrowRow)
file_vals = o[0].read()
assert file_vals["id"] == id
assert file_vals["text"] == text
assert o[1] == id
assert o[2] == text

Expand Down Expand Up @@ -78,10 +79,8 @@ def test_arrow_generator_output_schema(tmp_path, catalog):
objs = list(func.process(stream))

assert len(objs) == len(ids)
for index, (o, id, text, dict) in enumerate(zip(objs, ids, texts, dicts)):
assert isinstance(o[0], IndexedFile)
assert isinstance(o[0].file, File)
assert o[0].index == index
for o, id, text, dict in zip(objs, ids, texts, dicts):
assert isinstance(o[0], ArrowRow)
assert o[1].id == id
assert o[1].text == text
assert o[1].dict.a == dict["a"]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_module_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_module_exports():
from datachain import (
AbstractUDF,
Aggregator,
ArrowRow,
C,
Column,
DataChain,
Expand All @@ -22,7 +23,6 @@ def test_module_exports():
FileError,
Generator,
ImageFile,
IndexedFile,
Mapper,
Session,
TarVFile,
Expand Down Expand Up @@ -67,6 +67,7 @@ def monkey_import_importerror(
from datachain import (
AbstractUDF,
Aggregator,
ArrowRow,
C,
Column,
DataChain,
Expand All @@ -76,7 +77,6 @@ def monkey_import_importerror(
FileError,
Generator,
ImageFile,
IndexedFile,
Mapper,
Session,
TarVFile,
Expand Down

0 comments on commit a925653

Please sign in to comment.