Skip to content

Commit

Permalink
Merge pull request #71 from martindurant/apply_where
Browse files Browse the repository at this point in the history
Add where= to apply()
  • Loading branch information
martindurant authored Jul 25, 2024
2 parents 8e4e1af + c185181 commit c79a272
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
16 changes: 8 additions & 8 deletions src/akimbo/apply_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ def f(self, *args, where=None, **kwargs):

f.__doc__ = """Run vectorized functions on nested/ragged/complex array
where: None | str | Sequence[str, ...]
if None, will attempt to apply the kernel throughout the nested structure,
wherever correct types are encountered. If where is given, only the selected
part of the structure will be considered, but the output will retain
the original shape. A fieldname or sequence of fieldnames to descend into
the tree are acceptable
where: None | str | Sequence[str, ...]
if None, will attempt to apply the kernel throughout the nested structure,
wherever correct types are encountered. If where is given, only the selected
part of the structure will be considered, but the output will retain
the original shape. A fieldname or sequence of fieldnames to descend into
the tree are acceptable
Kernel documentation follows from the original function
-Kernel documentation follows from the original function-
===
===
""" + (
f.__doc__ or str(f)
)
Expand Down
8 changes: 7 additions & 1 deletion src/akimbo/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class DatetimeAccessor:
def __init__(self, accessor) -> None:
self.accessor = accessor

cast = dec(pc.cast)
# listed below https://arrow.apache.org/docs/python/generated/
# pyarrow.compute.ceil_temporal.html
cast = dec(pc.cast) # TODO: move to .ak
ceil_temporal = dec_t(pc.ceil_temporal)
floor_temporal = dec_t(pc.floor_temporal)
reound_temporal = dec_t(pc.round_temporal)
Expand Down Expand Up @@ -60,6 +62,10 @@ def __init__(self, accessor) -> None:
weeks_between = dec_t(pc.weeks_between)
years_between = dec_t(pc.years_between)

# TODO: strftime, strptime

# TODO: timezone conversion


def _to_arrow(array):
array = _make_unit_compatible(array)
Expand Down
26 changes: 23 additions & 3 deletions src/akimbo/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def is_dataframe(cls, data):

@classmethod
def _to_output(cls, data):
# TODO: clarify protocol here; can data be in arrow already?
raise NotImplementedError

def to_output(self, data=None):
Expand All @@ -158,13 +159,21 @@ def to_output(self, data=None):
return data
return self._to_output(data)

def apply(self, fn: Callable):
def apply(self, fn: Callable, where=None, **kwargs):
"""Perform arbitrary function on all the values of the series
The function should take an ak array as input and produce an
ak array or scalar.
"""
return self.to_output(fn(self.array))
if where:
bits = tuple(where.split("."))
arr = self.array
part = arr.__getitem__(bits)
out = fn(part, **kwargs)
final = ak.with_field(arr, out, where=where)
else:
final = fn(self.array)
return self.to_output(final)

def __getitem__(self, item):
out = self.array.__getitem__(item)
Expand All @@ -175,8 +184,18 @@ def __dir__(self) -> Iterable[str]:
meths = series_methods if self.is_series(self._obj) else df_methods
return sorted(set(attrs) | set(meths))

def with_behavior(self, behavior):
def with_behavior(self, behavior, where=()):
"""Assign a behavior to this array-of-records"""
# TODO: compare usage with sub-accessors
# TODO: implement where= (assign directly to ._paraneters["__record__"]
# of output's layout. In this case, behaviour is a dict of locations to apply to.
# and we can continually add to it (or accept a dict)
# beh = self._behavior.copy()
# if isinstance(behavior, dict):
# beh.update(behavior)
# else:
# # str or type
# beh[where] = behaviour
return type(self)(self._obj, behavior)

with_name = with_behavior # alias - this is the upstream name
Expand Down Expand Up @@ -208,6 +227,7 @@ def array(self) -> ak.Array:

@classmethod
def register_accessor(cls, name, klass):
# TODO: check clobber?
cls.subaccessors[name] = klass

def merge(self):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@ def test_apply():
assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]]


def test_apply_where():
data = [
{"a": [1, 2, 3], "b": [1, 2, 3]},
{"a": [1, 2, 3], "b": [1, 2, 3]},
{"a": [1, 2, 3], "b": [1, 2, 3]},
]
s = pl.Series(data)
s2 = s.ak.apply(np.negative, where="a")
assert s2[0] == {"a": [-1, -2, -3], "b": [1, 2, 3]}


def test_merge_unmerge():
data = [
{"a": [1, 2, 3], "b": [1, 2, 3]},
{"a": [1, 2, 3], "b": [1, 2, 3]},
{"a": [1, 2, 3], "b": [1, 2, 3]},
]
s = pl.Series(data)
df = s.ak.unmerge()
assert df["a"].to_list() == [[1, 2, 3]] * 3
s2 = df.ak.merge()
assert s.to_list() == s2.to_list()


def test_operator():
s = pl.Series([[1, 2, 3], [], [4, 5]])
s2 = s.ak + 1
Expand Down

0 comments on commit c79a272

Please sign in to comment.