Skip to content

Commit

Permalink
To pandas - hierarchical multi header (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov authored Jul 17, 2024
1 parent 095952c commit 00c846a
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 109 deletions.
5 changes: 1 addition & 4 deletions examples/llm-claude-aggregate-query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import anthropic
import pandas as pd
from anthropic.types import Message

from datachain import Column, DataChain
Expand Down Expand Up @@ -55,6 +54,4 @@
)
)

with pd.option_context("display.max_columns", None):
df = chain.to_pandas()
print(df)
chain.show()
5 changes: 1 addition & 4 deletions examples/llm-claude-simple-query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import anthropic
import pandas as pd
from anthropic.types import Message
from pydantic import BaseModel

Expand Down Expand Up @@ -62,6 +61,4 @@ class Rating(BaseModel):
)
)

with pd.option_context("display.max_columns", None):
df = chain.to_pandas()
print(df)
chain.show()
5 changes: 1 addition & 4 deletions examples/llm-claude.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import anthropic
import pandas as pd
from anthropic.types import Message

from datachain import Column, DataChain, File
Expand Down Expand Up @@ -37,6 +36,4 @@
)
)

with pd.option_context("display.max_columns", None):
df = chain.to_pandas()
print(df)
chain.show()
34 changes: 33 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Union,
)

import pandas as pd
import sqlalchemy
from pydantic import BaseModel, create_model

Expand Down Expand Up @@ -38,9 +39,9 @@
detach,
)
from datachain.query.schema import Column, DatasetRow
from datachain.utils import inside_notebook

if TYPE_CHECKING:
import pandas as pd
from typing_extensions import Self

C = Column
Expand Down Expand Up @@ -731,6 +732,37 @@ def from_pandas( # type: ignore[override]

return cls.from_values(name, session, object_name=object_name, **fr_map)

def to_pandas(self, flatten=False) -> "pd.DataFrame":
headers, max_length = self.signals_schema.get_headers_with_length()
if flatten or max_length < 2:
df = pd.DataFrame.from_records(self.to_records())
if headers:
df.columns = [".".join(filter(None, header)) for header in headers]
return df

transposed_result = list(map(list, zip(*self.results())))
data = {tuple(n): val for n, val in zip(headers, transposed_result)}
return pd.DataFrame(data)

def show(self, limit: int = 20, flatten=False, transpose=False) -> None:
dc = self.limit(limit) if limit > 0 else self
df = dc.to_pandas(flatten)
if transpose:
df = df.T

with pd.option_context(
"display.max_columns", None, "display.multi_sparse", False
):
if inside_notebook():
from IPython.display import display

display(df)
else:
print(df)

if len(df) == limit:
print(f"\n[Limited by {len(df)} rows]")

def parse_tabular(
self,
output: OutputType = None,
Expand Down
10 changes: 10 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,16 @@ def print_tree(self, indent: int = 4, start_at: int = 0):
sub_schema = SignalSchema({"* list of": args[0]})
sub_schema.print_tree(indent=indent, start_at=total_indent + indent)

def get_headers_with_length(self):
paths = [
path for path, _, has_subtree, _ in self.get_flat_tree() if not has_subtree
]
max_length = max([len(path) for path in paths], default=0)
return [
path + [""] * (max_length - len(path)) if len(path) < max_length else path
for path in paths
], max_length

def __or__(self, other):
return self.__class__(self.values | other.values)

Expand Down
26 changes: 1 addition & 25 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import attrs
import pandas as pd
import sqlalchemy
from attrs import frozen
from dill import dumps, source
Expand All @@ -53,10 +52,9 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.progress import CombinedDownloadCallback
from datachain.query.schema import DEFAULT_DELIMITER
from datachain.sql.functions import rand
from datachain.storage import Storage, StorageURI
from datachain.utils import batched, determine_processes, inside_notebook
from datachain.utils import batched, determine_processes

from .metrics import metrics
from .schema import C, UDFParamSpec, normalize_param
Expand Down Expand Up @@ -1346,12 +1344,6 @@ async def get_params(row: RowDict) -> tuple:
def to_records(self) -> list[dict[str, Any]]:
return self.results(lambda cols, row: dict(zip(cols, row)))

def to_pandas(self) -> "pd.DataFrame":
records = self.to_records()
df = pd.DataFrame.from_records(records)
df.columns = [c.replace(DEFAULT_DELIMITER, ".") for c in df.columns]
return df

def shuffle(self) -> "Self":
# ToDo: implement shaffle based on seed and/or generating random column
return self.order_by(C.sys__rand)
Expand All @@ -1370,22 +1362,6 @@ def sample(self, n) -> "Self":

return sampled.limit(n)

def show(self, limit=20) -> None:
df = self.limit(limit).to_pandas()

options = ["display.max_colwidth", 50, "display.show_dimensions", False]
with pd.option_context(*options):
if inside_notebook():
from IPython.display import display

display(df)

else:
print(df.to_string())

if len(df) == limit:
print(f"[limited by {limit} objects]")

def clone(self, new_table=True) -> "Self":
obj = copy(self)
obj.steps = obj.steps.copy()
Expand Down
13 changes: 0 additions & 13 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3213,19 +3213,6 @@ def test_to_records(simple_ds_query):
assert simple_ds_query.to_records() == SIMPLE_DS_QUERY_RECORDS


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
def test_to_pandas(simple_ds_query):
import pandas as pd

df = simple_ds_query.to_pandas()
expected = pd.DataFrame.from_records(SIMPLE_DS_QUERY_RECORDS)
assert (df == expected).all(axis=None)


@pytest.mark.parametrize("method", ["to_records", "extract"])
@pytest.mark.parametrize("save", [True, False])
@pytest.mark.parametrize(
Expand Down
103 changes: 54 additions & 49 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,8 @@ def test_from_features(catalog):
params="parent",
output={"file": File, "t1": MyFr},
)
df1 = ds.to_pandas()

assert df1[["t1.nnn", "t1.count"]].equals(
pd.DataFrame({"t1.nnn": ["n1", "n2", "n1"], "t1.count": [3, 5, 1]})
)
for i, (_, t1) in enumerate(ds.iterate()):
assert t1 == features[i]


def test_preserve_feature_schema(catalog):
Expand Down Expand Up @@ -212,33 +209,33 @@ class _TestFr(BaseModel):
params="t1",
output={"x": _TestFr},
)
# assert ds.collect() == 1

df = ds.to_pandas()
for i, (x,) in enumerate(ds.iterate()):
assert isinstance(x, _TestFr)

assert df["x.my_name"].tolist() == ["n1", "n2", "n1"]
assert np.allclose(df["x.sqrt"], [math.sqrt(x) for x in [3, 5, 1]])
with pytest.raises(KeyError):
df["x.t1.nnn"]
fr = features[i]
test_fr = _TestFr(file=File(name=""), sqrt=math.sqrt(fr.count), my_name=fr.nnn)
assert x == test_fr


def test_map(catalog):
class _TestFr(BaseModel):
sqrt: float
my_name: str

ds = DataChain.from_values(t1=features)

df = ds.map(
dc = DataChain.from_values(t1=features).map(
x=lambda m_fr: _TestFr(
sqrt=math.sqrt(m_fr.count),
my_name=m_fr.nnn + "_suf",
),
params="t1",
output={"x": _TestFr},
).to_pandas()
)

assert df["x.my_name"].tolist() == ["n1_suf", "n2_suf", "n1_suf"]
assert np.allclose(df["x.sqrt"], [math.sqrt(x) for x in [3, 5, 1]])
assert dc.collect_one("x") == [
_TestFr(sqrt=math.sqrt(fr.count), my_name=fr.nnn + "_suf") for fr in features
]


def test_agg(catalog):
Expand All @@ -247,26 +244,31 @@ class _TestFr(BaseModel):
cnt: int
my_name: str

df = (
DataChain.from_values(t1=features)
.agg(
x=lambda frs: [
_TestFr(
f=File(name=""),
cnt=sum(f.count for f in frs),
my_name="-".join([fr.nnn for fr in frs]),
)
],
partition_by=C.t1.nnn,
params="t1",
output={"x": _TestFr},
)
.to_pandas()
dc = DataChain.from_values(t1=features).agg(
x=lambda frs: [
_TestFr(
f=File(name=""),
cnt=sum(f.count for f in frs),
my_name="-".join([fr.nnn for fr in frs]),
)
],
partition_by=C.t1.nnn,
params="t1",
output={"x": _TestFr},
)

assert len(df) == 2
assert df["x.my_name"].tolist() == ["n1-n1", "n2"]
assert df["x.cnt"].tolist() == [4, 5]
assert dc.collect_one("x") == [
_TestFr(
f=File(name=""),
cnt=sum(fr.count for fr in features if fr.nnn == "n1"),
my_name="-".join([fr.nnn for fr in features if fr.nnn == "n1"]),
),
_TestFr(
f=File(name=""),
cnt=sum(fr.count for fr in features if fr.nnn == "n2"),
my_name="-".join([fr.nnn for fr in features if fr.nnn == "n2"]),
),
]


def test_agg_two_params(catalog):
Expand Down Expand Up @@ -294,10 +296,8 @@ class _TestFr(BaseModel):
output={"x": _TestFr},
)

df = ds.to_pandas()
assert len(df) == 2
assert df["x.my_name"].tolist() == ["n1-n1", "n2"]
assert df["x.cnt"].tolist() == [12, 15]
assert ds.collect_one("x.my_name") == ["n1-n1", "n2"]
assert ds.collect_one("x.cnt") == [12, 15]


def test_agg_simple_iterator(catalog):
Expand Down Expand Up @@ -356,10 +356,8 @@ def func(key, val) -> Iterator[tuple[File, _ImageGroup]]:
values = [1, 5, 9]
ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key"))

df = ds.to_pandas()
assert len(df) == 2
assert df["x_1.name"].tolist() == ["n1-n1", "n2"]
assert df["x_1.size"].tolist() == [10, 5]
assert ds.collect_one("x_1.name") == ["n1-n1", "n2"]
assert ds.collect_one("x_1.size") == [10, 5]


def test_agg_tuple_result_generator(catalog):
Expand All @@ -376,10 +374,8 @@ def func(key, val) -> Generator[tuple[File, _ImageGroup], None, None]:
values = [1, 5, 9]
ds = DataChain.from_values(key=keys, val=values).agg(x=func, partition_by=C("key"))

df = ds.to_pandas()
assert len(df) == 2
assert df["x_1.name"].tolist() == ["n1-n1", "n2"]
assert df["x_1.size"].tolist() == [10, 5]
assert ds.collect_one("x_1.name") == ["n1-n1", "n2"]
assert ds.collect_one("x_1.size") == [10, 5]


def test_iterate(catalog):
Expand Down Expand Up @@ -829,15 +825,15 @@ def test_from_features_object_name(tmp_dir, catalog):
values = ["odd" if num % 2 else "even" for num in fib]

dc = DataChain.from_values(fib=fib, odds=values, object_name="custom")
assert "custom.fib" in dc.to_pandas().columns
assert "custom.fib" in dc.to_pandas(flatten=True).columns


def test_parse_tabular_object_name(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
df.to_parquet(path)
dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="name")
assert "name.first_name" in dc.to_pandas().columns
dc = DataChain.from_storage(path.as_uri()).parse_tabular(object_name="tbl")
assert "tbl.first_name" in dc.to_pandas(flatten=True).columns


def test_sys_feature(tmp_dir, catalog):
Expand Down Expand Up @@ -868,3 +864,12 @@ def test_sys_feature(tmp_dir, catalog):
MyFr(nnn="n1", count=1),
]
assert "sys" not in ds_no_sys.catalog.get_dataset("ds_no_sys").feature_schema


def test_to_pandas_multi_level():
df = DataChain.from_values(t1=features).to_pandas()

assert "t1" in df.columns
assert "nnn" in df["t1"].columns
assert "count" in df["t1"].columns
assert df["t1"]["count"].tolist() == [3, 5, 1]
Loading

0 comments on commit 00c846a

Please sign in to comment.