From 483864e67bf564e4d70f571009dc94f8dbe0f6fa Mon Sep 17 00:00:00 2001 From: Dmitry Petrov Date: Thu, 11 Jul 2024 21:29:34 -0700 Subject: [PATCH] to_pandas() with multy headers --- src/datachain/lib/dc.py | 11 ++++++++++- src/datachain/lib/signal_schema.py | 10 ++++++++++ tests/unit/lib/test_datachain.py | 4 ++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 5316419c2..1340289ac 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -10,6 +10,7 @@ Union, ) +import pandas as pd import sqlalchemy from datachain.lib.feature import Feature, FeatureType @@ -36,7 +37,6 @@ from datachain.query.schema import Column, DatasetRow if TYPE_CHECKING: - import pandas as pd from typing_extensions import Self from datachain.catalog import Catalog @@ -739,6 +739,15 @@ def from_pandas( # type: ignore[override] return cls.from_features(name, session, **fr_map) + def to_pandas(self, levels: bool = True) -> "pd.DataFrame": + if not levels: + return super().to_pandas() + + headers = self.signals_schema.get_headers() + transposed_result = list(map(list, zip(*self.results()))) + data = {tuple(n): val for n, val in zip(headers, transposed_result)} + return pd.DataFrame(data) + def parse_tabular( self, output: Optional[dict[str, FeatureType]] = None, diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 1d3670576..2de1445d4 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -309,6 +309,16 @@ def print_tree(self, indent: int = 4, start_at: int = 0): sub_schema = SignalSchema({"* list of": args[0]}) sub_schema.print_tree(indent=indent, start_at=total_indent + indent) + def get_headers(self): + paths = [ + path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree + ] + max_length = max(len(path) for path in paths) + return [ + path + [""] * (max_length - len(path)) if len(path) < max_length else path + for path in paths + ] + @staticmethod def _type_to_str(type_): if get_origin(type_) == Union: diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 15a5ae96c..95b05df88 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -784,3 +784,7 @@ def test_extend_features(catalog): res = dc._extend_features("sum", "num") assert res == sum(range(len(features))) + + +def test_to_pandas(): + pass