From 69597f59bb3aa9fe6a926fbad95c5cf6881e2433 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Sat, 20 Jul 2024 06:29:13 +0545 Subject: [PATCH] datachain: implement to_parquet (#97) First part of #91, critical for the release. It uses `to_pandas()`, so it cannot write to a parquet file than what memory allows. --- src/datachain/lib/dc.py | 21 +++++++++++++++++++++ tests/unit/lib/test_datachain.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index bf03a303e..766719de6 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1,10 +1,12 @@ import copy +import os import re from collections.abc import Iterable, Iterator, Sequence from functools import wraps from typing import ( TYPE_CHECKING, Any, + BinaryIO, Callable, ClassVar, Literal, @@ -1178,6 +1180,25 @@ def from_parquet( partitioning=partitioning, ) + def to_parquet( + self, + path: Union[str, os.PathLike[str], BinaryIO], + partition_cols: Optional[Sequence[str]] = None, + **kwargs, + ) -> None: + """Save chain to parquet file. + + Parameters: + path : Path or a file-like binary object to save the file. + partition_cols : Column names by which to partition the dataset. + """ + _partition_cols = list(partition_cols) if partition_cols else None + return self.to_pandas().to_parquet( + path, + partition_cols=_partition_cols, + **kwargs, + ) + @classmethod def create_empty( cls, diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 1b064fbb4..1d3f6c937 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -817,6 +817,34 @@ def test_from_parquet_partitioned(tmp_dir, catalog): assert df1.equals(df) +def test_to_parquet(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA) + dc = DataChain.from_pandas(df) + + path = tmp_dir / "test.parquet" + dc.to_parquet(path) + + assert path.is_file() + pd.testing.assert_frame_equal(pd.read_parquet(path), df) + + +def test_to_parquet_partitioned(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA) + dc = DataChain.from_pandas(df) + + path = tmp_dir / "parquets" + dc.to_parquet(path, partition_cols=["first_name"]) + + assert set(path.iterdir()) == { + path / f"first_name={name}" for name in df["first_name"] + } + df1 = pd.read_parquet(path) + df1 = df1.reindex(columns=df.columns) + df1["first_name"] = df1["first_name"].astype("str") + df1 = df1.sort_values("first_name").reset_index(drop=True) + pd.testing.assert_frame_equal(df1, df) + + @pytest.mark.parametrize("processes", [False, 2, True]) def test_parallel(processes, catalog): prefix = "t & "