Skip to content

Commit

Permalink
to_pandas() with multy headers
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov committed Jul 12, 2024
1 parent ef2347f commit 483864e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Union,
)

import pandas as pd
import sqlalchemy

from datachain.lib.feature import Feature, FeatureType
Expand All @@ -36,7 +37,6 @@
from datachain.query.schema import Column, DatasetRow

if TYPE_CHECKING:
import pandas as pd
from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -739,6 +739,15 @@ def from_pandas( # type: ignore[override]

return cls.from_features(name, session, **fr_map)

def to_pandas(self, levels: bool = True) -> "pd.DataFrame":
if not levels:
return super().to_pandas()

headers = self.signals_schema.get_headers()
transposed_result = list(map(list, zip(*self.results())))
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
return pd.DataFrame(data)

def parse_tabular(
self,
output: Optional[dict[str, FeatureType]] = None,
Expand Down
10 changes: 10 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ def print_tree(self, indent: int = 4, start_at: int = 0):
sub_schema = SignalSchema({"* list of": args[0]})
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)

def get_headers(self):
paths = [
path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
]
max_length = max(len(path) for path in paths)
return [
path + [""] * (max_length - len(path)) if len(path) < max_length else path
for path in paths
]

@staticmethod
def _type_to_str(type_):
if get_origin(type_) == Union:
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,7 @@ def test_extend_features(catalog):

res = dc._extend_features("sum", "num")
assert res == sum(range(len(features)))


def test_to_pandas():
pass

0 comments on commit 483864e

Please sign in to comment.