Skip to content

Commit

Permalink
from_parquet: use virtual filesystem to preserve partition information
Browse files Browse the repository at this point in the history
Also did:

* minor refactor,
* removes `nrows` _hack_ and,
* disables prefetching when `nrows` is set, so that we don't download
    the whole dataset.
  • Loading branch information
skshetry committed Dec 25, 2024
1 parent 60256d6 commit 3a7e7d2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 62 deletions.
141 changes: 83 additions & 58 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional

import fsspec.implementations.reference
import orjson
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm

Expand All @@ -16,6 +18,7 @@

if TYPE_CHECKING:
from datasets.features.features import Features
from fsspec.spec import AbstractFileSystem
from pydantic import BaseModel

from datachain.lib.data_model import DataType
Expand All @@ -25,7 +28,17 @@
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"


class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
def _open(self, path, mode="rb", *args, **kwargs):
# `fsspec`'s `ReferenceFileSystem._open` reads the whole file in-memory.
(uri,) = self.references[path]
protocol, _ = split_protocol(uri)
return self.fss[protocol]._open(uri, mode, *args, **kwargs)


class ArrowGenerator(Generator):
DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow.dataset.DEFAULT_BATCH_SIZE`

def __init__(
self,
input_schema: Optional["pa.Schema"] = None,
Expand Down Expand Up @@ -53,59 +66,83 @@ def __init__(
self.kwargs = kwargs

def process(self, file: File):
fs: Optional[AbstractFileSystem]
if file._caching_enabled:
file.ensure_cached()
path = file.get_local_path()
ds = dataset(path, schema=self.input_schema, **self.kwargs)
elif self.nrows:
path = _nrows_file(file, self.nrows)
ds = dataset(path, schema=self.input_schema, **self.kwargs)
cache_path = file.get_local_path()
fs_path = file.path
fs = ReferenceFileSystem({fs_path: [cache_path]})
else:
path = file.get_path()
ds = dataset(
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
)
fs, fs_path = file.get_fs(), file.get_path()

ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)

hf_schema = _get_hf_schema(ds.schema)
use_datachain_schema = (
bool(ds.schema.metadata)
and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata
)
index = 0
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
for record_batch in ds.to_batches():
for record in record_batch.to_pylist():
if use_datachain_schema and self.output_schema:
vals = [_nested_model_instantiate(record, self.output_schema)]
else:
vals = list(record.values())
if self.output_schema:
fields = self.output_schema.model_fields
vals_dict = {}
for i, ((field, field_info), val) in enumerate(
zip(fields.items(), vals)
):
anno = field_info.annotation
if hf_schema:
from datachain.lib.hf import convert_feature

feat = list(hf_schema[0].values())[i]
vals_dict[field] = convert_feature(val, feat, anno)
elif ModelStore.is_pydantic(anno):
vals_dict[field] = anno(**val) # type: ignore[misc]
else:
vals_dict[field] = val
vals = [self.output_schema(**vals_dict)]
if self.source:
kwargs: dict = self.kwargs
# Can't serialize CsvFileFormat; may lose formatting options.
if isinstance(kwargs.get("format"), CsvFileFormat):
kwargs["format"] = "csv"
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
yield [arrow_file, *vals]
else:
yield vals
index += 1
pbar.update(len(record_batch))

kw = {}
if self.nrows:
kw = {"batch_size": min(self.DEFAULT_BATCH_SIZE, self.nrows)}

def iter_records():
for record in ds.to_batches(**kw):
yield from record.to_pylist()

it = islice(iter_records(), self.nrows)
with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar:
for index, record in enumerate(pbar):
yield self._process_record(
record, file, index, hf_schema, use_datachain_schema
)

def _process_record(
self,
record: dict[str, Any],
file: File,
index: int,
hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
use_datachain_schema: bool,
):
if use_datachain_schema and self.output_schema:
vals = [_nested_model_instantiate(record, self.output_schema)]
else:
vals = self._process_non_datachain_record(record, hf_schema)

if self.source:
kwargs: dict = self.kwargs
# Can't serialize CsvFileFormat; may lose formatting options.
if isinstance(kwargs.get("format"), CsvFileFormat):
kwargs["format"] = "csv"
arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs)
return [arrow_file, *vals]
return vals

def _process_non_datachain_record(
self,
record: dict[str, Any],
hf_schema: Optional[tuple["Features", dict[str, "DataType"]]],
):
vals = list(record.values())
if not self.output_schema:
return vals

fields = self.output_schema.model_fields
vals_dict = {}
for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)):
anno = field_info.annotation
if hf_schema:
from datachain.lib.hf import convert_feature

feat = list(hf_schema[0].values())[i]
vals_dict[field] = convert_feature(val, feat, anno)
elif ModelStore.is_pydantic(anno):
vals_dict[field] = anno(**val) # type: ignore[misc]
else:
vals_dict[field] = val
return [self.output_schema(**vals_dict)]


def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
Expand Down Expand Up @@ -190,18 +227,6 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa:
raise TypeError(f"{col_type!r} datatypes not supported, column: {column}")


def _nrows_file(file: File, nrows: int) -> str:
tf = NamedTemporaryFile(delete=False) # noqa: SIM115
with file.open(mode="r") as reader:
with open(tf.name, "a") as writer:
for row, line in enumerate(reader):
if row >= nrows:
break
writer.write(line)
writer.write("\n")
return tf.name


def _get_hf_schema(
schema: "pa.Schema",
) -> Optional[tuple["Features", dict[str, "DataType"]]]:
Expand Down
11 changes: 7 additions & 4 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def jmespath_to_name(s: str):
nrows=nrows,
)
}
return chain.gen(**signal_dict) # type: ignore[misc, arg-type]
# disable prefetch if nrows is set
settings = {"prefetch": 0} if nrows else {}
return chain.settings(**settings).gen(**signal_dict) # type: ignore[misc, arg-type]

def explode(
self,
Expand Down Expand Up @@ -1894,7 +1896,10 @@ def parse_tabular(

if source:
output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
return self.gen(

# disable prefetch if nrows is set
settings = {"prefetch": 0} if nrows else {}
return self.settings(**settings).gen(
ArrowGenerator(schema, model, source, nrows, **kwargs), output=output
)

Expand Down Expand Up @@ -1976,8 +1981,6 @@ def from_csv(
else:
msg = f"error parsing csv - incompatible output type {type(output)}"
raise DatasetPrepareError(chain.name, msg)
elif nrows:
nrows += 1

parse_options = ParseOptions(delimiter=delimiter)
read_options = ReadOptions(column_names=column_names)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ def test_arrow_generator_hf(tmp_path, catalog):
assert isinstance(obj[1].col, HFClassLabel)


@pytest.mark.parametrize("cache", [True, False])
def test_arrow_generator_partitioned(tmp_path, catalog, cache):
pq_path = tmp_path / "parquets"
pylist = [
{"first_name": "Alice", "age": 25, "city": "New York"},
{"first_name": "Bob", "age": 30, "city": "Los Angeles"},
{"first_name": "Charlie", "age": 35, "city": "Chicago"},
]
table = pa.Table.from_pylist(pylist)
pq.write_to_dataset(table, pq_path, partition_cols=["first_name"])

output, original_names = schema_to_output(table.schema)
output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(
table.schema, output_schema=output_schema, partitioning="hive"
)

for path in pq_path.rglob("*.parquet"):
stream = File(path=path.as_posix(), source="file://")
stream._set_stream(catalog, caching_enabled=cache)

(o,) = list(func.process(stream))
assert isinstance(o[0], ArrowRow)
assert dict(o[1]) in pylist


@pytest.mark.parametrize(
"col_type,expected",
(
Expand Down

0 comments on commit 3a7e7d2

Please sign in to comment.