Skip to content

Commit

Permalink
from_parquet: use virtual filesystem to preserve partition information
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 24, 2024
1 parent 60256d6 commit 785dfab
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional

import fsspec.implementations.reference
import orjson
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm

Expand All @@ -16,15 +18,23 @@

if TYPE_CHECKING:
from datasets.features.features import Features
from fsspec.spec import AbstractFileSystem
from pydantic import BaseModel

from datachain.lib.data_model import DataType
from datachain.lib.dc import DataChain


DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"


class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
def _open(self, path, mode="rb", *args, **kwargs):
# `fsspec`'s `ReferenceFileSystem._open` reads the whole file in-memory.
(uri,) = self.references[path]
protocol, _ = split_protocol(uri)
return self.fss[protocol]._open(uri, mode, *args, **kwargs)


class ArrowGenerator(Generator):
def __init__(
self,
Expand Down Expand Up @@ -53,18 +63,21 @@ def __init__(
self.kwargs = kwargs

def process(self, file: File):
fs: Optional[AbstractFileSystem]
if file._caching_enabled:
file.ensure_cached()
path = file.get_local_path()
ds = dataset(path, schema=self.input_schema, **self.kwargs)
elif self.nrows:
path = _nrows_file(file, self.nrows)
ds = dataset(path, schema=self.input_schema, **self.kwargs)
cache_path = file.get_local_path()
fs_path = file.path
fs = ReferenceFileSystem({fs_path: [cache_path]})
else:
path = file.get_path()
ds = dataset(
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
)
fs, fs_path = file.get_fs(), file.get_path()

if self.nrows:
f = _nrows_file(file, self.nrows)
fs = ReferenceFileSystem({fs_path: [f]})

ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs)

hf_schema = _get_hf_schema(ds.schema)
use_datachain_schema = (
bool(ds.schema.metadata)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ def test_arrow_generator_hf(tmp_path, catalog):
assert isinstance(obj[1].col, HFClassLabel)


@pytest.mark.parametrize("cache", [True, False])
def test_arrow_generator_partitioned(tmp_path, catalog, cache):
pq_path = tmp_path / "parquets"
pylist = [
{"first_name": "Alice", "age": 25, "city": "New York"},
{"first_name": "Bob", "age": 30, "city": "Los Angeles"},
{"first_name": "Charlie", "age": 35, "city": "Chicago"},
]
table = pa.Table.from_pylist(pylist)
pq.write_to_dataset(table, pq_path, partition_cols=["first_name"])

output, original_names = schema_to_output(table.schema)
output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(
table.schema, output_schema=output_schema, partitioning="hive"
)

for path in pq_path.rglob("*.parquet"):
stream = File(path=path.as_posix(), source="file://")
stream._set_stream(catalog, caching_enabled=cache)

(o,) = list(func.process(stream))
assert isinstance(o[0], ArrowRow)
assert dict(o[1]) in pylist


@pytest.mark.parametrize(
"col_type,expected",
(
Expand Down

0 comments on commit 785dfab

Please sign in to comment.