Skip to content

Commit

Permalink
Add dt unary POC
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Apr 29, 2024
1 parent 10aa053 commit 17f05ca
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
33 changes: 27 additions & 6 deletions src/awkward_pandas/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
import pyarrow.compute as pc


def _run_unary(layout, op, kind=None, **kw):
if layout.is_record:
[_run_unary(_, op, kind=kind, **kw) for _ in layout._contents]
elif layout.is_leaf and (kind is None or layout.dtype.kind == kind):
layout._data = ak.str._apply_through_arrow(op, layout, **kw).data
elif layout.is_option or layout.is_list:
_run_unary(layout.content, op, kind=kind, **kw)


def run_unary(arr: ak.Array, op, kind=None, **kw) -> ak.Array:
arr2 = ak.copy(arr)
_run_unary(arr2.layout, op, kind=kind, **kw)
return ak.Array(arr2)


class DatetimeAccessor:
def __init__(self, accessor) -> None:
self.accessor = accessor
Expand All @@ -19,15 +34,21 @@ def cast(self, target_type=None, safe=None, options=None):
>>> import pandas as pd
>>> import awkward_pandas.pandas
>>> s = pd.Series([0, 1, 2])
>>> s = pd.Series([[0, 1], [1, 0], [2]])
>>> s.ak.dt.cast("timestamp[s]")
0 1970-01-01 00:00:00
1 1970-01-01 00:00:01
2 1970-01-01 00:00:02
dtype: timestamp[s][pyarrow]
0 ['1970-01-01T00:00:00' '1970-01-01T00:00:01']
1 ['1970-01-01T00:00:01' '1970-01-01T00:00:00']
2 ['1970-01-01T00:00:02']
dtype: list<item: timestamp[s]>[pyarrow]
"""
return self.accessor.to_output(
pc.cast(self.accessor.arrow, target_type, safe, options)
run_unary(
self.accessor.array,
pc.cast,
target_type=target_type,
safe=safe,
options=options,
)
)

def ceil_temporal(
Expand Down
5 changes: 4 additions & 1 deletion src/awkward_pandas/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def encode(self, encoding: str = "utf-8"):
return self.accessor.to_output(encode(self.accessor.array, encoding=encoding))

def decode(self, encoding: str = "utf-8"):
"""Decode Series of bytes to Series of strings. Leaves non-bytestrings alone."""
"""Decode Series of bytes to Series of strings. Leaves non-bytestrings alone.
Validity of UTF8 is *not* checked.
"""
return self.accessor.to_output(decode(self.accessor.array, encoding=encoding))

@staticmethod
Expand Down
18 changes: 18 additions & 0 deletions tests/test_dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import datetime

import pytest

import awkward_pandas.pandas # noqa

pd = pytest.importorskip("pandas")


def test_cast():
s = pd.Series([[0, 1], [1, 0], [2]])
out = s.ak.dt.cast("timestamp[s]")
assert str(out.dtype) == "list<item: timestamp[s]>[pyarrow]"
assert out.to_list() == [
[datetime.datetime(1970, 1, 1, 0, 0), datetime.datetime(1970, 1, 1, 0, 0, 1)],
[datetime.datetime(1970, 1, 1, 0, 0, 1), datetime.datetime(1970, 1, 1, 0, 0)],
[datetime.datetime(1970, 1, 1, 0, 0, 2)],
]

0 comments on commit 17f05ca

Please sign in to comment.