Skip to content

Commit

Permalink
Merge pull request #58 from martindurant/dask_meta
Browse files Browse the repository at this point in the history
Dask meta via typetracers
  • Loading branch information
martindurant authored Jun 12, 2024
2 parents b1605bf + 74f4248 commit eb4ff65
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
27 changes: 24 additions & 3 deletions src/awkward_pandas/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,33 @@ class DaskAwkwardAccessor(AkAccessor):
dataframe_type = dd.DataFrame
aggregations = False # you need dask-awkward for that

@staticmethod
def _to_tt(data):
# self._obj._meta.convert_dtypes(dtype_backend="pyarrow")
data = data._meta if hasattr(data, "_meta") else data
arr = PandasAwkwardAccessor.to_arrow(data)
return ak.to_backend(ak.from_arrow(arr), "typetracer")

@classmethod
def _create_op(cls, op):
def run(self, *args, **kwargs):
orig = self._obj.head()
ar = (ar.head() if hasattr(ar, "ak") else ar for ar in args)
meta = PandasAwkwardAccessor._to_output(op(orig.ak.array, *ar, **kwargs))
try:
tt = self._to_tt(self._obj)
ar = (
ak.to_backend(ar) if isinstance(ar, (ak.Array, ak.Record)) else ar
for ar in args
)
ar = [self._to_tt(ar) if hasattr(ar, "ak") else ar for ar in ar]
out = op(tt, *ar, **kwargs)
meta = PandasAwkwardAccessor._to_output(
ak.typetracer.length_zero_if_typetracer(out)
)
except (ValueError, TypeError):
# could make our own fallback as follows, but dask will guess anyway
# orig = self._obj.head()
# ar = (ar.head() if hasattr(ar, "ak") else ar for ar in args)
# meta = PandasAwkwardAccessor._to_output(op(orig.ak.array, *ar, **kwargs))
meta = None

def inner(data):
import awkward_pandas.pandas # noqa: F401
Expand Down
7 changes: 7 additions & 0 deletions src/awkward_pandas/pandas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import awkward as ak
import pandas as pd
import pyarrow
import pyarrow as pa

from awkward_pandas.mixin import Accessor
Expand All @@ -13,6 +14,12 @@ class PandasAwkwardAccessor(Accessor):

@classmethod
def to_arrow(cls, data):
if isinstance(data, ak.Array):
return ak.to_arrow(data)
if isinstance(data, ak.Record):
return ak.to_arrow_table(data)
if isinstance(data, (pyarrow.Array, pyarrow.Table)):
return data
if cls.is_series(data):
return pa.array(data)
return pa.table(data)
Expand Down

0 comments on commit eb4ff65

Please sign in to comment.