Skip to content

Commit

Permalink
from_parquet: fix partition keys extraction when caching is enabled
Browse files Browse the repository at this point in the history
When caching is enabled, the dataset is opened using the path to the
cache. This means that the original partition keys get lost and need to
be re-derived from the original path.

This commit adds a helper function that extracts partition keys from
original file path and updates the record with the partition keys.

I could not find a better way than this _hacky_ solution. Feel free to
suggest a better way to handle this.

I have added a test. Without this fix, the test fails as it would
fail to find `first_name` in the record and would default to `None`.
In a real `DataChain` code, this would violate the schema and raise a
pydantic's `ValidationError`.
  • Loading branch information
skshetry committed Dec 24, 2024
1 parent 1b34bc0 commit 1862bd0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import orjson
import pyarrow as pa
from pyarrow.dataset import CsvFileFormat, dataset
from pyarrow.dataset import CsvFileFormat, dataset, get_partition_keys, partitioning
from tqdm import tqdm

from datachain.lib.data_model import dict_to_data_model
Expand All @@ -25,6 +25,23 @@
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"


def _get_partition_keys_from_path(path, schema, flavor) -> dict[str, str]:
if not flavor:
return {}

if isinstance(flavor, str):
_partitioning = partitioning(schema=schema, flavor=flavor)
else:
_partitioning = flavor

try:
partitioning_expr = _partitioning.parse(path)
except AttributeError:
return {}

return get_partition_keys(partitioning_expr)


class ArrowGenerator(Generator):
def __init__(
self,
Expand Down Expand Up @@ -53,10 +70,19 @@ def __init__(
self.kwargs = kwargs

def process(self, file: File):
partition_keys = {}
if file._caching_enabled:
file.ensure_cached()
path = file.get_local_path()
ds = dataset(path, schema=self.input_schema, **self.kwargs)
flavor = self.kwargs.get("partitioning")
# Extract partition keys from the file's original path.
# Since the dataset is opened using the cached local path,
# the original partition keys may not be preserved and
# need to be re-derived.
partition_keys = _get_partition_keys_from_path(
file.get_path(), self.input_schema, flavor
)
elif self.nrows:
path = _nrows_file(file, self.nrows)
ds = dataset(path, schema=self.input_schema, **self.kwargs)
Expand All @@ -74,6 +100,7 @@ def process(self, file: File):
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
for record_batch in ds.to_batches():
for record in record_batch.to_pylist():
record.update(partition_keys)
if use_datachain_schema and self.output_schema:
vals = [_nested_model_instantiate(record, self.output_schema)]
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ def test_arrow_generator_hf(tmp_path, catalog):
assert isinstance(obj[1].col, HFClassLabel)


@pytest.mark.parametrize("cache", [True, False])
def test_arrow_generator_partitioned(tmp_path, catalog, cache):
pq_path = tmp_path / "parquets"
pylist = [
{"first_name": "Alice", "age": 25, "city": "New York"},
{"first_name": "Bob", "age": 30, "city": "Los Angeles"},
{"first_name": "Charlie", "age": 35, "city": "Chicago"},
]
table = pa.Table.from_pylist(pylist)
pq.write_to_dataset(table, pq_path, partition_cols=["first_name"])

output, original_names = schema_to_output(table.schema)
output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(
table.schema, output_schema=output_schema, partitioning="hive"
)

for path in pq_path.rglob("*.parquet"):
stream = File(path=path.as_posix(), source="file://")
stream._set_stream(catalog, caching_enabled=cache)

(o,) = list(func.process(stream))
assert isinstance(o[0], ArrowRow)
assert dict(o[1]) in pylist


@pytest.mark.parametrize(
"col_type,expected",
(
Expand Down

0 comments on commit 1862bd0

Please sign in to comment.